Skip to content

Commit

Permalink
Make CS tags by default equal to config key (#659)
Browse files Browse the repository at this point in the history
* fixed tags not respecting CS config key

* adjusted CS tags in the tests

* 1 snuck through
  • Loading branch information
Helveg committed Dec 16, 2022
1 parent e9029a0 commit fbda191
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 31 deletions.
2 changes: 1 addition & 1 deletion bsb/connectivity/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _get_connect_args_from_job(self, chunk, roi):

def connect_cells(self, pre_set, post_set, src_locs, dest_locs, tag=None):
cs = self.scaffold.require_connectivity_set(
pre_set.cell_type, post_set.cell_type, tag
pre_set.cell_type, post_set.cell_type, tag if tag is not None else self.name
)
cs.connect(pre_set, post_set, src_locs, dest_locs)

Expand Down
54 changes: 25 additions & 29 deletions tests/test_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setUp(self):

def test_per_block(self):
# Test that connections can be stored over chunked layout and can be loaded again.
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
for lchunk, g_itr in cs.nested_iter_connections(direction="out"):
for gchunk, conns in g_itr:
ids = conns[0][:, 0]
Expand All @@ -52,12 +52,10 @@ def test_per_block(self):
self.assertEqual(25, len(u), "expected exactly 25 global cells")
self.assertClose(np.arange(0, 25), np.sort(u))
self.assertClose(25, c)
self.assertEqual(
100 * 100, len(self.network.get_connectivity_set("test_cell_to_test_cell"))
)
self.assertEqual(100 * 100, len(self.network.get_connectivity_set("all_to_all")))

def test_per_local(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
for lchunk in cs.get_local_chunks(direction="out"):
local_locs, gchunk_ids, global_locs = cs.load_local_connections("out", lchunk)
ids = local_locs[:, 0]
Expand All @@ -72,9 +70,7 @@ def test_per_local(self):
self.assertEqual(25, len(u), "expected exactly 25 global cells")
self.assertClose(np.arange(0, 25), np.sort(u))
self.assertClose(100, c, "expected 25 local sources per global cell")
self.assertEqual(
100 * 100, len(self.network.get_connectivity_set("test_cell_to_test_cell"))
)
self.assertEqual(100 * 100, len(self.network.get_connectivity_set("all_to_all")))


class TestConnectivitySet(
Expand All @@ -99,7 +95,7 @@ def setUp(self):
self.network.compile(clear=True)

def test_load_all(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
data = cs.load_connections()
try:
lcol, lloc, gcol, gloc = data
Expand All @@ -111,7 +107,7 @@ def test_load_all(self):
self.assertEqual(10000, len(gloc), "expected full 10k global locs")

def test_load_local(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
chunks = cs.get_local_chunks("inc")
data = cs.load_local_connections("inc", chunks[0])
try:
Expand All @@ -127,12 +123,12 @@ def test_load_local(self):
self.assertEqual(100, unique_globals, "Expected 100 globals")

def test_flat_iter(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
itr = cs.flat_iter_connections()
self.check_a2a_flat_iter(itr, ["inc", "out"], 4, 4)

def test_nested_iter(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
try:
iter(cs.nested_iter_connections())
except TypeError:
Expand Down Expand Up @@ -170,27 +166,27 @@ def test_nested_iter(self):
)

def test_incoming(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
self.check_a2a_flat_iter(iter(cs.incoming), ["inc"], 4, 4)

def test_outgoing(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
self.check_a2a_flat_iter(iter(cs.outgoing), ["out"], 4, 4)

def test_from(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
chunks = cs.get_local_chunks("inc")
self.check_a2a_flat_iter(iter(cs.from_(chunks)), ["out"], 4, 4)
self.check_a2a_flat_iter(iter(cs.from_(chunks[0])), ["out"], 1, 4)

def test_to(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
chunks = cs.get_local_chunks("inc")
self.check_a2a_flat_iter(iter(cs.to(chunks)), ["out"], 4, 4)
self.check_a2a_flat_iter(iter(cs.to(chunks[0])), ["out"], 4, 1)

def test_from_to(self):
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
chunks = cs.get_local_chunks("inc")
self.check_a2a_flat_iter(iter(cs.from_(chunks).to(chunks)), ["out"], 4, 4)
self.check_a2a_flat_iter(iter(cs.to(chunks).from_(chunks)), ["out"], 4, 4)
Expand Down Expand Up @@ -307,14 +303,14 @@ def setUp(self):
def test_from_label(self):
self.network.connectivity.all_to_all.presynaptic.labels = ["from_X"]
self.network.compile(append=True, skip_placement=True)
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
allcon = cs.load_connections()[0]
self.assertEqual(300, len(allcon), "should have 3 x 100 cells with from_X label")

def test_to_label(self):
self.network.connectivity.all_to_all.postsynaptic.labels = ["from_X"]
self.network.compile(append=True, skip_placement=True)
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
allcon = cs.load_connections()[0]
self.assertEqual(300, len(allcon), "should have 100 x 3 cells with from_X label")

Expand All @@ -325,7 +321,7 @@ def test_dupe_from_labels(self):
"from_Y",
]
self.network.compile(append=True, skip_placement=True)
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
allcon = cs.load_connections()[0]
self.assertEqual(500, len(allcon), "should have 3 x 100 cells with from_X label")

Expand All @@ -337,7 +333,7 @@ def test_dupe_labels(self):
]
self.network.connectivity.all_to_all.postsynaptic.labels = ["from_X", "from_F"]
self.network.compile(append=True, skip_placement=True)
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("all_to_all")
allcon = cs.load_connections()[0]
self.assertEqual(
(3 + 2) * 5, len(allcon), "should have 3 x 100 cells with from_X label"
Expand Down Expand Up @@ -413,7 +409,7 @@ def connect_spy(strat, pre, post):
except Exception as e:
raise
self.fail(f"Unexpected error: {e}")
cs = self.network.get_connectivity_set("test_cell_to_test_cell")
cs = self.network.get_connectivity_set("self_intersect")
_, sloc, _, dloc = cs.load_connections()
self.assertAll(sloc > -1, "expected only true conn")
self.assertAll(dloc > -1, "expected only true conn")
Expand Down Expand Up @@ -544,7 +540,7 @@ def setUp(self):
def test_single_voxel(self):
# Tests whethervoxel intersection works using a few fixed positions and outcomes.
self.network.compile()
cs = self.network.get_connectivity_set("test_cell_A_to_test_cell_B")
cs = self.network.get_connectivity_set("intersect")
pre_chunks, pre_locs, post_chunks, post_locs = cs.load_connections()
self.assertClose(0, pre_chunks, "expected only conns in base chunk")
self.assertClose(0, post_chunks, "expected only conns in base chunk")
Expand All @@ -563,7 +559,7 @@ def test_single_voxel_labelled(self):
self.network.connectivity.intersect.presynaptic.morphology_labels = ["tip"]
self.network.connectivity.intersect.postsynaptic.morphology_labels = ["top"]
self.network.compile()
cs = self.network.get_connectivity_set("test_cell_A_to_test_cell_B")
cs = self.network.get_connectivity_set("intersect")
pre_chunks, pre_locs, post_chunks, post_locs = cs.load_connections()
self.assertClose(0, pre_chunks, "expected only conns in base chunk")
self.assertClose(0, post_chunks, "expected only conns in base chunk")
Expand All @@ -581,7 +577,7 @@ def test_single_voxel_label404(self):
self.network.connectivity.intersect.presynaptic.morphology_labels = ["tip"]
self.network.connectivity.intersect.postsynaptic.morphology_labels = ["top"]
self.network.compile()
cs = self.network.get_connectivity_set("test_cell_A_to_test_cell_B")
cs = self.network.get_connectivity_set("intersect")
pre_chunks, pre_locs, post_chunks, post_locs = cs.load_connections()
self.assertClose(0, pre_chunks, "expected only conns in base chunk")
self.assertClose(0, post_chunks, "expected only conns in base chunk")
Expand Down Expand Up @@ -612,21 +608,21 @@ def test_contacts(self):
self.network.placement.fixed_pos_B.positions = [[0, 0, 0]]
self.network.cell_types.test_cell_A.spatial.morphologies[0].names = ["C"]
self.network.compile()
conns = len(self.network.get_connectivity_set("test_cell_A_to_test_cell_B"))
conns = len(self.network.get_connectivity_set("intersect"))
self.assertGreater(conns, 0, "no connections formed")
self.network.connectivity.intersect.contacts = 2
self.network.compile(clear=True)
new_conns = len(self.network.get_connectivity_set("test_cell_A_to_test_cell_B"))
new_conns = len(self.network.get_connectivity_set("intersect"))
self.assertEqual(conns * 2, new_conns, "Expected double contacts")

def test_zero_contacts(self):
self.network.connectivity.intersect.contacts = 0
self.network.placement.fixed_pos_B.positions = [[100, 0, 0]]
self.network.cell_types.test_cell_A.spatial.morphologies[0].names = ["C"]
self.network.compile()
conns = len(self.network.get_connectivity_set("test_cell_A_to_test_cell_B"))
conns = len(self.network.get_connectivity_set("intersect"))
self.assertEqual(0, conns, "expected no contacts")
self.network.connectivity.intersect.contacts = -3
self.network.compile(clear=True)
conns = len(self.network.get_connectivity_set("test_cell_A_to_test_cell_B"))
conns = len(self.network.get_connectivity_set("intersect"))
self.assertEqual(0, conns, "expected no contacts")
2 changes: 1 addition & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_str(self):
*self.network.placement.values(),
*self.network.connectivity.values(),
self.network.get_placement_set("test_cell"),
self.network.get_connectivity_set("test_cell_to_test_cell"),
self.network.get_connectivity_set("all_to_all"),
):
self.assertNotEqual(object.__repr__(obj), str(obj))
self.assertNotEqual(object.__repr__(obj), repr(obj))

0 comments on commit fbda191

Please sign in to comment.