Skip to content

Commit

Permalink
Add InvertedRoI mixin and FixedIndegree strategy (#729)
Browse files Browse the repository at this point in the history
* add InvertedRoI mixin

* add FixedIndegree connection strat

* fix FixedIndegree, add tests

* add multiple cell types test

* docstrings

* changed cell type number to increase reliability of test

* fix tests when some cells draw 0 targets from a pop in multipop

* doc mixins module
  • Loading branch information
Helveg committed Jun 9, 2023
1 parent 1cdfa8b commit 53bf5d0
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 2 deletions.
43 changes: 43 additions & 0 deletions bsb/connectivity/general.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import itertools
import os
import numpy as np
import functools
from .strategy import ConnectionStrategy
from ..exceptions import SourceQualityError
from .. import config, _util as _gutil
from ..config import types
from ..mixins import InvertedRoI
from ..reporting import warn


Expand Down Expand Up @@ -114,3 +116,44 @@ def _map(self, data, map, targets):
return np.vectorize(dict(zip(map, targets)).get)(data)
except TypeError:
raise SourceQualityError("Missing GIDs in external map.")


@config.node
class FixedIndegree(InvertedRoI, ConnectionStrategy):
"""
Connect a group of postsynaptic cell types to ``indegree`` uniformly random
presynaptic cells from all the presynaptic cell types.
"""

indegree = config.attr(type=int, required=True)

def get_region_of_interest(self, chunk):
from_chunks = set(
itertools.chain.from_iterable(
ct.get_placement_set().get_all_chunks()
for ct in self.presynaptic.cell_types
)
)
return from_chunks

def connect(self, pre, post):
in_ = self.indegree
rng = np.random.default_rng()
high = sum(len(ps) for ps in pre.placement.values())
for post_ct, ps in post.placement.items():
l = len(ps)
pre_targets = np.full((l * in_, 3), -1)
post_targets = np.full((l * in_, 3), -1)
ptr = 0
for i in range(l):
post_targets[ptr : ptr + in_, 0] = i
pre_targets[ptr : ptr + in_, 0] = rng.choice(high, in_, replace=False)
ptr += in_
lowmux = 0
for pre_ct, pre_ps in pre.placement.items():
highmux = lowmux + len(pre_ps)
demux_idx = (pre_targets[:, 0] >= lowmux) & (pre_targets[:, 0] < highmux)
demuxed = pre_targets[demux_idx]
demuxed[:, 0] -= lowmux
self.connect_cells(pre_ps, ps, demuxed, post_targets[demux_idx])
lowmux = highmux
6 changes: 5 additions & 1 deletion bsb/connectivity/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ def _get_connect_args_from_job(self, pre_roi, post_roi):
return pre, post

def connect_cells(self, pre_set, post_set, src_locs, dest_locs, tag=None):
if len(self.presynaptic.cell_types) > 1 or len(self.postsynaptic.cell_types) > 1:
name = f"{self.name}_{pre_set.cell_type.name}_to_{post_set.cell_type.name}"
else:
name = self.name
cs = self.scaffold.require_connectivity_set(
pre_set.cell_type, post_set.cell_type, tag if tag is not None else self.name
pre_set.cell_type, post_set.cell_type, tag if tag is not None else name
)
cs.connect(pre_set, post_set, src_locs, dest_locs)

Expand Down
47 changes: 47 additions & 0 deletions bsb/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,50 @@ def __init_subclass__(cls, **kwargs):
"NotParallel can only be applied to placement or "
"connectivity strategies"
)


class InvertedRoI:
"""
This mixin inverts the perspective of the ``get_region_of_interest`` interface and
lets you find presynaptic regions of interest for a postsynaptic chunk.
Usage:
..code-block:: python
class MyConnStrat(InvertedRoI, ConnectionStrategy):
def get_region_of_interest(post_chunk):
return [pre_chunk1, pre_chunk2]
"""

def queue(self, pool):
# Reset jobs that we own
self._queued_jobs = []
# Get the queued jobs of all the strategies we depend on.
deps = set(
itertools.chain.from_iterable(
strat._queued_jobs for strat in self.get_after()
)
)
post_types = self.postsynaptic.cell_types
# Iterate over each chunk that is populated by our postsynaptic cell types.
to_chunks = set(
itertools.chain.from_iterable(
ct.get_placement_set().get_all_chunks() for ct in post_types
)
)
rois = {
chunk: roi
for chunk in to_chunks
if (roi := self.get_region_of_interest(chunk))
}
if not rois:
warn(
f"No overlap found between {[post.name for post in post_types]} and "
f"{[pre.name for pre in self.presynaptic.cell_types]} "
f"in '{self.name}'."
)
for chunk, roi in rois.items():
job = pool.queue_connectivity(self, roi, [chunk], deps=deps)
self._queued_jobs.append(job)
report(f"Queued {len(self._queued_jobs)} jobs for {self.name}", level=2)
68 changes: 68 additions & 0 deletions bsb/unittest/data/configs/test_indegree.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
{
"storage": {
"engine": "hdf5"
},
"network": {
"x": 150.0,
"y": 150.0,
"z": 150.0
},
"partitions": {
"all": {
"type": "layer",
"thickness": 150.0
}
},
"cell_types": {
"inhibitory": {
"spatial": {
"radius": 1.0,
"count": 2000
}
},
"excitatory": {
"spatial": {
"radius": 1.0,
"count": 2000
}
},
"extra": {
"spatial": {
"radius": 1.0,
"count": 2000
}
}
},
"placement": {
"random": {
"strategy": "bsb.placement.RandomPlacement",
"cell_types": ["inhibitory","excitatory", "extra"],
"partitions": ["all"]
}
},
"connectivity": {
"indegree": {
"strategy": "bsb.connectivity.FixedIndegree",
"indegree": 50,
"presynaptic": {
"cell_types": ["excitatory"]
},
"postsynaptic": {
"cell_types": ["inhibitory"]
}
},
"multi_indegree": {
"strategy": "bsb.connectivity.FixedIndegree",
"indegree": 50,
"presynaptic": {
"cell_types": ["excitatory", "extra"]
},
"postsynaptic": {
"cell_types": ["inhibitory", "extra"]
}
}
},
"simulations": {

}
}
7 changes: 7 additions & 0 deletions docs/bsb/bsb.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ bsb.exceptions module
:undoc-members:
:show-inheritance:

bsb.exceptions module
---------------------

.. automodule:: bsb.mixins
:members:
:undoc-members:

bsb.option module
-----------------

Expand Down
40 changes: 39 additions & 1 deletion tests/test_connectivity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from bsb.core import Scaffold
from bsb.services import MPI
from bsb.config import Configuration
from bsb.config import Configuration, from_file
from bsb.morphologies import Morphology, Branch
from bsb.unittest import (
NumpyTestCase,
Expand All @@ -9,6 +9,7 @@
MorphologiesFixture,
NetworkFixture,
skip_parallel,
get_config_path,
)
import unittest
import numpy as np
Expand Down Expand Up @@ -634,3 +635,40 @@ def test_zero_contacts(self):
self.network.compile(clear=True)
conns = len(self.network.get_connectivity_set("intersect"))
self.assertEqual(0, conns, "expected no contacts")


class TestFixedIndegree(
NetworkFixture, RandomStorageFixture, unittest.TestCase, engine_name="hdf5"
):
def setUp(self) -> None:
self.cfg = from_file(get_config_path("test_indegree.json"))
super().setUp()

def test_indegree(self):
self.network.compile()
cs = self.network.get_connectivity_set("indegree")
_, post_locs = cs.load_connections().all()
ps = self.network.get_placement_set("inhibitory")
u, c = np.unique(post_locs[:, 0], return_counts=True)
self.assertTrue(
np.array_equal(np.arange(len(ps)), np.sort(u)),
"Not all post cells have connections",
)
self.assertTrue(np.all(c == 50), "Not all cells have indegree 50")

def test_multi_indegree(self):
self.network.compile()
for post_name in ("inhibitory", "extra"):
post_ps = self.network.get_placement_set(post_name)
total = np.zeros(len(post_ps))
for pre_name in ("excitatory", "extra"):
cs = self.network.get_connectivity_set(
f"multi_indegree_{pre_name}_to_{post_name}"
)
_, post_locs = cs.load_connections().all()
ps = self.network.get_placement_set("inhibitory")
u, c = np.unique(post_locs[:, 0], return_counts=True)
this = np.zeros(len(post_ps))
this[u] = c
total += this
self.assertTrue(np.all(total == 50), "Not all cells have indegree 50")

0 comments on commit 53bf5d0

Please sign in to comment.