Skip to content

Commit

Permalink
Make RoI implementations optional (#750)
Browse files Browse the repository at this point in the history
* removed v3 external connections code

* removed RoIs that were equivalent to the new no roi fallback

* make RoI optional

If `None` is returned this is valid and means the entirety of the PS is the RoI. If instead an empty list is returned, there is no RoI for the chunk and it will be skipped

* added 2 helper functions for finding all the pre or post chunks

* restructured connectivity docs

* hemitypecoll.placement is a list now. simpler to use

* fix docs

* fix docs and test

* Documented the purpose of the different iterators
  • Loading branch information
Helveg committed Oct 4, 2023
1 parent 57492bb commit 2305e0f
Show file tree
Hide file tree
Showing 14 changed files with 393 additions and 483 deletions.
5 changes: 2 additions & 3 deletions bsb/connectivity/detailed/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@ def _get_rect_ext(self, chunk_size, pre_post_flag):

def candidate_intersection(self, target_coll, candidate_coll):
target_cache = [
(ttype, tset, tset.load_boxes())
for ttype, tset in target_coll.placement.items()
(tset.cell_type, tset, tset.load_boxes()) for tset in target_coll.placement
]
for ctype, cset in candidate_coll.placement.items():
for cset in candidate_coll.placement:
box_tree = cset.load_box_tree()
for ttype, tset, tboxes in target_cache:
yield (tset, cset, self._affinity_filter(box_tree.query(tboxes)))
Expand Down
96 changes: 5 additions & 91 deletions bsb/connectivity/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,10 @@ class AllToAll(ConnectionStrategy):
All to all connectivity between two neural populations
"""

def get_region_of_interest(self, chunk):
# All to all needs all pre chunks per post chunk.
# Fingers crossed for out of memory errors.
return self._get_all_post_chunks()

@functools.cache
def _get_all_post_chunks(self):
all_ps = (ct.get_placement_set() for ct in self.postsynaptic.cell_types)
chunks = set(_gutil.ichain(ps.get_all_chunks() for ps in all_ps))
return list(chunks)

def connect(self, pre, post):
for from_ps in pre.placement.values():
for from_ps in pre.placement:
fl = len(from_ps)
for to_ps in post.placement.values():
for to_ps in post.placement:
len_ = len(to_ps)
ml = fl * len_
src_locs = np.full((ml, 3), -1)
Expand All @@ -52,72 +41,6 @@ def connect(self, pre, post):
self.connect_cells(from_ps, to_ps, src_locs, dest_locs)


class ExternalConnections(ConnectionStrategy):
"""
Load the connection matrix from an external source.
"""

required = ["source"]
casts = {"format": str, "warn_missing": bool, "use_map": bool, "headers": bool}
defaults = {
"format": "csv",
"headers": True,
"use_map": False,
"warn_missing": True,
"delimiter": ",",
}

has_external_source = True

def check_external_source(self):
return os.path.exists(self.source)

def get_external_source(self):
return self.source

def validate(self):
if self.warn_missing and not self.check_external_source():
src = self.get_external_source()
warn(f"Missing external source '{src}' for '{self.name}'")

def connect(self):
if self.format == "csv":
return self._connect_from_csv()

def _connect_from_csv(self):
if not self.check_external_source():
src = self.get_external_source()
raise RuntimeError(f"Missing source file '{src}' for `{self.name}`.")
from_type = self.from_cell_types[0]
to_type = self.to_cell_types[0]
# Read the entire csv, skipping the headers if there are any.
data = np.loadtxt(
self.get_external_source(),
skiprows=int(self.headers),
delimiter=self.delimiter,
)
if self.use_map:

def emap_name(t):
return t.placement.name + "_ext_map"

from_gid_map = self.scaffold.load_appendix(emap_name(from_type))
to_gid_map = self.scaffold.load_appendix(emap_name(to_type))
from_targets = self.scaffold.get_placement_set(from_type).identifiers
to_targets = self.scaffold.get_placement_set(to_type).identifiers
data[:, 0] = self._map(data[:, 0], from_gid_map, from_targets)
data[:, 1] = self._map(data[:, 1], to_gid_map, to_targets)
self.scaffold.connect_cells(self, data)

def _map(self, data, map, targets):
# Create a dict with pairs between the map and the target values
# Vectorize its dictionary lookup and perform the vector function on the data
try:
return np.vectorize(dict(zip(map, targets)).get)(data)
except TypeError:
raise SourceQualityError("Missing GIDs in external map.")


@config.node
class FixedIndegree(InvertedRoI, ConnectionStrategy):
"""
Expand All @@ -127,20 +50,11 @@ class FixedIndegree(InvertedRoI, ConnectionStrategy):

indegree = config.attr(type=int, required=True)

def get_region_of_interest(self, chunk):
from_chunks = set(
itertools.chain.from_iterable(
ct.get_placement_set().get_all_chunks()
for ct in self.presynaptic.cell_types
)
)
return from_chunks

def connect(self, pre, post):
in_ = self.indegree
rng = np.random.default_rng()
high = sum(len(ps) for ps in pre.placement.values())
for post_ct, ps in post.placement.items():
high = sum(len(ps) for ps in pre.placement)
for ps in post.placement:
l = len(ps)
pre_targets = np.full((l * in_, 3), -1)
post_targets = np.full((l * in_, 3), -1)
Expand All @@ -150,7 +64,7 @@ def connect(self, pre, post):
pre_targets[ptr : ptr + in_, 0] = rng.choice(high, in_, replace=False)
ptr += in_
lowmux = 0
for pre_ct, pre_ps in pre.placement.items():
for pre_ps in pre.placement:
highmux = lowmux + len(pre_ps)
demux_idx = (pre_targets[:, 0] >= lowmux) & (pre_targets[:, 0] < highmux)
demuxed = pre_targets[demux_idx]
Expand Down
4 changes: 2 additions & 2 deletions bsb/connectivity/import_.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __boot__(self):
)

def parse_source(self, pre, post):
pre = next(iter(pre.placement.values()))
post = next(iter(post.placement.values()))
pre = pre.placement[0]
post = post.placement[0]
if self.mapping_key:

def make_maps(pre_chunks, post_chunks):
Expand Down
32 changes: 16 additions & 16 deletions bsb/connectivity/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ..config import refs, types
from ..profiling import node_meter
from ..reporting import report, warn
from .._util import SortableByAfter, obj_str_insert
from .._util import SortableByAfter, obj_str_insert, ichain
import abc
from itertools import chain

Expand Down Expand Up @@ -38,23 +38,14 @@ def __iter__(self):

@property
def placement(self):
return {
ct: ct.get_placement_set(
self.roi,
return [
ct.get_placement_set(
chunks=self.roi,
labels=self.hemitype.labels,
morphology_labels=self.hemitype.morphology_labels,
)
for ct in self.hemitype.cell_types
}

def __getattr__(self, attr):
if attr == "placement":
return type(self).placement.__get__(self)
else:
return self.placement[attr]

def __getitem__(self, item):
return self.placement[item]
]


@config.dynamic(attr_name="strategy", required=True)
Expand Down Expand Up @@ -117,7 +108,6 @@ def connect_cells(self, pre_set, post_set, src_locs, dest_locs, tag=None):
)
cs.connect(pre_set, post_set, src_locs, dest_locs)

@abc.abstractmethod
def get_region_of_interest(self, chunk):
pass

Expand All @@ -141,7 +131,7 @@ def queue(self, pool):
rois = {
chunk: roi
for chunk in from_chunks
if (roi := self.get_region_of_interest(chunk))
if (roi := self.get_region_of_interest(chunk)) is None or len(roi)
}
if not rois:
warn(
Expand All @@ -156,3 +146,13 @@ def queue(self, pool):

def get_cell_types(self):
return set(self.presynaptic.cell_types) | set(self.postsynaptic.cell_types)

def get_all_pre_chunks(self):
all_ps = (ct.get_placement_set() for ct in self.presynaptic.cell_types)
chunks = set(ichain(ps.get_all_chunks() for ps in all_ps))
return list(chunks)

def get_all_post_chunks(self):
all_ps = (ct.get_placement_set() for ct in self.postsynaptic.cell_types)
chunks = set(ichain(ps.get_all_chunks() for ps in all_ps))
return list(chunks)
2 changes: 1 addition & 1 deletion bsb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def run_placement(self, strategies=None, DEBUG=True, pipelines=True):
if pipelines:
self.run_pipelines()
if strategies is None:
strategies = list(self.placement.values())
strategies = [*self.placement]
strategies = PlacementStrategy.resolve_order(strategies)
pool = create_job_pool(self)
if pool.is_master():
Expand Down
2 changes: 1 addition & 1 deletion bsb/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def queue(self, pool):
rois = {
chunk: roi
for chunk in to_chunks
if (roi := self.get_region_of_interest(chunk))
if (roi := self.get_region_of_interest(chunk)) is None or len(roi)
}
if not rois:
warn(
Expand Down
14 changes: 14 additions & 0 deletions bsb/storage/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,12 @@ def __copy__(self):
return ConnectivityIterator(self._cs, self._dir, lchunks, gchunks)

def __iter__(self):
"""
Iterate over the data chunk by chunk, offset as their global ids
:returns: presyn global cell IDs, postsyn global cell IDs (0 to N)
:rtype: Tuple[numpy.ndarray, numpy.ndarray]
"""
yield from (
self._offset_block(*data)
for data in self._cs.flat_iter_connections(
Expand All @@ -1103,6 +1109,14 @@ def __iter__(self):
)

def chunk_iter(self):
"""
Iterate over the data chunk by chunk, with the chunk-local cell IDs, and the local
and global chunks they stem from returned as well.
:returns: presyn chunk, presyn chunk-local cell IDs (0 to N), postsyn chunk,
postsyn chunk-local cell IDs (0 to N)
:rtype: Tuple[~bsb.storage.Chunk, numpy.ndarray, ~bsb.storage.Chunk, numpy.ndarray]
"""
yield from (
(data[2], data[3][1], data[1], data[3][0])
if dir == "inc"
Expand Down

0 comments on commit 2305e0f

Please sign in to comment.