Skip to content

Commit

Permalink
Improved dependency system (#779)
Browse files Browse the repository at this point in the history
* remove dead file

* add HasDependencies mixin

* remove SortableByAfter

* switch to HasDependencies

* remove unused OrderError

* fix test imports

* add job sorting tests

* fix import after move
  • Loading branch information
Helveg committed Nov 30, 2023
1 parent 26e849b commit 7d0df7f
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 327 deletions.
119 changes: 1 addition & 118 deletions bsb/_util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import itertools as _it
import abc as _abc
import os as _os
import sys as _sys
import contextlib as _ctxlib
import typing

import numpy as np

import numpy as _np
from .exceptions import OrderError as _OrderError
import functools

ichain = _it.chain.from_iterable
Expand Down Expand Up @@ -80,122 +79,6 @@ def assert_samelen(*args):
), "Input arguments should be of same length."


class SortableByAfter:
@_abc.abstractmethod
def has_after(self):
pass

@_abc.abstractmethod
def create_after(self):
pass

@_abc.abstractmethod
def get_after(self):
pass

@_abc.abstractmethod
def get_ordered(self, objects):
pass

def add_after(self, after_item):
if not self.has_after():
self.create_after()
self.get_after().append(after_item)

def is_after_satisfied(self, objects):
"""
Determine whether the ``after`` specification of this object is met. Any objects
appearing in ``self.after`` need to occur in ``objects`` before the object.
:param objects: Proposed order for which the after condition is checked.
:type objects: list
"""
if not self.has_after(): # No after?
# Condition without constraints always True.
return True
self_met = False
after = self.get_after()
# Determine whether this object is out of order.
for type in objects:
if type is self:
# We found ourselves, from this point on nothing that appears in the after
# array is allowed to be encountered
self_met = True
elif self_met and type in after:
# We have encountered ourselves, so everything we find from now on is not
# allowed to be in our after array, if it is, we fail the after condition.
return False
# We didn't meet anything behind us that was supposed to be in front of us
# => Condition met.
return True

def satisfy_after(self, objects):
"""
Given an array of objects, place this object after all of the objects specified in
the ``after`` condition. If objects in the after condition are missing from the
given array this object is placed at the end of the array. Modifies the `objects`
array in place.
"""
before_types = self.get_after().copy()
i = 0
place_after = False
# Loop over the objects until we've found all our before types.
while len(before_types) > 0 and i < len(objects):
if objects[i] in before_types:
# We encountered one of our before types and can remove it from the list
# of things we still need to look for
before_types.remove(objects[i])
# We increment i unless we encounter and remove ourselves
if objects[i] == self:
# We are still in the loop, so there must still be things in our after
# condition that we are looking for; therefor we remove ourselves from
# the list and wait until we found all our conditions and place ourselves
# there
objects.remove(self)
place_after = True
else:
i += 1
if place_after:
# We've looped to either after our last after condition or the last element
# and should reinsert ourselves here.
objects.insert(i, self)

@classmethod
def resolve_order(cls, objects):
"""
Orders a given dictionary of objects by the class's default mechanism and
then apply the `after` attribute for further restrictions.
"""
# Sort by the default approach
sorting_objects = list(cls.get_ordered(objects))
# Afterwards cell types can be specified that need to be placed after other types.
after_specifications = [c for c in sorting_objects if c.has_after()]
j = 0
# Keep rearranging as long as any cell type's after condition isn't satisfied.
while any(
not c.is_after_satisfied(sorting_objects) for c in after_specifications
):
j += 1
# Rearrange each element that is out of place.
for after_type in after_specifications:
if not after_type.is_after_satisfied(sorting_objects):
after_type.satisfy_after(sorting_objects)
# If we have had to rearrange all elements more than there are elements, the
# conditions cannot be met, and a circular dependency is at play.
if j > len(objects):
circulars = ", ".join(
c.name
for c in after_specifications
if not c.is_after_satisfied(sorting_objects)
)
raise _OrderError(
f"Couldn't resolve order, probably a circular dependency including: "
f"{circulars}"
)
# Return the sorted array.
return sorting_objects


def immutable():
def immutable_decorator(f):
@functools.wraps(f)
Expand Down
38 changes: 16 additions & 22 deletions bsb/connectivity/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from ..config import refs, types
from ..profiling import node_meter
from ..reporting import report, warn
from .._util import SortableByAfter, obj_str_insert, ichain
from .._util import obj_str_insert, ichain
from ..mixins import HasDependencies
import abc
from itertools import chain

Expand Down Expand Up @@ -60,25 +61,29 @@ def placement(self):


@config.dynamic(attr_name="strategy", required=True)
class ConnectionStrategy(abc.ABC, SortableByAfter):
class ConnectionStrategy(abc.ABC, HasDependencies):
scaffold: "Scaffold"
name: str = config.attr(key=True)
"""Name used to refer to the connectivity strategy"""
presynaptic: Hemitype = config.attr(type=Hemitype, required=True)
"""Presynaptic (source) neuron population"""
postsynaptic: Hemitype = config.attr(type=Hemitype, required=True)
"""Postsynaptic (target) neuron population"""
after: list["ConnectionStrategy"] = config.reflist(refs.connectivity_ref)
"""
This strategy should be executed only after all the connections in this list have
been executed.
"""
depends_on: list["ConnectionStrategy"] = config.reflist(refs.connectivity_ref)
"""The list of strategies that must run before this one"""

def __init_subclass__(cls, **kwargs):
super(cls, cls).__init_subclass__(**kwargs)
# Decorate subclasses to measure performance
node_meter("connect")(cls)

def __hash__(self):
return id(self)

def __lt__(self, other):
# This comparison should sort connection strategies by name, via __repr__ below
return str(self) < str(other)

def __boot__(self):
self._queued_jobs = []

Expand All @@ -90,24 +95,13 @@ def __repr__(self):
post = [ct.name for ct in self.postsynaptic.cell_types]
return f"'{self.name}', connecting {pre} to {post}"

@classmethod
def get_ordered(cls, objects):
# No need to sort connectivity strategies, just obey dependencies.
return objects

def get_after(self):
return [] if not self.has_after() else self.after

def has_after(self):
return hasattr(self, "after")

def create_after(self):
self.after = []

@abc.abstractmethod
def connect(self, presyn_collection, postsyn_collection):
pass

def get_deps(self):
return set(self.depends_on)

def _get_connect_args_from_job(self, pre_roi, post_roi):
pre = HemitypeCollection(self.presynaptic, pre_roi)
post = HemitypeCollection(self.postsynaptic, post_roi)
Expand Down Expand Up @@ -135,7 +129,7 @@ def queue(self, pool):
# Reset jobs that we own
self._queued_jobs = []
# Get the queued jobs of all the strategies we depend on.
deps = set(chain.from_iterable(strat._queued_jobs for strat in self.get_after()))
deps = set(chain.from_iterable(strat._queued_jobs for strat in self.get_deps()))
pre_types = self.presynaptic.cell_types
# Iterate over each chunk that is populated by our presynaptic cell types.
from_chunks = set(
Expand Down
6 changes: 3 additions & 3 deletions bsb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def run_placement(self, strategies=None, DEBUG=True, pipelines=True):
self.run_pipelines()
if strategies is None:
strategies = [*self.placement]
strategies = PlacementStrategy.resolve_order(strategies)
strategies = PlacementStrategy.sort_deps(strategies)
pool = create_job_pool(self)
if pool.is_master():
for strategy in strategies:
Expand All @@ -282,8 +282,8 @@ def run_connectivity(self, strategies=None, DEBUG=True, pipelines=True):
if pipelines:
self.run_pipelines()
if strategies is None:
strategies = list(self.connectivity.values())
strategies = ConnectionStrategy.resolve_order(strategies)
strategies = set(self.connectivity.values())
strategies = ConnectionStrategy.sort_deps(strategies)
pool = create_job_pool(self)
if pool.is_master():
for strategy in strategies:
Expand Down
1 change: 0 additions & 1 deletion bsb/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@
JsonImportError=_e(),
),
),
OrderError=_e(),
ClassError=_e(),
TestError=_e(FixtureError=_e()),
),
Expand Down

0 comments on commit 7d0df7f

Please sign in to comment.