Skip to content

Commit

Permalink
Reworked job submission system (#803)
Browse files Browse the repository at this point in the history
* Add placement name and class in error feedback

* Reorganisation of Pool and introduce error handling in master process

* Update dependencies with mpipool

* remove job id system

* Update in the job pool scheme

* removed event_loop argument

* Remove register_pool_listeners

* Renew the test_jobs suite

Add new tests for single and multiple jobs submissions, testing cancel job and listeners usage.

* Renewed job pool scheme

Simplify arguments handling for job execution methods
Cancel method is added for jobs
queue_function renamed in queue
Modify condition for execution loop
Add breaking condition for listeners

* Fix errors

Fixed run_placement
Fixed ConnectivityJob
Fixed the pool execution loop
Fixed test
Make Listeners to raise errors

* Remove exception raise from run_placement and run_connectivity

* Listeners handling moved to scaffold

* Fix run_after methods

Move some cancel  job tests to only parallel scheduler
Fix wrong argument in run_after methods

* Added expect failures and skip in unit tests

* Add SubmissionContext class

* Remodel Submission class and add tests

* Add test for submit context in no node case

* Change Job results

results handling is changed, now results are stored in tmp files

* Test adjusted for new get and set result

* Add a function to clean the listeners list

* Add delay before listeners last call

* Fix listeners usage in tests

* Switch to pickle for tmp file w/r

* wip tests

* Add PoolProgress class to handle contex with listeners

* update deps

* Pool rework finalized. Broken morphology pickling error

* parallel execution fixed, serial needs touch up

* changed imports, add docstr

* lint

* remove `test_entities` (old NEST based entities test, now relays)

* update to mpipool with shutdown(wait, cancel_futures) support

* fix unsupported pickling of morphology return value

* fix test flake

* fix serial error raising, store unhandled errors and raise in notify

* fixed tests

* fix tests

* fix docs

* leave comment why that's there

* fixed fixme, didnt need fixing

* unskip tests

* fix deadlock that occurs exclusively on GHA

* try to print the error on GHA

* fixed weird exception context error with a workaround.

* Address review comments

* Apply suggestions from code review

* added docs

* fix flow of pool based core functions

---------

Co-authored-by: Robin De Schepper <robin.deschepper93@gmail.com>
  • Loading branch information
filimarc and Helveg committed Mar 7, 2024
1 parent 4fc49a0 commit 83083a7
Show file tree
Hide file tree
Showing 26 changed files with 1,209 additions and 442 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ bsb-*/
# Merge note: NRRD & jupyter checkpoints
*.nrrd
.ipynb_checkpoints

bsb/services/try_tmp.py
24 changes: 24 additions & 0 deletions bsb/_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,26 @@ def get_default(self):
return False


class DebugPoolFlag(
BsbOption,
name="debug_pool",
cli=("dp", "debug_pool"),
project=("debug_pool",),
env=("BSB_DEBUG_POOL",),
script=("debug_pool",),
flag=True,
):
"""
Debug job pools
"""

def setter(self, value):
return bool(value)

def get_default(self):
return False


def verbosity():
return VerbosityOption

Expand All @@ -134,3 +154,7 @@ def config():

def profiling():
return ProfilingOption


def debug_pool():
return DebugPoolFlag
53 changes: 38 additions & 15 deletions bsb/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import itertools as _it
import os as _os
import sys as _sys
import typing
import typing as _t

import numpy as _np
import numpy as np

ichain = _it.chain.from_iterable


def merge_dicts(a, b):
"""
Merge 2 dictionaries and their subdictionaries
"""
for key in b:
if key in a and isinstance(a[key], dict) and isinstance(b[key], dict):
merge_dicts(a[key], b[key])
Expand All @@ -22,6 +24,10 @@ def merge_dicts(a, b):


def obj_str_insert(__str__):
"""
Decorator to insert the return value of __str__ into '<classname {returnvalue} at 0x...>'
"""

@_ft.wraps(__str__)
def wrapper(self):
obj_str = object.__repr__(self)
Expand All @@ -32,6 +38,11 @@ def wrapper(self):

@_ctxlib.contextmanager
def suppress_stdout():
"""
Context manager that attempts to silence regular stdout and stderr. Some binary
components may yet circumvene this if they access the underlying OS's stdout directly,
like streaming to `/dev/stdout`.
"""
with open(_os.devnull, "w") as devnull:
old_stdout = _sys.stdout
old_stderr = _sys.stderr
Expand Down Expand Up @@ -67,6 +78,9 @@ def listify_input(value):


def sanitize_ndarray(arr_input, shape, dtype=None):
"""
Convert an object to an ndarray and shape, avoiding to copy it wherever possible.
"""
kwargs = {"copy": False}
if dtype is not None:
kwargs["dtype"] = dtype
Expand All @@ -76,13 +90,21 @@ def sanitize_ndarray(arr_input, shape, dtype=None):


def assert_samelen(*args):
"""
Assert that all input arguments have the same length.
"""
len_ = None
assert all(
(len_ := len(arg) if len_ is None else len(arg)) == len_ for arg in args
), "Input arguments should be of same length."


def immutable():
"""
Decorator to mark a method as immutable, so that any calls to it return, and are
performed on, a copy of the instance.
"""

def immutable_decorator(f):
@_ft.wraps(f)
def immutable_action(self, *args, **kwargs):
Expand All @@ -95,7 +117,8 @@ def immutable_action(self, *args, **kwargs):
return immutable_decorator


def unique(iter_: typing.Iterable[typing.Any]):
def unique(iter_: _t.Iterable[_t.Any]):
"""Return a new list containing all the unique elements of an iterator"""
return [*set(iter_)]


Expand All @@ -107,19 +130,19 @@ def rotation_matrix_from_vectors(vec1, vec2):
:return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
"""
if (
np.isnan(vec1).any()
or np.isnan(vec2).any()
or not np.any(vec1)
or not np.any(vec2)
_np.isnan(vec1).any()
or _np.isnan(vec2).any()
or not _np.any(vec1)
or not _np.any(vec2)
):
raise ValueError("Vectors should not contain nan and their norm should not be 0.")
a = (vec1 / np.linalg.norm(vec1)).reshape(3)
b = (vec2 / np.linalg.norm(vec2)).reshape(3)
v = np.cross(a, b)
a = (vec1 / _np.linalg.norm(vec1)).reshape(3)
b = (vec2 / _np.linalg.norm(vec2)).reshape(3)
v = _np.cross(a, b)
if any(v): # if not all zeros then
c = np.dot(a, b)
s = np.linalg.norm(v)
kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2))
c = _np.dot(a, b)
s = _np.linalg.norm(v)
kmat = _np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
return _np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2))
else:
return np.eye(3) # cross of all zeros only occurs on identical directions
return _np.eye(3) # cross of all zeros only occurs on identical directions
1 change: 0 additions & 1 deletion bsb/connectivity/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

if typing.TYPE_CHECKING:
from ..cell_types import CellType
from ..connectivity import ConnectionStrategy
from ..core import Scaffold
from ..morphologies import MorphologySet
from ..storage.interfaces import PlacementSet
Expand Down
104 changes: 55 additions & 49 deletions bsb/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import os
import sys
import time
import typing

Expand All @@ -8,12 +9,13 @@
from ._util import obj_str_insert
from .config._config import Configuration
from .connectivity import ConnectionStrategy
from .exceptions import InputError, NodeNotFoundError, RedoError, ScaffoldError
from .exceptions import InputError, NodeNotFoundError, RedoError
from .placement import PlacementStrategy
from .profiling import meter
from .reporting import report, warn
from .services import MPI
from .services.pool import create_job_pool
from .services import MPI, JobPool
from .services._pool_listeners import NonTTYTerminalListener
from .services.pool import Job
from .simulation import get_simulation_adapter
from .storage import Chunk, Storage, open_storage

Expand Down Expand Up @@ -124,6 +126,9 @@ def __init__(self, config=None, storage=None, clear=False, comm=None):
:returns: A network object
:rtype: :class:`~.core.Scaffold`
"""
self._pool_listeners: list[tuple[typing.Callable[[list["Job"]], None], float]] = (
[]
)
self._configuration = None
self._storage = None
self._comm = comm or MPI
Expand Down Expand Up @@ -246,32 +251,23 @@ def resize(self, x=None, y=None, z=None):
)

@meter()
def run_placement(self, strategies=None, DEBUG=True, pipelines=True):
def run_placement(self, strategies=None, fail_fast=True, pipelines=True):
"""
Run placement strategies.
"""
if pipelines:
self.run_pipelines()
if strategies is None:
strategies = [*self.placement]
strategies = [*self.placement.values()]
strategies = PlacementStrategy.sort_deps(strategies)
pool = create_job_pool(self)
if pool.is_master():
pool = self.create_job_pool(fail_fast=fail_fast)
if pool.is_main():
for strategy in strategies:
strategy.queue(pool, self.network.chunk_size)
loop = self._progress_terminal_loop(pool, debug=DEBUG)
try:
pool.execute(loop)
except Exception:
self._stop_progress_loop(loop, debug=DEBUG)
raise
finally:
self._stop_progress_loop(loop, debug=DEBUG)
else:
pool.execute()
pool.execute()

@meter()
def run_connectivity(self, strategies=None, DEBUG=True, pipelines=True):
def run_connectivity(self, strategies=None, fail_fast=True, pipelines=True):
"""
Run connection strategies.
"""
Expand All @@ -280,20 +276,11 @@ def run_connectivity(self, strategies=None, DEBUG=True, pipelines=True):
if strategies is None:
strategies = set(self.connectivity.values())
strategies = ConnectionStrategy.sort_deps(strategies)
pool = create_job_pool(self)
if pool.is_master():
pool = self.create_job_pool(fail_fast=fail_fast)
if pool.is_main():
for strategy in strategies:
strategy.queue(pool)
loop = self._progress_terminal_loop(pool, debug=DEBUG)
try:
pool.execute(loop)
except Exception:
self._stop_progress_loop(loop, debug=DEBUG)
raise
finally:
self._stop_progress_loop(loop, debug=DEBUG)
else:
pool.execute()
pool.execute()

@meter()
def run_placement_strategy(self, strategy):
Expand All @@ -303,13 +290,13 @@ def run_placement_strategy(self, strategy):
self.run_placement([strategy])

@meter()
def run_after_placement(self, pipelines=True):
def run_after_placement(self, fail_fast=None, pipelines=True):
"""
Run after placement hooks.
"""
if self.after_placement:
warn("After placement disabled")
# pool = create_job_pool(self)
# pool = self.create_job_pool(fail_fast)
# for hook in self.configuration.after_placement.values():
# pool.queue(hook.after_placement)
# pool.execute(self._pool_event_loop)
Expand Down Expand Up @@ -337,6 +324,7 @@ def compile(
append=False,
redo=False,
force=False,
fail_fast=True,
):
"""
Run reconstruction steps in the scaffold sequence to obtain a full network.
Expand Down Expand Up @@ -375,17 +363,17 @@ def compile(
# append mode is luckily simpler, just don't clear anything :)

t = time.time()
self.run_pipelines()
self.run_pipelines(fail_fast=fail_fast)
if not skip_placement:
placement_todo = ", ".join(s.name for s in p_strats)
report(f"Starting placement strategies: {placement_todo}", level=2)
self.run_placement(p_strats, pipelines=False)
self.run_placement(p_strats, fail_fast=fail_fast, pipelines=False)
if not skip_after_placement:
self.run_after_placement(pipelines=False)
self.run_after_placement(pipelines=False, fail_fast=fail_fast)
if not skip_connectivity:
connectivity_todo = ", ".join(s.name for s in c_strats)
report(f"Starting connectivity strategies: {connectivity_todo}", level=2)
self.run_connectivity(c_strats, pipelines=False)
self.run_connectivity(c_strats, fail_fast=fail_fast, pipelines=False)
if not skip_after_connectivity:
self.run_after_connectivity(pipelines=False)
report("Runtime: {}".format(time.time() - t), 2)
Expand All @@ -394,23 +382,14 @@ def compile(
self.storage._preexisted = True

@meter()
def run_pipelines(self, pipelines=None, DEBUG=True):
def run_pipelines(self, fail_fast=True, pipelines=None):
if pipelines is None:
pipelines = self.get_dependency_pipelines()
pool = create_job_pool(self)
if pool.is_master():
pool = self.create_job_pool(fail_fast=fail_fast)
if pool.is_main():
for pipeline in pipelines:
pipeline.queue(pool)
loop = self._progress_terminal_loop(pool, debug=DEBUG)
try:
pool.execute(loop)
except Exception:
self._stop_progress_loop(loop, debug=DEBUG)
raise
finally:
self._stop_progress_loop(loop, debug=DEBUG)
else:
pool.execute()
pool.execute()

@meter()
def run_simulation(self, simulation_name: str):
Expand Down Expand Up @@ -785,6 +764,33 @@ def _load_cs_types(
) from None
return cs

def create_job_pool(self, fail_fast=None, quiet=False):
pool = JobPool(self, fail_fast=fail_fast)
try:
tty = os.isatty(sys.stdout.fileno())
except Exception:
tty = False
if tty:
# todo: Create the TTY terminal listener
default_listener = NonTTYTerminalListener
else:
default_listener = NonTTYTerminalListener
if self._pool_listeners:
for listener, max_wait in self._pool_listeners:
pool.add_listener(listener, max_wait=max_wait)
elif not quiet:
pool.add_listener(default_listener())
return pool

def register_listener(self, listener, max_wait=None):
self._pool_listeners.append((listener, max_wait))

def remove_listener(self, listener):
for i, (l, _) in enumerate(self._pool_listeners):
if l is listener:
self._pool_listeners.pop(i)
break


class ReportListener:
def __init__(self, scaffold, file):
Expand Down
4 changes: 3 additions & 1 deletion bsb/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
GatewayError=_e(
AllenApiError=_e(),
),
JobPoolError=_e(
JobCancelledError=_e(),
),
TopologyError=_e(
UnmanagedPartitionError=_e(),
LayoutError=_e(),
Expand Down Expand Up @@ -122,7 +125,6 @@
),
),
ClassError=_e(),
TestError=_e(FixtureError=_e()),
),
)

Expand Down

0 comments on commit 83083a7

Please sign in to comment.