Skip to content

Commit

Permalink
Adds NotParallel and uses it for ParallelArray (#665)
Browse files Browse the repository at this point in the history
* Added `unique` to global util file

* Added `NotParallel` mixin

* Fixed ParallelArrayPlacement; Now is a serial strategy

* Revert "Fixed ParallelArrayPlacement; Now is a serial strategy"

This reverts commit 558459a.

* Fixed ParallelArrayPlacement; Now is a serial strategy

* Fixed ParallelArrayPlacement; Calling place_cells with arguments 'indicator, [p], chunk=mychunk' allows to select the morphologies correctly.

* Now place_cells is called once per destination chunk

* Fixed No morphology data available error appearing in the case of a single cell per chunk

* Revert "Fixed No morphology data available error appearing in the case of a single cell per chunk"

This reverts commit a64989a.

* Modified spacing_x; Now it does not depend on the angle of the array.

* Update bsb/placement/arrays.py

Co-authored-by: Robin De Schepper <robin.deschepper93@gmail.com>

* Reformatted arrays.py with black

Co-authored-by: Robin De Schepper <robin.deschepper93@gmail.com>
  • Loading branch information
alessiomarta and Helveg committed Dec 16, 2022
1 parent d702e09 commit e9029a0
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 3 deletions.
6 changes: 6 additions & 0 deletions bsb/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os as _os
import sys as _sys
import contextlib as _ctxlib
import typing

import numpy as _np
from .exceptions import OrderError as _OrderError
import functools
Expand Down Expand Up @@ -190,3 +192,7 @@ def resolve_order(cls, objects):
)
# Return the sorted array.
return sorting_objects


def unique(iter_: typing.Iterable[typing.Any]):
return [*set(iter_)]
61 changes: 61 additions & 0 deletions bsb/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from .connectivity import ConnectionStrategy
from .placement import PlacementStrategy
from .reporting import report
from .storage import Chunk
from . import _util as _gutil

import itertools


def _queue_placement(self, pool, chunk_size):
# Reset jobs that we own
self._queued_jobs = []
# Get the queued jobs of all the strategies we depend on.
deps = set(itertools.chain(*(strat._queued_jobs for strat in self.get_after())))
# todo: perhaps pass the volume or partition boundaries as chunk size
job = pool.queue_placement(self, Chunk([0, 0, 0], None), deps=deps)
self._queued_jobs.append(job)
report(f"Queued serial job for {self.name}", level=2)


def _all_chunks(iter_):
return _gutil.unique(
itertools.chain.from_iterable(
ct.get_placement_set().get_all_chunks() for ct in iter_
)
)


def _queue_connectivity(self, pool):
# Reset jobs that we own
self._queued_jobs = []
# Get the queued jobs of all the strategies we depend on.
deps = set(
itertools.chain.from_iterable(strat._queued_jobs for strat in self.get_after())
)
# Schedule all chunks in 1 job
pre_chunks = _all_chunks(self.presynaptic.cell_types)
post_chunks = _all_chunks(self.postsynaptic.cell_types)
job = pool.queue_connectivity(self, pre_chunks, post_chunks, deps=deps)
self._queued_jobs.append(job)
report(f"Queued serial job for {self.name}", level=2)


def _raise_na(*args, **kwargs):
raise NotImplementedError("NotParallel connection strategies have no RoI.")


class NotParallel:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if PlacementStrategy in cls.__mro__:
cls.queue = _queue_placement
elif ConnectionStrategy in cls.__mro__:
cls.queue = _queue_connectivity
if "get_region_of_interest" not in cls.__dict__:
cls.get_region_of_interest = _raise_na
else:
raise Exception(
"NotParallel can only be applied to placement or "
"connectivity strategies"
)
22 changes: 19 additions & 3 deletions bsb/placement/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import math, numpy as np
from .. import config
from ..config import types
from ..mixins import NotParallel
from ..storage import Chunk
from ..reporting import report, warn


@config.node
class ParallelArrayPlacement(PlacementStrategy):
class ParallelArrayPlacement(NotParallel, PlacementStrategy):
"""
Implementation of the placement of cells in parallel arrays.
"""
Expand Down Expand Up @@ -37,7 +40,7 @@ def place(self, chunk, indicators):
# Amount of parallel arrays of cells
n_arrays = x_pos.shape[0]
# Number of cells
n = np.sum(indicator.guess(chunk))
n = np.sum(indicator.guess(prt.data))
# Add extra cells to fill the lattice error volume which will be pruned
n += int((n_arrays * spacing_x % width) / width * n)
# cells to distribute along the rows
Expand Down Expand Up @@ -79,4 +82,17 @@ def place(self, chunk, indicators):
cells[(i * len(x)) : ((i + 1) * len(x)), 2] = z
# Place all the cells in 1 batch (more efficient)
positions = cells[cells[:, 0] < width - radius]
self.place_cells(indicator, positions, chunk=chunk)

# Determine in which chunks the cells must be placed
cs = self.scaffold.configuration.network.chunk_size
chunks_list = np.array(
[chunk.data + np.floor_divide(p, cs[0]) for p in positions]
)
unique_chunks_list = np.unique(chunks_list, axis=0)

# For each chunk, place the cells
for c in unique_chunks_list:
idx = np.where((chunks_list == c).all(axis=1))
pos_current_chunk = positions[idx]
self.place_cells(indicator, pos_current_chunk, chunk=c)
report(f"Placed {len(positions)} {cell_type.name} in {prt.name}", level=3)
53 changes: 53 additions & 0 deletions tests/test_placement.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import unittest, os, sys, numpy as np, h5py

from bsb import config
from bsb.connectivity import ConnectionStrategy
from bsb.mixins import NotParallel

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__)))
from bsb.services import MPI
Expand Down Expand Up @@ -153,6 +157,47 @@ def test_chunked_job(self):
job = pool.queue_chunk(test_chunk, _chunk(0, 0, 0))
pool.execute()

def test_notparallel_ps_job(test):
spy = 0

@config.node
class SerialPStrat(NotParallel, PlacementStrategy):
def place(self, chunk, indicators):
nonlocal spy
test.assertEqual(Chunk([0, 0, 0], None), chunk)
spy += 1

pool = JobPool(_net)
pstrat = _net.placement.add(
"test", SerialPStrat(strategy="", cell_types=[], partitions=[])
)
pstrat.queue(pool, None)
pool.execute()
test.assertEqual(1, sum(MPI.allgather(spy)))

def test_notparallel_cs_job(test):
spy = 0

@config.node
class SerialCStrat(NotParallel, ConnectionStrategy):
def connect(self, pre, post):
nonlocal spy

spy += 1

pool = JobPool(_net)
cstrat = _net.connectivity.add(
"test",
SerialCStrat(
strategy="",
presynaptic={"cell_types": []},
postsynaptic={"cell_types": []},
),
)
cstrat.queue(pool)
pool.execute()
test.assertEqual(1, sum(MPI.allgather(spy)))


@unittest.skipIf(MPI.get_size() < 2, "Skipped during serial testing.")
class TestParallelScheduler(unittest.TestCase, SchedulerBaseTest):
Expand Down Expand Up @@ -205,6 +250,14 @@ def spy_queue(jobs):
if not MPI.get_rank():
self.assertTrue(result, "A job with unfinished dependencies was scheduled.")

@unittest.expectedFailure
def test_notparallel_cs_job(test):
raise Exception("MPI voodoo deadlocks simple nonlocal assigment")

@unittest.expectedFailure
def test_notparallel_ps_job(test):
raise Exception("MPI voodoo deadlocks simple nonlocal assigment")


@unittest.skipIf(MPI.get_size() > 1, "Skipped during parallel testing.")
class TestSerialScheduler(unittest.TestCase, SchedulerBaseTest):
Expand Down

0 comments on commit e9029a0

Please sign in to comment.