Skip to content

Commit

Permalink
Add morphology loader function for Hemitypes. (#735)
Browse files Browse the repository at this point in the history
* Add morphology loader function for Hemitypes.

* Update bsb/connectivity/strategy.py

---------

Co-authored-by: Robin De Schepper <robin.deschepper93@gmail.com>
  • Loading branch information
drodarie and Helveg committed Jul 7, 2023
1 parent 2a00fdb commit e61bfaa
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 7 deletions.
4 changes: 2 additions & 2 deletions bsb/connectivity/detailed/fiber_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def connect(self):
to_ps = self.scaffold.get_placement_set(to_type.name, labels=labels_post)

# Load the morphology and voxelization data for the entrire morphology, for each cell type.
from_morphology_set = from_placement_set.load_morphologies()
from_morphology_set = self.presynaptic.morpho_loader(from_ps)

to_morphology_set = to_placement_set.load_morphologies()
to_morphology_set = self.postsynaptic.morpho_loader(to_ps)
joined_map = (
from_morphology_set._morphology_map + to_morphology_set._morphology_map
)
Expand Down
4 changes: 3 additions & 1 deletion bsb/connectivity/detailed/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def get_region_of_interest(self, chunk):
def _get_rect_ext(self, chunk_size, pre_post_flag):
if pre_post_flag:
types = self.presynaptic.cell_types
loader = self.presynaptic.morpho_loader
else:
types = self.postsynaptic.cell_types
loader = self.postsynaptic.morpho_loader
ps_list = [ct.get_placement_set() for ct in types]
ms_list = [ps.load_morphologies() for ps in ps_list]
ms_list = [loader(ps) for ps in ps_list]
if not sum(map(len, ms_list)):
# No cells placed, return smallest possible RoI.
return [np.array([0, 0, 0]), np.array([0, 0, 0])]
Expand Down
12 changes: 8 additions & 4 deletions bsb/connectivity/detailed/voxel_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,28 @@ def connect(self, pre, post):
candidates = post
self._n_tvoxels = self.voxels_pre
self._n_cvoxels = self.voxels_post
target_morpho = self.presynaptic.morpho_loader
cand_morpho = self.postsynaptic.morpho_loader
else:
targets = post
candidates = pre
self._n_tvoxels = self.voxels_post
self._n_cvoxels = self.voxels_pre
target_morpho = self.postsynaptic.morpho_loader
cand_morpho = self.presynaptic.morpho_loader
combo_itr = self.candidate_intersection(targets, candidates)
mset_cache = {}
for target_set, cand_set, match_itr in combo_itr:
if self.cache:
if id(target_set) not in mset_cache:
mset_cache[id(target_set)] = target_set.load_morphologies()
mset_cache[id(target_set)] = target_morpho(target_set)
if id(cand_set) not in mset_cache:
mset_cache[id(cand_set)] = cand_set.load_morphologies()
mset_cache[id(cand_set)] = cand_morpho(cand_set)
target_mset = mset_cache[id(target_set)]
cand_mset = mset_cache[id(cand_set)]
else:
target_mset = target_set.load_morphologies()
cand_mset = cand_set.load_morphologies()
target_mset = target_morpho(target_set)
cand_mset = cand_morpho(cand_set)
self._match_voxel_intersection(
match_itr, target_set, cand_set, target_mset, cand_mset
)
Expand Down
18 changes: 18 additions & 0 deletions bsb/connectivity/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,23 @@

@config.node
class Hemitype:
"""
Class used to represent one (pre- or postsynaptic) side of a connection rule.
"""

cell_types = config.reflist(refs.cell_type_ref, required=True)
"""List of cell types to use in connection."""
labels = config.attr(type=types.list())
"""List of labels to filter the placement set by."""
morphology_labels = config.attr(type=types.list())
"""List of labels to filter the morphologies by."""
morpho_loader = config.attr(
type=types.function_(),
required=False,
call_default=False,
default=(lambda ps: ps.load_morphologies()),
)
"""Function to load the morphologies (MorphologySet) from a PlacementSet"""


class HemitypeCollection:
Expand Down Expand Up @@ -46,9 +60,13 @@ def __getitem__(self, item):
@config.dynamic(attr_name="strategy", required=True)
class ConnectionStrategy(abc.ABC, SortableByAfter):
name = config.attr(key=True)
"""Name used to refer to the connectivity strategy"""
presynaptic = config.attr(type=Hemitype, required=True)
"""Presynaptic (source) neuron population"""
postsynaptic = config.attr(type=Hemitype, required=True)
"""Postsynaptic (target) neuron population"""
after = config.reflist(refs.connectivity_ref)
"""Action to perform after connecting the neurons with the current strategy."""

def __init_subclass__(cls, **kwargs):
super(cls, cls).__init_subclass__(**kwargs)
Expand Down
6 changes: 6 additions & 0 deletions bsb/morphologies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ class MorphologySet:
"""

def __init__(self, loaders, m_indices=None, /, labels=None):
"""
:param loaders: list of Morphology loader functions.
:type loaders: List[Callable[[], bsb.storage.interfaces.StoredMorphology]]
:param m_indices: indices of the loaders for each of the morphologies.
:type: List[int]
"""
if m_indices is None:
loaders, m_indices = np.unique(loaders, return_inverse=True)
self._m_indices = np.array(m_indices, copy=False, dtype=int)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,13 @@ class TestConnWithSubCellLabels(
engine_name="hdf5",
morpho_filters=["PurkinjeCell", "StellateCell"],
):
def _morpho_loader(self, ps):
self.increment += 1
return ps.load_morphologies()

def setUp(self):
super().setUp()
self.increment = 0
self.network.connectivity.add(
"self_intersect",
dict(
Expand All @@ -363,6 +368,7 @@ def setUp(self):
postsynaptic=dict(
cell_types=["test_cell"],
morphology_labels=["tag_16", "tag_17", "tag_18"],
morpho_loader=self._morpho_loader,
),
),
)
Expand Down Expand Up @@ -412,6 +418,12 @@ def connect_spy(strat, pre, post):
except Exception as e:
raise
self.fail(f"Unexpected error: {e}")
self.assertEqual(
self.increment,
len(self.chunks) + 1,
"expect one call of the loading function per chunk + 1 for processing"
" the region of interest.",
)
cs = self.network.get_connectivity_set("self_intersect")
sloc, dloc = cs.load_connections().all()
self.assertAll(sloc > -1, "expected only true conn")
Expand Down

0 comments on commit e61bfaa

Please sign in to comment.