Skip to content

Commit

Permalink
Merge pull request #1358 from brian-team/fix_spatialsubgroup_indexing
Browse files Browse the repository at this point in the history
Fix indexing for sections in SpatialNeuron
  • Loading branch information
mstimberg committed Oct 8, 2021
2 parents 7a2c69e + 9950683 commit a669f65
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
14 changes: 5 additions & 9 deletions brian2/groups/subgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,11 @@ class Subgroup(Group, SpikeSource):
A unique name for the group, or use ``source.name+'_subgroup_0'``, etc.
'''
def __init__(self, source, start, stop, name=None):
# First check if the source is itself a Subgroup
# If so, then make this a Subgroup of the original Group
if isinstance(source, Subgroup):
source = source.source
start = start + source.start
stop = stop + source.start
self.source = source
else:
self.source = weakproxy_with_fallback(source)
# A Subgroup should never be constructed from another Subgroup
# Instead, use Subgroup(source.source,
# start + source.start, stop + source.start)
assert not isinstance(source, Subgroup)
self.source = weakproxy_with_fallback(source)

# Store a reference to the source's equations (if any)
self.equations = None
Expand Down
16 changes: 13 additions & 3 deletions brian2/spatialneuron/spatialneuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ class SpatialNeuron(NeuronGroup):
----------
morphology : `Morphology`
The morphology of the neuron.
model : (str, `Equations`)
model : str, `Equations`
The equations defining the group.
method : (str, function), optional
method : str, function, optional
The numerical integration method. Either a string with the name of a
registered method (e.g. "euler") or a function that receives an
`Equations` object and returns the corresponding abstract code. If no
Expand Down Expand Up @@ -495,11 +495,17 @@ def spatialneuron_segment(neuron, item):
start, stop = to_start_stop(item.indices[:], neuron._N)
else:
start, stop = to_start_stop(item, neuron._N)
if isinstance(neuron, SpatialSubgroup):
start += neuron.start
stop += neuron.start

if start >= stop:
raise IndexError('Illegal start/end values for subgroup, %d>=%d' %
(start, stop))

if isinstance(neuron, SpatialSubgroup):
# Note that the start/stop values calculated above are always
# absolute values, even for subgroups
neuron = neuron.source
return Subgroup(neuron, start, stop)


Expand All @@ -522,6 +528,10 @@ class SpatialSubgroup(Subgroup):

def __init__(self, source, start, stop, morphology, name=None):
self.morphology = morphology
if isinstance(source, SpatialSubgroup):
source = source.source
start += source.start
stop += source.start
Subgroup.__init__(self, source, start, stop, name)

def __getattr__(self, name):
Expand Down
6 changes: 5 additions & 1 deletion brian2/tests/test_spatialneuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def test_spatialneuron_indexing():
sec.sec2 = Cylinder(length=50 * um, diameter=10 * um, n=16)
sec.sec2.sec21 = Cylinder(length=50 * um, diameter=10 * um, n=32)
neuron = SpatialNeuron(sec, 'Im = 0*amp/meter**2 : amp/meter**2')

neuron.v = 'i*volt'
# Accessing indices/variables of a subtree refers to the full subtree
assert len(neuron.indices[:]) == 1 + 2 + 4 + 8 + 16 + 32
assert len(neuron.sec1.indices[:]) == 2 + 4 + 8
Expand Down Expand Up @@ -659,6 +659,10 @@ def test_spatialneuron_indexing():
assert len(neuron[0:1].indices[:]) == 1
assert len(neuron[sec.sec2.indices[:]]) == 16
assert len(neuron[sec.sec2]) == 16
assert_equal(neuron.sec1.sec11.v, [3, 4, 5, 6]*volt)
assert_equal(neuron.sec1.sec11[1].v, neuron.sec1.sec11.v[1])
assert_equal(neuron.sec1.sec11[1:3].v, neuron.sec1.sec11.v[1:3])
assert_equal(neuron.sec1.sec11[1:3].v, [4, 5]*volt)


@pytest.mark.codegen_independent
Expand Down

0 comments on commit a669f65

Please sign in to comment.