Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions .github/workflows/Sync_branches.yml

This file was deleted.

6 changes: 4 additions & 2 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# convenient alias
conn = connect
init = initialize
optimizers = optim
globals()['optimizers'] = optim

# numerical integrators
from brainpy import integrators
Expand All @@ -58,8 +58,11 @@
synapses, # synaptic dynamics
synouts, # synaptic output
synplast, # synaptic plasticity
experimental, # experimental model
)
from brainpy._src.dyn.base import not_pass_shargs
from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem,
Module as Module,
Container as Container,
Sequential as Sequential,
Network as Network,
Expand All @@ -71,7 +74,6 @@
TwoEndConn as TwoEndConn,
CondNeuGroup as CondNeuGroup,
Channel as Channel)
from brainpy._src.dyn.base import (DSPartial as DSPartial)
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/checkpoints/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,8 @@ def save_pytree(

if os.path.splitext(filename)[-1] != '.bp':
filename = filename + '.bp'
os.makedirs(os.path.dirname(filename), exist_ok=True)
if os.path.dirname(filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
if not overwrite and os.path.exists(filename):
raise InvalidCheckpointPath(filename)
target = to_bytes(target)
Expand Down
132 changes: 68 additions & 64 deletions brainpy/_src/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@

import collections
import gc
from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any
import warnings
from typing import Union, Dict, Callable, Sequence, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np

from brainpy import tools, check
from brainpy import tools
from brainpy._src import math as bm
from brainpy._src.math.ndarray import Variable, VariableView
from brainpy._src.math.object_transform.base import BrainPyObject, Collector
from brainpy._src.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
from brainpy._src.initialize import Initializer, parameter, variable, Uniform, noise as init_noise
from brainpy._src.integrators import odeint, sdeint
from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm
from brainpy._src.math.ndarray import Variable, VariableView
from brainpy._src.math.object_transform.base import BrainPyObject, Collector
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape

__all__ = [
# general class
'DynamicalSystem',
'Module',
'FuncAsDynSys',
'DSPartial',

# containers
'Container', 'Network', 'Sequential', 'System',
Expand All @@ -48,6 +48,46 @@
SLICE_VARS = 'slice_vars'


def not_pass_shargs(func: Callable):
"""Label the update function as the one without passing shared arguments.

The original update function explicitly requires shared arguments at the first place::

class TheModel(DynamicalSystem):
def update(self, s, x):
# s is the shared arguments, like `t`, `dt`, etc.
pass

So, each time we call the model we should provide shared arguments into the model::

TheModel()(shared, inputs)

When we label the update function as ``do_not_pass_sha_args``, this time there is no
need to call the dynamical system with shared arguments::

class NewModel(DynamicalSystem):
@no_shared
def update(self, x):
pass

NewModel()(inputs)

.. versionadded:: 2.3.5

Parameters
----------
func: Callable
The function in the :py:class:`~.DynamicalSystem`.

Returns
-------
func: Callable
The wrapped function for the class.
"""
func._new_style = True
return func


class DynamicalSystem(BrainPyObject):
"""Base Dynamical System class.

Expand All @@ -65,7 +105,6 @@ class DynamicalSystem(BrainPyObject):
we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
:py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.


Parameters
----------
name : optional, str
Expand All @@ -74,12 +113,6 @@ class DynamicalSystem(BrainPyObject):
The model computation mode. It should be instance of :py:class:`~.Mode`.
"""

online_fit_by: Optional[OnlineAlgorithm]
'''Online fitting method.'''

offline_fit_by: Optional[OfflineAlgorithm]
'''Offline fitting method.'''

global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], Variable]] = dict()
'''Global delay data, which stores the delay variables and corresponding delay targets.
This variable is useful when the same target variable is used in multiple mappings,
Expand All @@ -97,15 +130,11 @@ def __init__(
f'but we got {type(mode)}: {mode}')
self._mode = mode

super(DynamicalSystem, self).__init__(name=name)

# local delay variables
self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()

# fitting parameters
self.online_fit_by = None
self.offline_fit_by = None
self.fit_record = dict()
# super initialization
super(DynamicalSystem, self).__init__(name=name)

@property
def mode(self) -> bm.Mode:
Expand All @@ -124,7 +153,21 @@ def __repr__(self):

def __call__(self, *args, **kwargs):
"""The shortcut to call ``update`` methods."""
return self.update(*args, **kwargs)
if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'):
if len(args) and isinstance(args[0], dict):
bm.share.save_shargs(**args[0])
return self.update(*args[1:], **kwargs)
else:
return self.update(*args, **kwargs)
else:
if len(args) and isinstance(args[0], dict):
return self.update(*args, **kwargs)
else:
# If first argument is not shared argument,
# we should get the shared arguments from the global context.
# However, users should set and update shared arguments
# in the global context when using this mode.
return self.update(bm.share.get_shargs(), *args, **kwargs)

def register_delay(
self,
Expand Down Expand Up @@ -339,26 +382,13 @@ def __del__(self):
del self.__dict__[key]
gc.collect()

@tools.not_customized
def online_init(self):
raise NoImplementationError('Subclass must implement online_init() function when using OnlineTrainer.')

@tools.not_customized
def online_fit(self,
target: ArrayType,
fit_record: Dict[str, ArrayType]):
raise NoImplementationError('Subclass must implement online_fit() function when using OnlineTrainer.')

@tools.not_customized
def offline_fit(self,
target: ArrayType,
fit_record: Dict[str, ArrayType]):
raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.')

def clear_input(self):
pass


Module = DynamicalSystem


class FuncAsDynSys(DynamicalSystem):
"""Transform a Python function as a :py:class:`~.DynamicalSystem`

Expand Down Expand Up @@ -411,31 +441,6 @@ def __repr__(self):
f'{indent}num_of_vars={len(self.implicit_vars)})')


class DSPartial(FuncAsDynSys):
def __init__(
self,
target: Callable,
*args,
child_objs: Union[Callable, BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None,
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
shared: Dict = None,
**keywords
):
super().__init__(target=target, child_objs=child_objs, dyn_vars=dyn_vars)

check.is_dict_data(shared, all_none=True)
self.target = check.is_callable(target, )
self.args = tuple(args)
self.keywords = keywords
self.shared = dict() if shared is None else shared

def __call__(self, s, *args, **keywords):
assert isinstance(s, dict)
s = tools.DotDict(s).update(self.shared)
args = self.args + (s,) + args
keywords = {**self.keywords, **keywords}
return self.target(*args, **keywords)


class Container(DynamicalSystem):
"""Container object which is designed to add other instances of DynamicalSystem.
Expand Down Expand Up @@ -639,7 +644,7 @@ def __repr__(self):
entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(self._modules))
return f'{self.__class__.__name__}(\n{entries}\n)'

def update(self, *args) -> ArrayType:
def update(self, s, x) -> ArrayType:
"""Update function of a sequential model.

Parameters
Expand All @@ -654,7 +659,6 @@ def update(self, *args) -> ArrayType:
y: ArrayType
The output tensor.
"""
s, x = (dict(), args[0]) if len(args) == 1 else (args[0], args[1])
for m in self._modules:
if isinstance(m, DynamicalSystem):
x = m(s, x)
Expand Down Expand Up @@ -818,7 +822,7 @@ def get_batch_shape(self, batch_size=None):
else:
return (batch_size,) + self.varshape

def update(self, tdi, x=None):
def update(self, *args):
"""The function to specify the updating rule.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional

import brainpy.math as bm
from brainpy._src.dyn.base import DynamicalSystem
from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs

__all__ = [
'Layer'
Expand Down
9 changes: 5 additions & 4 deletions brainpy/_src/dyn/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from jax import lax

from brainpy import math as bm, tools, check
from brainpy._src.dyn.base import not_pass_shargs
from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter
from brainpy.types import ArrayType
from .base import Layer
Expand Down Expand Up @@ -153,8 +154,8 @@ def _check_input_dim(self, x):
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")

def update(self, *args):
x = args[0] if len(args) == 1 else args[1]
@not_pass_shargs
def update(self, x):
self._check_input_dim(x)
w = self.w.value
if self.mask is not None:
Expand Down Expand Up @@ -525,8 +526,8 @@ def __init__(
def _check_input_dim(self, x):
raise NotImplementedError

def update(self, *args):
x = args[0] if len(args) == 1 else args[1]
@not_pass_shargs
def update(self, x):
self._check_input_dim(x)

w = self.w.value
Expand Down
7 changes: 3 additions & 4 deletions brainpy/_src/dyn/layers/dropout.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp


from brainpy import math as bm, check
from .base import Layer
from brainpy._src.dyn.base import not_pass_shargs

__all__ = [
'Dropout'
Expand Down Expand Up @@ -49,8 +48,8 @@ def __init__(
self.prob = check.is_float(prob, min_bound=0., max_bound=1.)
self.rng = bm.random.default_rng(seed)

def update(self, sha, x):
if sha.get('fit', True):
def update(self, s, x):
if s['fit']:
keep_mask = self.rng.bernoulli(self.prob, x.shape)
return bm.where(bm.as_jax(keep_mask), x / self.prob, 0.)
else:
Expand Down
Loading