Skip to content

Commit

Permalink
WIP: Allow attribute access in SynapticSubgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Feb 6, 2023
1 parent c548f8f commit 4a06fea
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions brian2/synapses/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def find_synapses(index, synaptic_neuron):
return synapses


class SynapticSubgroup:
class SynapticSubgroup(Group):
"""
A simple subgroup of `Synapses` that can be used for indexing.
Expand All @@ -543,35 +543,65 @@ class SynapticSubgroup:
when new synapses where added after creating this object.
"""

def __init__(self, synapses, indices):
self.synapses = weakproxy_with_fallback(synapses)
def __init__(self, synapses, indices, name=None):
self.source = weakproxy_with_fallback(synapses)
if name is None:
name = f"{self.source.name}_subgroup*"
Group.__init__(
self,
clock=None,
name=name,
)
self._stored_indices = indices
self._synaptic_pre = synapses.variables["_synaptic_pre"]
self._source_N = self._synaptic_pre.size # total number of synapses
self._N = len(indices)
# All the variables have to go via the _sub_idx to refer to the
# appropriate values in the source group
self.variables = Variables(self, default_index="_sub_idx")
self.variables.add_constant("N", value=self._N)
self.variables.add_constant("_source_N", value=self._synaptic_pre.size)
# add references for all variables in the original group
self.variables.add_references(self.source, list(self.source.variables.keys()))
# Only the variable _sub_idx itself is stored in the subgroup
# and needs the normal index for this group
self.variables.add_array(
"_sub_idx",
size=self._N,
dtype=np.int32,
values=indices,
index="_idx",
constant=True,
read_only=True,
unique=True,
)

def _indices(self, index_var="_idx"):
if index_var != "_idx":
raise AssertionError(f"Did not expect index {index_var} here.")
self.namespace = self.source.namespace

self._enable_group_attributes()

def _indices(self, item, index_var="_idx"):
if len(self._synaptic_pre.get_value()) != self._source_N:
raise RuntimeError(
"Synapses have been added/removed since this "
"synaptic subgroup has been created"
)
return self._stored_indices
orig_indexing = SynapticIndexing(self.source)
orig_indices = orig_indexing(item)
return sorted(set(self._stored_indices).intersection(orig_indices))

def __len__(self):
return len(self._stored_indices)

def __repr__(self):
return (
f"<{self.__class__.__name__}, storing {len(self._stored_indices):d} "
f"indices of {self.synapses.name}>"
f"indices of {self.source.name}>"
)


class SynapticIndexing:
def __init__(self, synapses):
self.synapses = weakref.proxy(synapses)
self.synapses = weakproxy_with_fallback(synapses)
self.source = weakproxy_with_fallback(self.synapses.source)
self.target = weakproxy_with_fallback(self.synapses.target)
self.synaptic_pre = synapses.variables["_synaptic_pre"]
Expand Down

0 comments on commit 4a06fea

Please sign in to comment.