Skip to content

Commit

Permalink
Add static type hints to config nodes (#757)
Browse files Browse the repository at this point in the history
* change pyver in CI

* renamed and type hinted

* type hinting

* black

* fix console scripts

* bumped deps, fix 3.12?

* add libhdf5-dev for morphio wheel building

* patch entry points pre 3.10

* add a module spec, might break things?

* remove currently unsupported feature from tests

* maybe None works better as spec

* revert none spec, fix doc link

* fix docs

* bump sphinx

* bump doc reqs
  • Loading branch information
Helveg committed Oct 23, 2023
1 parent 0cdfb28 commit 7bdf19b
Show file tree
Hide file tree
Showing 38 changed files with 346 additions and 211 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v3.5.0
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -18,7 +18,7 @@ jobs:
- name: Install apt dependencies
run: |
sudo apt-get update
sudo apt-get install openmpi-bin libopenmpi-dev
sudo apt-get install openmpi-bin libopenmpi-dev libhdf5-dev
- name: Cache pip
uses: actions/cache@v3
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
sudo apt-get install -y openmpi-bin libopenmpi-dev
- name: Install dependencies
run: |
python -m pip install pip
python -m pip install --upgrade pip
pip install -r docs/requirements.txt
- name: Install self
run: |
Expand Down
2 changes: 2 additions & 0 deletions bsb/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import glob
import itertools
from importlib.machinery import ModuleSpec
from shutil import copy2 as copy_file
import builtins
import traceback
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(self, name):
# The __path__ attribute needs to be retained to mark this module as a package with
# submodules (config.refs, config.parsers.json, ...)
__path__ = _path
__spec__ = ModuleSpec(__name__, __loader__, origin=__file__)

# Load the Configuration class on demand, not on import, to avoid circular
# dependencies.
Expand Down
8 changes: 8 additions & 0 deletions bsb/config/_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,10 @@ def should_call_default(self):


class cfglist(builtins.list):
"""
Extension of the builtin list to manipulate lists of configuration nodes.
"""

def get_node_name(self):
return self._config_parent.get_node_name() + "." + self._config_attr_name

Expand Down Expand Up @@ -679,6 +683,10 @@ def tree(self, instance):


class cfgdict(builtins.dict):
"""
Extension of the builtin dictionary to manipulate dicts of configuration nodes.
"""

def __getattr__(self, name):
try:
return self[name]
Expand Down
89 changes: 62 additions & 27 deletions bsb/config/_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import typing

from . import attr, list, dict, root, node, types
from .. import config
from . import types
from ..cell_types import CellType
from ._attrs import _boot_nodes
from ._attrs import _boot_nodes, cfgdict, cfglist
from ..placement import PlacementStrategy
from ..storage._files import CodeDependencyNode, MorphologyDependencyNode
from ..storage.interfaces import StorageNode
Expand All @@ -17,24 +18,25 @@
Region,
Partition,
)
import builtins
import numpy as np

if typing.TYPE_CHECKING:
from ..core import Scaffold


@node
@config.node
class NetworkNode:
scaffold: "Scaffold"

x = attr(type=float, required=True)
y = attr(type=float, required=True)
z = attr(type=float, required=True)
origin = attr(
type=types.list(type=float, size=3), default=lambda: [0, 0, 0], call_default=True
x: float = config.attr(type=float, required=True)
y: float = config.attr(type=float, required=True)
z: float = config.attr(type=float, required=True)
origin: list[float] = config.attr(
type=types.list(type=float, size=3),
default=lambda: [0, 0, 0],
call_default=True,
)
chunk_size = attr(
chunk_size: list[float] = config.attr(
type=types.or_(
types.list(float),
types.scalar_expand(float, expand=lambda s: np.ones(3) * s),
Expand All @@ -47,32 +49,65 @@ def boot(self):
self.chunk_size = np.array(self.chunk_size)


@root
@config.root
class Configuration:
"""
The main Configuration object containing the full definition of a scaffold model.
"""

scaffold: "Scaffold"

name = attr()
components = list(type=CodeDependencyNode)
morphologies = list(type=MorphologyDependencyNode)
storage = attr(type=StorageNode, required=True)
network = attr(type=NetworkNode, required=True)
regions = dict(type=Region)
partitions = dict(type=Partition, required=True)
cell_types = dict(type=CellType, required=True)
placement = dict(type=PlacementStrategy, required=True)
after_placement = dict(type=PostProcessingHook)
connectivity = dict(type=ConnectionStrategy, required=True)
after_connectivity = dict(type=PostProcessingHook)
simulations = dict(type=Simulation)
name: str = config.attr()
"""
Descriptive name of the model
"""
components: cfglist[CodeDependencyNode] = config.list(
type=CodeDependencyNode,
)
morphologies: cfglist[MorphologyDependencyNode] = config.list(
type=MorphologyDependencyNode,
)
storage: StorageNode = config.attr(
type=StorageNode,
required=True,
)
network: NetworkNode = config.attr(
type=NetworkNode,
required=True,
)
regions: cfgdict[str, Region] = config.dict(
type=Region,
)
partitions: cfgdict[str, Partition] = config.dict(
type=Partition,
required=True,
)
cell_types: cfgdict[str, CellType] = config.dict(
type=CellType,
required=True,
)
placement: cfgdict[str, PlacementStrategy] = config.dict(
type=PlacementStrategy,
required=True,
)
after_placement: cfgdict[str, PostProcessingHook] = config.dict(
type=PostProcessingHook,
)
connectivity: cfgdict[str, ConnectionStrategy] = config.dict(
type=ConnectionStrategy,
required=True,
)
after_connectivity: cfgdict[str, PostProcessingHook] = config.dict(
type=PostProcessingHook,
)
simulations: cfgdict[str, Simulation] = config.dict(
type=Simulation,
)
__module__ = "bsb.config"

@classmethod
def default(cls, **kwargs):
default_args = builtins.dict(
default_args = dict(
storage={"engine": "hdf5"},
network={"x": 200, "y": 200, "z": 200},
partitions={},
Expand All @@ -91,7 +126,7 @@ def _bootstrap(self, scaffold):
_boot_nodes(self, scaffold)
self._config_isbooted = True
# Initialise the topology from the defined regions
regions = builtins.list(self.regions.values())
regions = list(self.regions.values())
# Arrange the topology based on network boundaries
start = self.network.origin.copy()
net = self.network
Expand All @@ -100,7 +135,7 @@ def _bootstrap(self, scaffold):
if unmanaged := set(self.partitions.values()) - get_partitions(regions):
p = "', '".join(p.name for p in unmanaged)
r = scaffold.regions.add(
"__unmanaged__", RegionGroup(children=builtins.list(unmanaged))
"__unmanaged__", RegionGroup(children=list(unmanaged))
)
regions.append(r)
scaffold.topology = create_topology(regions, start, end)
Expand Down
6 changes: 4 additions & 2 deletions bsb/config/_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
@config.node
class Distribution:
scaffold: "Scaffold"
distribution = config.attr(type=types.in_(_available_distributions), required=True)
parameters = config.catch_all(type=types.any_())
distribution: str = config.attr(
type=types.in_(_available_distributions), required=True
)
parameters: dict[str, typing.Any] = config.catch_all(type=types.any_())

def __init__(self, **kwargs):
if self.distribution == "constant":
Expand Down
14 changes: 8 additions & 6 deletions bsb/connectivity/general.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import itertools
import os
import typing

import numpy as np
import functools
from .strategy import ConnectionStrategy
from ..exceptions import SourceQualityError
from .. import config, _util as _gutil
from .. import config
from ..config import types
from ..mixins import InvertedRoI
from ..reporting import warn

if typing.TYPE_CHECKING:
from ..config import Distribution


@config.node
Expand All @@ -17,7 +19,7 @@ class Convergence(ConnectionStrategy):
to X target cells.
"""

convergence = config.attr(type=types.distribution(), required=True)
convergence: "Distribution" = config.attr(type=types.distribution(), required=True)

def connect(self):
raise NotImplementedError("Needs to be restored, please open an issue.")
Expand Down Expand Up @@ -48,7 +50,7 @@ class FixedIndegree(InvertedRoI, ConnectionStrategy):
presynaptic cells from all the presynaptic cell types.
"""

indegree = config.attr(type=int, required=True)
indegree: int = config.attr(type=int, required=True)

def connect(self, pre, post):
in_ = self.indegree
Expand Down
25 changes: 14 additions & 11 deletions bsb/connectivity/import_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,28 @@
import csv
import io
import typing
from collections import defaultdict

import psutil
import numpy as np
from tqdm import tqdm

from ..exceptions import ConfigurationError
from .strategy import ConnectionStrategy
from ..storage import Chunk
from .. import config
from ..config import refs
from ..mixins import NotParallel
from ..storage.interfaces import PlacementSet

if typing.TYPE_CHECKING:
from ..storage import FileDependencyNode
from ..cell_types import CellType
from ..topology import Partition


@config.node
class ImportConnectivity(NotParallel, ConnectionStrategy, abc.ABC, classmap_entry=None):
source = config.file(required=True)
cell_types = config.reflist(refs.cell_type_ref, required=False)
partitions = config.reflist(refs.partition_ref, required=False)
source: "FileDependencyNode" = config.file(required=True)
cell_types: list["CellType"] = config.reflist(refs.cell_type_ref, required=False)
partitions: list["Partition"] = config.reflist(refs.partition_ref, required=False)

@config.property(default=False)
def cache(self):
Expand All @@ -40,11 +43,11 @@ def parse_source(self, pre, post):

@config.node
class CsvImportConnectivity(ImportConnectivity):
pre_header = config.attr(default="pre")
post_header = config.attr(default="post")
mapping_key = config.attr()
delimiter = config.attr(default=",")
progress_bar = config.attr(type=bool, default=True)
pre_header: str = config.attr(default="pre")
post_header: str = config.attr(default="post")
mapping_key: str = config.attr()
delimiter: str = config.attr(default=",")
progress_bar: bool = config.attr(type=bool, default=True)

def __boot__(self):
if (
Expand Down
25 changes: 16 additions & 9 deletions bsb/connectivity/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

if typing.TYPE_CHECKING:
from ..core import Scaffold
from ..cell_types import CellType
from ..storage.interfaces import PlacementSet
from ..morphologies import MorphologySet
from ..connectivity import ConnectionStrategy


@config.node
Expand All @@ -20,13 +24,13 @@ class Hemitype:

scaffold: "Scaffold"

cell_types = config.reflist(refs.cell_type_ref, required=True)
cell_types: list["CellType"] = config.reflist(refs.cell_type_ref, required=True)
"""List of cell types to use in connection."""
labels = config.attr(type=types.list())
labels: list[str] = config.attr(type=types.list())
"""List of labels to filter the placement set by."""
morphology_labels = config.attr(type=types.list())
morphology_labels: list[str] = config.attr(type=types.list())
"""List of labels to filter the morphologies by."""
morpho_loader = config.attr(
morpho_loader: typing.Callable[["PlacementSet"], "MorphologySet"] = config.attr(
type=types.function_(),
required=False,
call_default=False,
Expand Down Expand Up @@ -58,14 +62,17 @@ def placement(self):
@config.dynamic(attr_name="strategy", required=True)
class ConnectionStrategy(abc.ABC, SortableByAfter):
scaffold: "Scaffold"
name = config.attr(key=True)
name: str = config.attr(key=True)
"""Name used to refer to the connectivity strategy"""
presynaptic = config.attr(type=Hemitype, required=True)
presynaptic: Hemitype = config.attr(type=Hemitype, required=True)
"""Presynaptic (source) neuron population"""
postsynaptic = config.attr(type=Hemitype, required=True)
postsynaptic: Hemitype = config.attr(type=Hemitype, required=True)
"""Postsynaptic (target) neuron population"""
after = config.reflist(refs.connectivity_ref)
"""Action to perform after connecting the neurons with the current strategy."""
after: list["ConnectionStrategy"] = config.reflist(refs.connectivity_ref)
"""
This strategy should be executed only after all the connections in this list have
been executed.
"""

def __init_subclass__(cls, **kwargs):
super(cls, cls).__init_subclass__(**kwargs)
Expand Down
3 changes: 2 additions & 1 deletion bsb/morphologies/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ..config import types
from .. import config
from ..config._attrs import cfglist
from ..services import MPI
import concurrent
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -38,7 +39,7 @@ def pick(self, morphology):

@config.node
class NameSelector(MorphologySelector, classmap_entry="by_name"):
names = config.list(type=str, required=types.shortform())
names: cfglist[str] = config.list(type=str, required=types.shortform())

def __init__(self, name=None, /, **kwargs):
if name is not None:
Expand Down
4 changes: 2 additions & 2 deletions bsb/placement/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class ParallelArrayPlacement(NotParallel, PlacementStrategy):
Implementation of the placement of cells in parallel arrays.
"""

spacing_x = config.attr(type=float, required=True)
angle = config.attr(type=types.deg_to_radian(), required=True)
spacing_x: float = config.attr(type=float, required=True)
angle: float = config.attr(type=types.deg_to_radian(), required=True)

def place(self, chunk, indicators):
"""
Expand Down

0 comments on commit 7bdf19b

Please sign in to comment.