Skip to content

Commit

Permalink
WIP: Support non-contiguous subgroups in synapses
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Nov 25, 2022
1 parent 65ec4c3 commit 8d5b98d
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ _vectorisation_idx = 1
"k" is called the inner variable #}

for _{{outer_index}} in range({{outer_index_size}}):
{% if non_contiguous_outer %}
_raw{{outer_index_array}} = {{get_array_name(variables[outer_sub_idx])}}[_{{outer_index}}]
{% else %}
_raw{{outer_index_array}} = _{{outer_index}} + {{outer_index_offset}}
{% endif %}
{% if not result_index_condition %}
{{vector_code['create_cond']|autoindent}}
if not _cond:
Expand Down Expand Up @@ -184,7 +188,11 @@ for _{{outer_index}} in range({{outer_index_size}}):
{% endif %}

_vectorisation_idx = _{{result_index}}
{% if non_contiguous_result %}
_raw{{result_index_array}} = {{get_array_name(variables[result_sub_idx])}}[_{{result_index}}]
{% else %}
_raw{{result_index_array}} = _{{result_index}} + {{result_index_offset}}
{% endif %}
{{vector_code['update']|autoindent}}

if not _numpy.isscalar(_n):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
constants or scalar arrays#}
const size_t _N_pre = {{constant_or_scalar('N_pre', variables['N_pre'])}};
const size_t _N_post = {{constant_or_scalar('N_post', variables['N_post'])}};
{{_dynamic_N_incoming}}.resize(_N_post + _target_offset);
{{_dynamic_N_outgoing}}.resize(_N_pre + _source_offset);
const size_t _raw_N_pre = {{constant_or_scalar('raw_N_pre', variables['raw_N_pre'])}};
const size_t _raw_N_post = {{constant_or_scalar('raw_N_post', variables['raw_N_post'])}};
{{_dynamic_N_incoming}}.resize(_raw_N_post);
{{_dynamic_N_outgoing}}.resize(_raw_N_pre);
size_t _raw_pre_idx, _raw_post_idx;
{# For a connect call j='k+i for k in range(0, N_post, 2) if k+i < N_post'
"j" is called the "result index" (and "_post_idx" the "result index array", etc.)
Expand All @@ -35,7 +37,11 @@
for(size_t _{{outer_index}}=0; _{{outer_index}}<_{{outer_index_size}}; _{{outer_index}}++)
{
bool __cond, _cond;
{% if non_contiguous_outer %}
_raw{{outer_index_array}} = {{get_array_name(variables[outer_sub_idx])}}[_{{outer_index}}];
{% else %}
_raw{{outer_index_array}} = _{{outer_index}} + {{outer_index_offset}};
{% endif %}
{% if not result_index_condition %}
{
{{vector_code['create_cond']|autoindent}}
Expand Down Expand Up @@ -181,7 +187,11 @@
}
_{{result_index}} = __{{result_index}}; // make the previously locally scoped var available
{{outer_index_array}} = _{{outer_index_array}};
{% if non_contiguous_result %}
_raw{{result_index_array}} = {{get_array_name(variables[result_sub_idx])}}[_{{result_index}}];
{% else %}
_raw{{result_index_array}} = _{{result_index}} + {{result_index_offset}};
{% endif %}
{% if result_index_condition %}
{
{% if result_index_used %}
Expand Down Expand Up @@ -222,10 +232,10 @@
{{vector_code['update']|autoindent}}

for (size_t _repetition=0; _repetition<_n; _repetition++) {
{{_dynamic_N_outgoing}}[_pre_idx] += 1;
{{_dynamic_N_incoming}}[_post_idx] += 1;
{{_dynamic__synaptic_pre}}.push_back(_pre_idx);
{{_dynamic__synaptic_post}}.push_back(_post_idx);
{{_dynamic_N_outgoing}}[_raw_pre_idx] += 1;
{{_dynamic_N_incoming}}[_raw_post_idx] += 1;
{{_dynamic__synaptic_pre}}.push_back(_raw_pre_idx);
{{_dynamic__synaptic_post}}.push_back(_raw_post_idx);
}
}
}
Expand Down
152 changes: 108 additions & 44 deletions brian2/synapses/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,13 +1149,17 @@ def verify_dependencies(
)

N_outgoing_pre = property(
fget=lambda self: self.variables["N_outgoing"].get_value(),
fget=lambda self: self.variables["N_outgoing"].get_value()
if not isinstance(self.source, Subgroup)
else self.variables["N_outgoing"].get_value()[self.source._sub_idx],
doc=(
"The number of outgoing synapses for each neuron in the pre-synaptic group."
),
)
N_incoming_post = property(
fget=lambda self: self.variables["N_incoming"].get_value(),
fget=lambda self: self.variables["N_incoming"].get_value()
if not isinstance(self.target, Subgroup)
else self.variables["N_incoming"].get_value()[self.target._sub_idx],
doc=(
"The number of incoming synapses for each neuron in the "
"post-synaptic group."
Expand Down Expand Up @@ -1316,7 +1320,14 @@ def _create_variables(self, equations, user_dtype=None):
read_only=True,
index="_presynaptic_idx",
)

if isinstance(self.source, Subgroup):
self.variables.add_reference("raw_N_pre", self.source.source, "N")
else:
self.variables.add_reference("raw_N_pre", self.source, "N")
if isinstance(self.target, Subgroup):
self.variables.add_reference("raw_N_post", self.target.source, "N")
else:
self.variables.add_reference("raw_N_post", self.target, "N")
# We have to make a distinction here between the indices
# and the arrays (even though they refer to the same object)
# the synaptic propagation template would otherwise overwrite
Expand All @@ -1326,45 +1337,68 @@ def _create_variables(self, equations, user_dtype=None):
self.variables.add_reference("_presynaptic_idx", self, "_synaptic_pre")
self.variables.add_reference("_postsynaptic_idx", self, "_synaptic_post")

# Except for subgroups (which potentially add an offset), the "i" and
# "j" variables are simply equivalent to `_synaptic_pre` and
# `_synaptic_post`
if getattr(self.source, "start", 0) == 0:
self.variables.add_reference("i", self, "_synaptic_pre")
else:
if isinstance(self.source, Subgroup) and not self.source.contiguous:
raise TypeError(
"Cannot use a non-contiguous subgroup as a "
"source group for Synapses."
# For subgroups, i and j either are shifted versions of the original indices
# (contiguous subgroups), or map into the real indices using the sub_idx
# variable
if (
isinstance(self.source, Subgroup)
and getattr(self.source, "start", None) != 0
):
if self.source.contiguous:
self.variables.add_reference(
"_source_i", self.source.source, "i", index="_presynaptic_idx"
)
self.variables.add_reference("_source_offset", self.source, "_offset")
self.variables.add_subexpression(
"i",
dtype=self.source.source.variables["i"].dtype,
expr="_source_i - _source_offset",
index="_presynaptic_idx",
)
else:
inverted_idcs = np.zeros(len(self.source.source), dtype=np.int32)
inverted_idcs[self.source._sub_idx] = np.arange(len(self.source))
self.variables.add_array(
"i",
dtype=np.int32,
size=len(self.source.source),
values=inverted_idcs,
index="_presynaptic_idx",
)
self.variables.add_reference(
"_source_i", self.source.source, "i", index="_presynaptic_idx"
)
self.variables.add_reference("_source_offset", self.source, "_offset")
self.variables.add_subexpression(
"i",
dtype=self.source.source.variables["i"].dtype,
expr="_source_i - _source_offset",
index="_presynaptic_idx",
)
if getattr(self.target, "start", 0) == 0:
self.variables.add_reference("j", self, "_synaptic_post")
else:
if isinstance(self.target, Subgroup) and not self.target.contiguous:
raise TypeError(
"Cannot use a non-contiguous subgroup as a "
"target group for Synapses."
# For subgroups that start at zero, or when not using a subgroup, i is
# simply a reference to the presynaptic indices
self.variables.add_reference("i", self, "_synaptic_pre")

if (
isinstance(self.target, Subgroup)
and getattr(self.target, "start", None) != 0
):
if self.target.contiguous:
self.variables.add_reference(
"_target_i", self.target.source, "i", index="_postsynaptic_idx"
)
self.variables.add_reference(
"_target_j", self.target.source, "i", index="_postsynaptic_idx"
)
self.variables.add_reference("_target_offset", self.target, "_offset")
self.variables.add_subexpression(
"j",
dtype=self.target.source.variables["i"].dtype,
expr="_target_j - _target_offset",
index="_postsynaptic_idx",
)
self.variables.add_reference("_target_offset", self.target, "_offset")
self.variables.add_subexpression(
"j",
dtype=self.target.source.variables["i"].dtype,
expr="_target_i - _target_offset",
index="_postsynaptic_idx",
)
else:
inverted_idcs = np.zeros(len(self.target.source), dtype=np.int32)
inverted_idcs[self.target._sub_idx] = np.arange(len(self.target))
self.variables.add_array(
"j",
dtype=np.int32,
size=len(self.target.source),
values=inverted_idcs,
index="_postsynaptic_idx",
)
else:
# For subgroups that start at zero, or when not using a subgroup, j is
# simply a reference to the postsynaptic indices
self.variables.add_reference("j", self, "_synaptic_post")

# Add the standard variables
self.variables.add_array(
Expand Down Expand Up @@ -1779,17 +1813,18 @@ def _update_synapse_numbers(self, old_num_synapses):
source_offset = self.variables["_source_offset"].get_value()
target_offset = self.variables["_target_offset"].get_value()
# This resizing is only necessary if we are connecting to/from synapses
post_with_offset = int(self.variables["N_post"].get_value()) + target_offset
pre_with_offset = int(self.variables["N_pre"].get_value()) + source_offset
self.variables["N_incoming"].resize(post_with_offset)
self.variables["N_outgoing"].resize(pre_with_offset)
post_size = self.variables["raw_N_post"].get_value()
pre_size = self.variables["raw_N_pre"].get_value()
self.variables["N_incoming"].resize(post_size)
self.variables["N_outgoing"].resize(pre_size)
N_outgoing = self.variables["N_outgoing"].get_value()
N_incoming = self.variables["N_incoming"].get_value()
synaptic_pre = self.variables["_synaptic_pre"].get_value()
synaptic_post = self.variables["_synaptic_post"].get_value()

# Update the number of total outgoing/incoming synapses per
# source/target neuron
# source/target neuron. Note that for subgroups, all entries that correspond
# to neurons not present in the subgroup are left empty.
N_outgoing[:] += np.bincount(
synaptic_pre[old_num_synapses:], minlength=len(N_outgoing)
)
Expand Down Expand Up @@ -1884,12 +1919,19 @@ def _add_synapses_from_arrays(self, sources, targets, n, p, namespace=None):
sources = sources.repeat(n)
targets = targets.repeat(n)

# For non-contiguous subgroups, we directly translate the indices here, i.e. for
# the template it is as if no subgroups were used at all
if isinstance(self.source, Subgroup) and not self.source.contiguous:
sources = self.source._sub_idx[sources]
if isinstance(self.target, Subgroup) and not self.target.contiguous:
targets = self.target._sub_idx[targets]
variables.add_array(
"sources", len(sources), dtype=np.int32, values=sources, read_only=True
)
variables.add_array(
"targets", len(targets), dtype=np.int32, values=targets, read_only=True
)

# These definitions are important to get the types right in C++
variables.add_auxiliary_variable("_real_sources", dtype=np.int32)
variables.add_auxiliary_variable("_real_targets", dtype=np.int32)
Expand Down Expand Up @@ -1985,6 +2027,16 @@ def _add_synapses_generator(
outer_index_size = "N_pre" if over_presynaptic else "N_post"
outer_index_array = "_pre_idx" if over_presynaptic else "_post_idx"
outer_index_offset = "_source_offset" if over_presynaptic else "_target_offset"
if over_presynaptic:
non_contiguous_outer = not getattr(self.source, "contiguous", True)
non_contiguous_result = not getattr(self.target, "contiguous", True)
else:
non_contiguous_outer = not getattr(self.target, "contiguous", True)
non_contiguous_result = not getattr(self.source, "contiguous", True)
# The ..._sub_idx variable are unused if the group is a contiguous subgroup,
# or not a subgroup at all
outer_sub_idx = "_source_sub_idx" if over_presynaptic else "_target_sub_idx"
result_sub_idx = "_target_sub_idx" if over_presynaptic else "_source_sub_idx"
result_index = "j" if over_presynaptic else "i"
result_index_size = "N_post" if over_presynaptic else "N_pre"
target_idx = "_postsynaptic_idx" if over_presynaptic else "_presynaptic_idx"
Expand All @@ -1997,11 +2049,15 @@ def _add_synapses_generator(
"outer_index_size": outer_index_size,
"outer_index_array": outer_index_array,
"outer_index_offset": outer_index_offset,
"non_contiguous_outer": non_contiguous_outer,
"outer_sub_idx": outer_sub_idx,
"result_index": result_index,
"result_index_size": result_index_size,
"result_index_name": result_index_name,
"result_index_array": result_index_array,
"result_index_offset": result_index_offset,
"non_contiguous_result": non_contiguous_result,
"result_sub_idx": result_sub_idx,
}
)
abstract_code = {
Expand Down Expand Up @@ -2094,6 +2150,13 @@ def _add_synapses_generator(
else:
variables.add_constant("_target_offset", value=0)

if not getattr(self.source, "contiguous", True):
variables.add_reference("_source_sub_idx", self.source, "_sub_idx")
needed_variables.append("_source_sub_idx")
if not getattr(self.target, "contiguous", True):
variables.add_reference("_target_sub_idx", self.target, "_sub_idx")
needed_variables.append("_target_sub_idx")

variables.add_auxiliary_variable("_raw_pre_idx", dtype=np.int32)
variables.add_auxiliary_variable("_raw_post_idx", dtype=np.int32)

Expand All @@ -2113,6 +2176,7 @@ def _add_synapses_generator(
f"'{self.target.name}', using generator "
f"'{parsed['original_expression']}'"
)
needed_variables.extend(["raw_N_pre", "raw_N_post"])

codeobj = create_runner_codeobj(
self,
Expand Down
26 changes: 26 additions & 0 deletions brian2/tests/test_subgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,32 @@ def test_synapse_creation():
assert all(S.N_incoming[:, 5] == 1)


@pytest.mark.standalone_compatible
def test_synapse_creation_non_contiguous():
G1 = NeuronGroup(10, "")
G2 = NeuronGroup(20, "")
SG1 = G1[::2]
SG2 = G2[[0, 1, 2, 3, 16, 17, 18, 19]]
S = Synapses(SG1, SG2)
S.connect(i=2, j=2) # Should correspond to (4, 2)
S.connect("i==2 and j==5") # Should correspond to (4, 17)

run(0 * ms) # for standalone

# Internally, the "real" neuron indices should be used
assert_equal(S._synaptic_pre[:], np.array([4, 4]))
assert_equal(S._synaptic_post[:], np.array([2, 17]))
# For the user, the subgroup-relative indices should be presented
assert_equal(S.i[:], np.array([2, 2]))
assert_equal(S.j[:], np.array([2, 5]))
# N_incoming and N_outgoing should also be correct
assert all(S.N_outgoing[2, :] == 2)
assert_equal(S.N_outgoing_pre, [0, 0, 2, 0, 0])
assert all(S.N_incoming[:, 2] == 1)
assert all(S.N_incoming[:, 5] == 1)
assert_equal(S.N_incoming_post, [0, 0, 1, 0, 0, 1, 0, 0])


@pytest.mark.standalone_compatible
def test_synapse_creation_state_vars():
G1 = NeuronGroup(10, "v : 1")
Expand Down

0 comments on commit 8d5b98d

Please sign in to comment.