diff --git a/.github/workflows/Sync_branches.yml b/.github/workflows/Sync_branches.yml deleted file mode 100644 index 4a4192425..000000000 --- a/.github/workflows/Sync_branches.yml +++ /dev/null @@ -1,18 +0,0 @@ -name: Sync multiple branches -on: - push: - branches: - - master -jobs: - sync-branch: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@master - - - name: Merge master -> brainpy-2.3.x - uses: devmasx/merge-branch@master - with: - type: now - from_branch: master - target_branch: brainpy-2.3.x - github_token: ${{ github.token }} \ No newline at end of file diff --git a/brainpy/__init__.py b/brainpy/__init__.py index ed21f9a99..9b3f8acb7 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -35,7 +35,7 @@ # convenient alias conn = connect init = initialize -optimizers = optim +globals()['optimizers'] = optim # numerical integrators from brainpy import integrators @@ -58,8 +58,11 @@ synapses, # synaptic dynamics synouts, # synaptic output synplast, # synaptic plasticity + experimental, # experimental model ) +from brainpy._src.dyn.base import not_pass_shargs from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem, + Module as Module, Container as Container, Sequential as Sequential, Network as Network, @@ -71,7 +74,6 @@ TwoEndConn as TwoEndConn, CondNeuGroup as CondNeuGroup, Channel as Channel) -from brainpy._src.dyn.base import (DSPartial as DSPartial) from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations LoopOverTime as LoopOverTime,) from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner diff --git a/brainpy/_src/checkpoints/serialization.py b/brainpy/_src/checkpoints/serialization.py index 3e5c27525..d93c04600 100644 --- a/brainpy/_src/checkpoints/serialization.py +++ b/brainpy/_src/checkpoints/serialization.py @@ -1264,7 +1264,8 @@ def save_pytree( if os.path.splitext(filename)[-1] != '.bp': filename = filename + '.bp' - os.makedirs(os.path.dirname(filename), exist_ok=True) + if os.path.dirname(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) if not overwrite and os.path.exists(filename): raise InvalidCheckpointPath(filename) target = to_bytes(target) diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index cb68d8743..2963a691e 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -2,28 +2,28 @@ import collections import gc -from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any +import warnings +from typing import Union, Dict, Callable, Sequence, Optional, Tuple import jax import jax.numpy as jnp import numpy as np -from brainpy import tools, check +from brainpy import tools from brainpy._src import math as bm -from brainpy._src.math.ndarray import Variable, VariableView -from brainpy._src.math.object_transform.base import BrainPyObject, Collector from brainpy._src.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All from brainpy._src.initialize import Initializer, parameter, variable, Uniform, noise as init_noise from brainpy._src.integrators import odeint, sdeint -from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm +from brainpy._src.math.ndarray import Variable, VariableView +from brainpy._src.math.object_transform.base import BrainPyObject, Collector from brainpy.errors import NoImplementationError, UnsupportedError from brainpy.types import ArrayType, Shape __all__ = [ # general class 'DynamicalSystem', + 'Module', 'FuncAsDynSys', - 'DSPartial', # containers 'Container', 'Network', 'Sequential', 'System', @@ -48,6 +48,46 @@ SLICE_VARS = 'slice_vars' +def not_pass_shargs(func: Callable): + """Label the update function as the one without passing shared arguments. + + The original update function explicitly requires shared arguments at the first place:: + + class TheModel(DynamicalSystem): + def update(self, s, x): + # s is the shared arguments, like `t`, `dt`, etc. + pass + + So, each time we call the model we should provide shared arguments into the model:: + + TheModel()(shared, inputs) + + When we label the update function as ``do_not_pass_sha_args``, this time there is no + need to call the dynamical system with shared arguments:: + + class NewModel(DynamicalSystem): + @no_shared + def update(self, x): + pass + + NewModel()(inputs) + + .. versionadded:: 2.3.5 + + Parameters + ---------- + func: Callable + The function in the :py:class:`~.DynamicalSystem`. + + Returns + ------- + func: Callable + The wrapped function for the class. + """ + func._new_style = True + return func + + class DynamicalSystem(BrainPyObject): """Base Dynamical System class. @@ -65,7 +105,6 @@ class DynamicalSystem(BrainPyObject): we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`, :py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`. - Parameters ---------- name : optional, str @@ -74,12 +113,6 @@ class DynamicalSystem(BrainPyObject): The model computation mode. It should be instance of :py:class:`~.Mode`. """ - online_fit_by: Optional[OnlineAlgorithm] - '''Online fitting method.''' - - offline_fit_by: Optional[OfflineAlgorithm] - '''Offline fitting method.''' - global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], Variable]] = dict() '''Global delay data, which stores the delay variables and corresponding delay targets. This variable is useful when the same target variable is used in multiple mappings, @@ -97,15 +130,11 @@ def __init__( f'but we got {type(mode)}: {mode}') self._mode = mode - super(DynamicalSystem, self).__init__(name=name) - # local delay variables self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector() - # fitting parameters - self.online_fit_by = None - self.offline_fit_by = None - self.fit_record = dict() + # super initialization + super(DynamicalSystem, self).__init__(name=name) @property def mode(self) -> bm.Mode: @@ -124,7 +153,21 @@ def __repr__(self): def __call__(self, *args, **kwargs): """The shortcut to call ``update`` methods.""" - return self.update(*args, **kwargs) + if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'): + if len(args) and isinstance(args[0], dict): + bm.share.save_shargs(**args[0]) + return self.update(*args[1:], **kwargs) + else: + return self.update(*args, **kwargs) + else: + if len(args) and isinstance(args[0], dict): + return self.update(*args, **kwargs) + else: + # If first argument is not shared argument, + # we should get the shared arguments from the global context. + # However, users should set and update shared arguments + # in the global context when using this mode. + return self.update(bm.share.get_shargs(), *args, **kwargs) def register_delay( self, @@ -339,26 +382,13 @@ def __del__(self): del self.__dict__[key] gc.collect() - @tools.not_customized - def online_init(self): - raise NoImplementationError('Subclass must implement online_init() function when using OnlineTrainer.') - - @tools.not_customized - def online_fit(self, - target: ArrayType, - fit_record: Dict[str, ArrayType]): - raise NoImplementationError('Subclass must implement online_fit() function when using OnlineTrainer.') - - @tools.not_customized - def offline_fit(self, - target: ArrayType, - fit_record: Dict[str, ArrayType]): - raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.') - def clear_input(self): pass +Module = DynamicalSystem + + class FuncAsDynSys(DynamicalSystem): """Transform a Python function as a :py:class:`~.DynamicalSystem` @@ -411,31 +441,6 @@ def __repr__(self): f'{indent}num_of_vars={len(self.implicit_vars)})') -class DSPartial(FuncAsDynSys): - def __init__( - self, - target: Callable, - *args, - child_objs: Union[Callable, BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None, - dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None, - shared: Dict = None, - **keywords - ): - super().__init__(target=target, child_objs=child_objs, dyn_vars=dyn_vars) - - check.is_dict_data(shared, all_none=True) - self.target = check.is_callable(target, ) - self.args = tuple(args) - self.keywords = keywords - self.shared = dict() if shared is None else shared - - def __call__(self, s, *args, **keywords): - assert isinstance(s, dict) - s = tools.DotDict(s).update(self.shared) - args = self.args + (s,) + args - keywords = {**self.keywords, **keywords} - return self.target(*args, **keywords) - class Container(DynamicalSystem): """Container object which is designed to add other instances of DynamicalSystem. @@ -639,7 +644,7 @@ def __repr__(self): entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(self._modules)) return f'{self.__class__.__name__}(\n{entries}\n)' - def update(self, *args) -> ArrayType: + def update(self, s, x) -> ArrayType: """Update function of a sequential model. Parameters @@ -654,7 +659,6 @@ def update(self, *args) -> ArrayType: y: ArrayType The output tensor. """ - s, x = (dict(), args[0]) if len(args) == 1 else (args[0], args[1]) for m in self._modules: if isinstance(m, DynamicalSystem): x = m(s, x) @@ -818,7 +822,7 @@ def get_batch_shape(self, batch_size=None): else: return (batch_size,) + self.varshape - def update(self, tdi, x=None): + def update(self, *args): """The function to specify the updating rule. Parameters diff --git a/brainpy/_src/dyn/layers/base.py b/brainpy/_src/dyn/layers/base.py index 4f5f6defd..830267557 100644 --- a/brainpy/_src/dyn/layers/base.py +++ b/brainpy/_src/dyn/layers/base.py @@ -4,7 +4,7 @@ from typing import Optional import brainpy.math as bm -from brainpy._src.dyn.base import DynamicalSystem +from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs __all__ = [ 'Layer' diff --git a/brainpy/_src/dyn/layers/conv.py b/brainpy/_src/dyn/layers/conv.py index 98a2b2204..6c1e51f7c 100644 --- a/brainpy/_src/dyn/layers/conv.py +++ b/brainpy/_src/dyn/layers/conv.py @@ -5,6 +5,7 @@ from jax import lax from brainpy import math as bm, tools, check +from brainpy._src.dyn.base import not_pass_shargs from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter from brainpy.types import ArrayType from .base import Layer @@ -153,8 +154,8 @@ def _check_input_dim(self, x): raise ValueError(f"input channels={x.shape[-1]} needs to have " f"the same size as in_channels={self.in_channels}.") - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] + @not_pass_shargs + def update(self, x): self._check_input_dim(x) w = self.w.value if self.mask is not None: @@ -525,8 +526,8 @@ def __init__( def _check_input_dim(self, x): raise NotImplementedError - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] + @not_pass_shargs + def update(self, x): self._check_input_dim(x) w = self.w.value diff --git a/brainpy/_src/dyn/layers/dropout.py b/brainpy/_src/dyn/layers/dropout.py index 9f8d2fcac..f73988b5e 100644 --- a/brainpy/_src/dyn/layers/dropout.py +++ b/brainpy/_src/dyn/layers/dropout.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -import jax.numpy as jnp - from brainpy import math as bm, check from .base import Layer +from brainpy._src.dyn.base import not_pass_shargs __all__ = [ 'Dropout' @@ -49,8 +48,8 @@ def __init__( self.prob = check.is_float(prob, min_bound=0., max_bound=1.) self.rng = bm.random.default_rng(seed) - def update(self, sha, x): - if sha.get('fit', True): + def update(self, s, x): + if s['fit']: keep_mask = self.rng.bernoulli(self.prob, x.shape) return bm.where(bm.as_jax(keep_mask), x / self.prob, 0.) else: diff --git a/brainpy/_src/dyn/layers/function.py b/brainpy/_src/dyn/layers/function.py index d6ac8a1fe..7f36179fc 100644 --- a/brainpy/_src/dyn/layers/function.py +++ b/brainpy/_src/dyn/layers/function.py @@ -6,6 +6,7 @@ import brainpy.math as bm from brainpy import check from .base import Layer +from brainpy._src.dyn.base import not_pass_shargs __all__ = [ 'Activation', @@ -26,6 +27,7 @@ class Activation(Layer): mode: Mode Enable training this node or not. (default True). """ + update_style = 'x' def __init__( self, @@ -38,9 +40,9 @@ def __init__( self.activate_fun = activate_fun self.kwargs = kwargs - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] - return self.activate_fun(x, **self.kwargs) + @not_pass_shargs + def update(self, *args, **kwargs): + return self.activate_fun(*args, **kwargs, **self.kwargs) class Flatten(Layer): @@ -62,8 +64,8 @@ def __init__( super().__init__(name, mode) check.is_subclass(self.mode, (bm.NonBatchingMode, bm.BatchingMode, bm.TrainingMode), self.name) - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] + @not_pass_shargs + def update(self, x): if isinstance(self.mode, bm.BatchingMode): return x.reshape((x.shape[0], -1)) else: @@ -76,19 +78,12 @@ def __init__( fun: Callable, name: Optional[str] = None, mode: bm.Mode = None, - has_shared: bool = False, **kwargs, ): super().__init__(name, mode) self._fun = fun self.kwargs = kwargs - self.has_shared = has_shared - - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] - if self.has_shared: - assert len(args) > 1 - s = args[0] - return self._fun(s, x, **self.kwargs) - else: - return self._fun(x, **self.kwargs) + + @not_pass_shargs + def update(self, *args, **kwargs): + return self._fun(*args, **kwargs, **self.kwargs) diff --git a/brainpy/_src/dyn/layers/linear.py b/brainpy/_src/dyn/layers/linear.py index 66b4e5caf..c5c21e4e3 100644 --- a/brainpy/_src/dyn/layers/linear.py +++ b/brainpy/_src/dyn/layers/linear.py @@ -6,11 +6,12 @@ import jax.numpy as jnp from brainpy import math as bm -from .base import Layer +from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm +from brainpy.check import is_initializer from brainpy.errors import MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter -from brainpy.check import is_initializer from brainpy.types import ArrayType +from .base import Layer __all__ = [ 'Dense', @@ -42,6 +43,12 @@ class Dense(Layer): Enable training this node or not. (default True) """ + online_fit_by: Optional[OnlineAlgorithm] + '''Online fitting method.''' + + offline_fit_by: Optional[OfflineAlgorithm] + '''Offline fitting method.''' + def __init__( self, num_in: int, @@ -76,6 +83,11 @@ def __init__( self.W = bm.TrainVar(self.W) self.b = None if (self.b is None) else bm.TrainVar(self.b) + # fitting parameters + self.online_fit_by = None + self.offline_fit_by = None + self.fit_record = dict() + def __repr__(self): return (f'{self.__class__.__name__}(name={self.name}, ' f'num_in={self.num_in}, ' diff --git a/brainpy/_src/dyn/layers/nvar.py b/brainpy/_src/dyn/layers/nvar.py index a3c83d86e..84c666748 100644 --- a/brainpy/_src/dyn/layers/nvar.py +++ b/brainpy/_src/dyn/layers/nvar.py @@ -9,6 +9,7 @@ import brainpy.math as bm from brainpy import check from .base import Layer +from brainpy._src.dyn.base import not_pass_shargs __all__ = [ 'NVAR' @@ -129,7 +130,8 @@ def reset_state(self, batch_size=None): else: self.store.value = jnp.zeros((self.num_delay, batch_size, self.num_in)) - def update(self, sha, x): + @not_pass_shargs + def update(self, x): all_parts = [] select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay # 1. Store the current input diff --git a/brainpy/_src/dyn/layers/pooling.py b/brainpy/_src/dyn/layers/pooling.py index 49e8d1fe1..0967e4bff 100644 --- a/brainpy/_src/dyn/layers/pooling.py +++ b/brainpy/_src/dyn/layers/pooling.py @@ -8,6 +8,7 @@ from brainpy import math as bm, check from .base import Layer +from brainpy._src.dyn.base import not_pass_shargs __all__ = [ 'MaxPool', @@ -80,8 +81,8 @@ def __init__( f'padding should be sequence of Tuple[int, int]. {padding}' assert all([len(x) == 2 for x in padding]), f"each entry in padding {padding} must be length 2" - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] + @not_pass_shargs + def update(self, x): x = bm.as_jax(x) window_shape = self._infer_shape(x.ndim, self.kernel_size) stride = self._infer_shape(x.ndim, self.stride) @@ -257,8 +258,8 @@ def __init__( mode=mode, name=name) - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] + @not_pass_shargs + def update(self, x): x = bm.as_jax(x) window_shape = self._infer_shape(x.ndim, self.kernel_size) strides = self._infer_shape(x.ndim, self.stride) @@ -358,8 +359,8 @@ def __init__( # channel_axis self.channel_axis = check.is_integer(channel_axis, allow_none=True) - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] + @not_pass_shargs + def update(self, x): x = bm.as_jax(x) x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) if x.ndim < x_dim: @@ -524,8 +525,8 @@ def __init__( class _AvgPoolNd(_MaxPoolNd): - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] + @not_pass_shargs + def update(self, x): x = bm.as_jax(x) x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) if x.ndim < x_dim: @@ -762,18 +763,16 @@ def __init__( raise ValueError("`target_size` must either be an int or tuple of length " f"{num_spatial_dims} containing ints.") - def update(self, *args): + @not_pass_shargs + def update(self, x): """Input-output mapping. Parameters ---------- - s: dict - Shared arguments. x: Array Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)` or `(..., dim_1, dim_2)`. """ - x = args[0] if len(args) == 1 else args[1] x = bm.as_jax(x) # channel axis diff --git a/brainpy/_src/dyn/layers/reservoir.py b/brainpy/_src/dyn/layers/reservoir.py index 18675672c..feffa3854 100644 --- a/brainpy/_src/dyn/layers/reservoir.py +++ b/brainpy/_src/dyn/layers/reservoir.py @@ -10,6 +10,7 @@ from brainpy.tools import to_size from brainpy.types import ArrayType from .base import Layer +from brainpy._src.dyn.base import not_pass_shargs __all__ = [ 'Reservoir', @@ -191,10 +192,10 @@ def __init__( def reset_state(self, batch_size=None): self.state.value = variable(jnp.zeros, batch_size, self.output_shape) - def update(self, *args): + @not_pass_shargs + def update(self, x): """Feedforward output.""" # inputs - x = args[0] if len(args) == 1 else args[1] x = bm.as_jax(x) if self.noise_ff > 0: x += self.noise_ff * self.rng.uniform(-1, 1, x.shape) diff --git a/brainpy/_src/dyn/layers/rnncells.py b/brainpy/_src/dyn/layers/rnncells.py index d9613c804..c99b33ab2 100644 --- a/brainpy/_src/dyn/layers/rnncells.py +++ b/brainpy/_src/dyn/layers/rnncells.py @@ -17,6 +17,8 @@ from brainpy.types import ArrayType from .base import Layer from .conv import _GeneralConv +from brainpy._src.dyn.base import not_pass_shargs + __all__ = [ 'RNNCell', 'GRUCell', 'LSTMCell', @@ -115,7 +117,8 @@ def reset_state(self, batch_size=None): self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) self.state[:] = self.state2train - def update(self, sha, x): + @not_pass_shargs + def update(self, x): h = x @ self.Wi h += self.state.value @ self.Wh if self.b is not None: @@ -225,7 +228,8 @@ def reset_state(self, batch_size=None): self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) self.state[:] = self.state2train - def update(self, sha, x): + @not_pass_shargs + def update(self, x): gates_x = jnp.matmul(x, bm.as_jax(self.Wi)) zr_x, a_x = jnp.split(gates_x, indices_or_sections=[2 * self.num_out], axis=-1) w_h_z, w_h_a = jnp.split(bm.as_jax(self.Wh), indices_or_sections=[2 * self.num_out], axis=-1) @@ -361,6 +365,7 @@ def reset_state(self, batch_size=None): self.state2train.value = parameter(self._state_initializer, self.num_out * 2, allow_none=False) self.state[:] = self.state2train + @not_pass_shargs def update(self, sha, x): h, c = jnp.split(self.state.value, 2, axis=-1) gated = x @ self.Wi @@ -558,6 +563,7 @@ def reset_state(self, batch_size: int = 1): self.h[:] = self.h_to_train self.c[:] = self.c_to_train + @not_pass_shargs def update(self, *args): x = args[0] if len(args) == 1 else args[1] gates = self.input_to_hidden(x) + self.hidden_to_hidden(self.h) diff --git a/brainpy/_src/dyn/layers/tests/test_pooling.py b/brainpy/_src/dyn/layers/tests/test_pooling.py index ad977a98b..8aad1d788 100644 --- a/brainpy/_src/dyn/layers/tests/test_pooling.py +++ b/brainpy/_src/dyn/layers/tests/test_pooling.py @@ -34,7 +34,7 @@ def test_maxpool2(self): x = self.rng.rand(10, 20, 20, 4) with bm.training_environment(): net = bp.layers.MaxPool((2, 2), (2, 2), channel_axis=-1) - y = net(None, x) + y = net(x) print("out shape: ", y.shape) def test_minpool(self): diff --git a/brainpy/_src/dyn/neurons/biological_models.py b/brainpy/_src/dyn/neurons/biological_models.py index f016420f4..32d1c68aa 100644 --- a/brainpy/_src/dyn/neurons/biological_models.py +++ b/brainpy/_src/dyn/neurons/biological_models.py @@ -4,7 +4,7 @@ import brainpy.math as bm from brainpy import check -from brainpy._src.dyn.base import NeuGroup +from brainpy._src.dyn.base import NeuGroup, not_pass_shargs from brainpy._src.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_ from brainpy._src.integrators.joint_eq import JointEq from brainpy._src.integrators.ode.generic import odeint @@ -129,7 +129,7 @@ class HH(NeuGroup): >>> group = bp.neurons.HH(2) >>> >>> I1 = bp.inputs.spike_input(sp_times=[500., 550., 1000, 1030, 1060, 1100, 1200], sp_lens=5, sp_sizes=5., duration=2000, ) - >>> I2 = bp.inputs.spike_input(sp_times=[600., 900, 950, 1500], sp_lens=5, sp_sizes=5., duration=2000, ) + >>> I2 = bp.inputs.spike_input(sp_times=[600., 900, 950, 1500], sp_lens=5, sp_sizes=5., duration=2000, ) >>> I1 += bp.math.random.normal(0, 3, size=I1.shape) >>> I2 += bp.math.random.normal(0, 3, size=I2.shape) >>> I = bm.stack((I1, I2), axis=-1) @@ -244,26 +244,15 @@ def __init__( self._n_initializer = n_initializer self._V_initializer = V_initializer - # variables - self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.m = (bm.Variable(self.m_inf(self.V.value)) - if m_initializer is None else - variable_(self._m_initializer, self.varshape, self.mode)) - self.h = (bm.Variable(self.h_inf(self.V.value)) - if h_initializer is None else - variable_(self._h_initializer, self.varshape, self.mode)) - self.n = (bm.Variable(self.n_inf(self.V.value)) - if n_initializer is None else - variable_(self._n_initializer, self.varshape, self.mode)) - self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - # integral if self.noise is None: self.integral = odeint(method=method, f=self.derivative) else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + # model + self.reset_state(self.mode) + # m channel m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18) @@ -283,37 +272,38 @@ def __init__( dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) if self._m_initializer is None: - self.m.value = self.m_inf(self.V.value) + self.m = self.m_inf(self.V.value) else: - self.m.value = variable_(self._m_initializer, self.varshape, batch_size) + self.m = variable_(self._m_initializer, self.varshape, batch_size) if self._h_initializer is None: - self.h.value = self.h_inf(self.V.value) + self.h = self.h_inf(self.V.value) else: - self.h.value = variable_(self._h_initializer, self.varshape, batch_size) + self.h = variable_(self._h_initializer, self.varshape, batch_size) if self._n_initializer is None: - self.n.value = self.n_inf(self.V.value) + self.n = self.n_inf(self.V.value) else: - self.n.value = variable_(self._n_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.n = variable_(self._n_initializer, self.varshape, batch_size) + self.input = variable_(bm.zeros, self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - def dV(self, V, t, m, h, n, I_ext): + def dV(self, V, t, m, h, n): I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) I_K = (self.gK * n ** 4.0) * (V - self.EK) I_leak = self.gL * (V - self.EL) - dVdt = (- I_Na - I_K - I_leak + I_ext) / self.C + dVdt = (- I_Na - I_K - I_leak + self.input) / self.C return dVdt @property def derivative(self): return JointEq(self.dV, self.dm, self.dh, self.dn) - def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] + @not_pass_shargs + def update(self, x=None): + s = bm.share.get_shargs() if x is not None: self.input += x - V, m, h, n = self.integral(self.V, self.m, self.h, self.n, t, self.input, dt) + V, m, h, n = self.integral(self.V, self.m, self.h, self.n, s['t'], s['dt']) self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.m.value = m @@ -454,10 +444,7 @@ def __init__( self._V_initializer = V_initializer # variables - self.W = variable_(self._W_initializer, self.varshape, self.mode) - self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + self.reset_state(self.mode) # integral if self.noise is None: @@ -466,10 +453,10 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.W.value = variable_(self._W_initializer, self.varshape, batch_size) - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.W = variable_(self._W_initializer, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.input = variable_(bm.zeros, self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, W, I_ext): M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) diff --git a/brainpy/_src/dyn/neurons/input_groups.py b/brainpy/_src/dyn/neurons/input_groups.py index 146d2e5c7..5724ab8b7 100644 --- a/brainpy/_src/dyn/neurons/input_groups.py +++ b/brainpy/_src/dyn/neurons/input_groups.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroup +from brainpy._src.dyn.base import NeuGroup, not_pass_shargs from brainpy._src.initialize import Initializer, parameter, variable_ from brainpy.types import Shape, ArrayType @@ -41,7 +41,8 @@ def __init__( mode=mode) self.spike = None - def update(self, tdi, x=None): + @not_pass_shargs + def update(self, x): return x def reset_state(self, batch_size=None): @@ -72,8 +73,9 @@ def __init__( mode=mode) self.spike = None - def update(self, tdi, x=None): - pass + @not_pass_shargs + def update(self, x): + return x def reset_state(self, batch_size=None): pass diff --git a/brainpy/_src/dyn/neurons/reduced_models.py b/brainpy/_src/dyn/neurons/reduced_models.py index 099b63c8f..28f80adee 100644 --- a/brainpy/_src/dyn/neurons/reduced_models.py +++ b/brainpy/_src/dyn/neurons/reduced_models.py @@ -6,7 +6,7 @@ from jax.lax import stop_gradient import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroup +from brainpy._src.dyn.base import NeuGroup, not_pass_shargs from brainpy._src.initialize import (ZeroInit, OneInit, Initializer, parameter, variable_, noise as init_noise) from brainpy._src.integrators import sdeint, odeint, JointEq @@ -99,25 +99,25 @@ def __init__( is_initializer(V_initializer, 'V_initializer') self._V_initializer = V_initializer - # variables - self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - # integral if self.noise is None: self.integral = odeint(method=method, f=self.derivative) else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + # variables + self.reset_state(self.mode) + def derivative(self, V, t, I_ext): return (-V + self.V_rest + self.R * I_ext) / self.tau def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.input = variable_(bm.zeros, self.varshape, batch_size) def update(self, tdi, x=None): - if x is not None: self.input += x + if x is not None: + self.input += x self.V.value = self.integral(self.V.value, tdi.t, self.input.value, tdi.dt) def clear_input(self): diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index 7dff6e0ef..fe7ae9fc8 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -615,6 +615,7 @@ def _step_func_predict(self, shared_args, t, i, x): # input step shared = tools.DotDict(t=t, i=i, dt=self.dt) shared.update(shared_args) + bm.share.save_shargs(**shared) self.target.clear_input() self._step_func_input(shared) @@ -629,6 +630,7 @@ def _step_func_predict(self, shared_args, t, i, x): # finally if self.progress_bar: id_tap(lambda *arg: self._pbar.update(), ()) + bm.share.remove_shargs() return out, mon def _get_f_predict(self, shared_args: Dict = None): diff --git a/brainpy/_src/experimental/__init__.py b/brainpy/_src/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/brainpy/_src/experimental/delay.py b/brainpy/_src/experimental/delay.py new file mode 100644 index 000000000..8e8ad3cbf --- /dev/null +++ b/brainpy/_src/experimental/delay.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- + +from typing import Union, Callable, Optional, Tuple, Sequence, Dict + +import jax +import jax.numpy as jnp +import numpy as np +from jax.lax import stop_gradient + +from brainpy import check, math as bm +from brainpy._src.math.object_transform.base import Collector +from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs +from brainpy.check import is_integer, jit_error_checking + +ROTATE_UPDATE = 'rotation' +CONCAT_UPDATE = 'concat' + + +class Delay(DynamicalSystem): + """Delay variable which has a fixed delay length. + + The data in this delay variable is arranged as:: + + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + + Parameters + ---------- + target: Variable + The initial delay data. + length: int + The delay data length. + initial_delay_data: Any + The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + method: str + The method used for updating delay. + + """ + + data: Optional[bm.Variable] + idx: Optional[bm.Variable] + length: int + + def __init__( + self, + target: bm.Variable, + length: int = 0, + initial_delay_data: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, + entries: Optional[Dict] = None, + mode: bm.Mode = None, + name: str = None, + method: str = None, + ): + super().__init__(mode=mode, name=name) + + # delay updating method + if method is None: + if self.mode.is_a(bm.NonBatchingMode): + method = ROTATE_UPDATE + else: + method = CONCAT_UPDATE + assert method in [ROTATE_UPDATE, CONCAT_UPDATE] + self.method = method + + # target + self.target = target + if not isinstance(target, bm.Variable): + raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') + + # delay length + self.length = is_integer(length, allow_none=False, min_bound=0) + + # delay data + if initial_delay_data is not None: + assert isinstance(initial_delay_data, (int, float, bool, bm.Array, jax.Array, Callable)) + self._initial_delay_data = initial_delay_data + if length > 0: + self._init_data(length) + else: + self.data = None + + # time variables + if self.method == ROTATE_UPDATE: + self.idx = bm.Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) + + # other info + self._access_to_step = dict() + for entry, value in entries.items(): + self.register_entry(entry, value) + + def register_entry( + self, + entry: str, + delay_time: Optional[Union[float, bm.Array, Callable]] = None, + delay_step: Optional[Union[int, bm.Array, Callable]] = None, + ) -> 'Delay': + """Register an entry to access the data. + + Args: + entry (str): The entry to access the delay data. + delay_step: The delay step of the entry (must be an integer, denoting the delay step). + delay_time: The delay time of the entry (can be a float). + + Returns: + Return the self. + """ + if entry in self._access_to_step: + raise KeyError(f'Entry {entry} has been registered.') + + if delay_time is not None: + if delay_step is not None: + raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') + if callable(delay_time): + delay_time = bm.as_jax(delay_time(self.delay_target_shape)) + delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) + elif isinstance(delay_time, float): + delay_step = int(delay_time / bm.get_dt()) + else: + delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) + + # delay steps + if delay_step is None: + delay_type = 'none' + elif isinstance(delay_step, int): + delay_type = 'homo' + elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)): + if delay_step.size == 1 and delay_step.ndim == 0: + delay_type = 'homo' + else: + delay_type = 'heter' + delay_step = bm.Array(delay_step) + elif callable(delay_step): + delay_step = delay_step(self.delay_target_shape) + delay_type = 'heter' + else: + raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' + f'integer, array of integers, callable function, brainpy.init.Initializer.') + if delay_type == 'heter': + if delay_step.dtype not in [jnp.int32, jnp.int64]: + raise ValueError('Only support delay steps of int32, int64. If your ' + 'provide delay time length, please divide the "dt" ' + 'then provide us the number of delay steps.') + if self.delay_target_shape[0] != delay_step.shape[0]: + raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') + if delay_type == 'heter': + max_delay_step = int(max(delay_step)) + elif delay_type == 'homo': + max_delay_step = delay_step + else: + max_delay_step = None + + # delay variable + if max_delay_step is not None: + if self.length < max_delay_step: + self._init_data(max_delay_step) + self.length = max_delay_step + self._access_to_step[entry] = delay_step + return self + + def at_entry(self, entry: str, *indices) -> bm.Array: + """Get the data at the given entry. + + Args: + entry (str): The entry to access the data. + *indices: + + Returns: + The data. + """ + assert isinstance(entry, str) + if entry not in self._access_to_step: + raise KeyError(f'Does not find delay entry "{entry}".') + delay_step = self._access_to_step[entry] + if delay_step is None: + return self.target.value + else: + if self.data is None: + return self.target.value + else: + if isinstance(delay_step, slice): + return self.retrieve(delay_step, *indices) + elif np.ndim(delay_step) == 0: + return self.retrieve(delay_step, *indices) + else: + if len(indices) == 0 and len(delay_step) == self.target.shape[0]: + indices = (jnp.arange(delay_step.size),) + return self.retrieve(delay_step, *indices) + + @property + def delay_target_shape(self): + """The data shape of the delay target.""" + return self.target.shape + + def __repr__(self): + name = self.__class__.__name__ + return (f'{name}(num_delay_step={self.length}, ' + f'delay_target_shape={self.delay_target_shape}, ' + f'update_method={self.method})') + + def _check_delay(self, delay_len): + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.length}. ' + f'But we got {delay_len}') + + def retrieve(self, delay_step, *indices): + """Retrieve the delay data according to the delay length. + + Parameters + ---------- + delay_step: int, ArrayType + The delay length used to retrieve the data. + """ + assert delay_step is not None + if check.is_checking(): + jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step) + + if self.method == ROTATE_UPDATE: + delay_idx = (self.idx.value + delay_step) % (self.length + 1) + delay_idx = stop_gradient(delay_idx) + + elif self.method == CONCAT_UPDATE: + delay_idx = delay_step + + else: + raise ValueError(f'Unknown updating method "{self.method}"') + + # the delay index + if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): + raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') + indices = (delay_idx,) + tuple(indices) + + # the delay data + return self.data[indices] + + @not_pass_shargs + def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None: + """Update delay variable with the new data. + """ + if self.data is not None: + # get the latest target value + if latest_value is None: + latest_value = self.target.value + + # update the delay data at the rotation index + if self.method == ROTATE_UPDATE: + self.idx.value = stop_gradient(bm.as_jax((self.idx - 1) % (self.length + 1))) + self.data[self.idx.value] = latest_value + + # update the delay data at the first position + elif self.method == CONCAT_UPDATE: + if self.length >= 2: + self.data.value = bm.vstack([latest_value, self.data[1:]]) + else: + self.data[0] = latest_value + + def reset_state(self, batch_size: int = None): + """Reset the delay data. + """ + # initialize delay data + if self.data is not None: + self._init_data(self.length, batch_size) + + # time variables + if self.method == ROTATE_UPDATE: + self.idx.value = stop_gradient(jnp.asarray(0, dtype=jnp.int32)) + + def _init_data(self, length, batch_size: int = None): + if batch_size is not None: + if self.target.batch_size != batch_size: + raise ValueError(f'The batch sizes of delay variable and target variable differ ' + f'({self.target.batch_size} != {batch_size}). ' + 'Please reset the target variable first, because delay data ' + 'depends on the target variable. ') + + if self.target.batch_axis is None: + batch_axis = None + else: + batch_axis = self.target.batch_axis + 1 + self.data = bm.Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), + batch_axis=batch_axis) + # update delay data + self.data[0] = self.target.value + if isinstance(self._initial_delay_data, (bm.Array, jax.Array, float, int, bool)): + self.data[1:] = self._initial_delay_data + elif callable(self._initial_delay_data): + self.data[1:] = self._initial_delay_data((length,) + self.target.shape, dtype=self.target.dtype) diff --git a/brainpy/_src/experimental/neurons.py b/brainpy/_src/experimental/neurons.py new file mode 100644 index 000000000..d45acc862 --- /dev/null +++ b/brainpy/_src/experimental/neurons.py @@ -0,0 +1,157 @@ +from typing import Union, Callable, Optional + +from jax.lax import stop_gradient + +import brainpy.math as bm +from brainpy._src.dyn.base import NeuGroup, not_pass_shargs +from brainpy._src.initialize import (ZeroInit, OneInit, Initializer, parameter, variable_) +from brainpy._src.integrators import odeint +from brainpy.check import is_initializer, is_callable, is_subclass +from brainpy.types import Shape, ArrayType + + +class LIF(NeuGroup): + r"""Leaky integrate-and-fire neuron model. + + **Model Descriptions** + + The formal equations of a LIF model [1]_ is given by: + + .. math:: + + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ + \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad + \text{last} \quad \tau_{ref} \quad \text{ms} + + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`V_{reset}` is the reset membrane potential, + :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, + :math:`\tau_{ref}` is the refractory time period, + and :math:`I` is the time-variant synaptic inputs. + + **Model Examples** + + - `(Brette, Romain. 2004) LIF phase locking `_ + + + Parameters + ---------- + size: sequence of int, int + The size of the neuron group. + V_rest: float, ArrayType, Initializer, callable + Resting membrane potential. + V_reset: float, ArrayType, Initializer, callable + Reset potential after spike. + V_th: float, ArrayType, Initializer, callable + Threshold potential of spike. + R: float, ArrayType, Initializer, callable + Membrane resistance. + tau: float, ArrayType, Initializer, callable + Membrane time constant. + tau_ref: float, ArrayType, Initializer, callable + Refractory period length.(ms) + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + noise: ArrayType, Initializer, callable + The noise added onto the membrane potential + method: str + The numerical integration method. + name: str + The group name. + + References + ---------- + + .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model + neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # neuron parameter + V_rest: Union[float, ArrayType, Initializer, Callable] = 0., + V_reset: Union[float, ArrayType, Initializer, Callable] = -5., + V_th: Union[float, ArrayType, Initializer, Callable] = 20., + R: Union[float, ArrayType, Initializer, Callable] = 1., + tau: Union[float, ArrayType, Initializer, Callable] = 10., + tau_ref: Optional[Union[float, ArrayType, Initializer, Callable]] = None, + V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + + # training parameter + mode: Optional[bm.Mode] = None, + spike_fun: Callable = bm.surrogate.inv_square_grad, + + # other parameters + method: str = 'exp_auto', + name: Optional[str] = None, + ): + # initialization + super(LIF, self).__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode) + is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode), self.name) + + # parameters + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_reset = parameter(V_reset, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.spike_fun = is_callable(spike_fun, 'spike_fun') + + # initializers + is_initializer(V_initializer, 'V_initializer') + self._V_initializer = V_initializer + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + self.reset_state(self.mode) + + def derivative(self, V, t, I_ext): + return (-V + self.V_rest + self.R * I_ext) / self.tau + + def reset_state(self, batch_size=None): + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.spike = variable_(bm.zeros, self.varshape, batch_size) + if self.tau_ref is not None: + self.t_last_spike = variable_(OneInit(-1e7), self.varshape, batch_size) + + @not_pass_shargs + def update(self, current): + t = bm.share.get('t') + + # integrate membrane potential + V = self.integral(self.V.value, t, current, bm.dt) + + if self.tau_ref is not None: + refractory = stop_gradient((t - self.t_last_spike) <= self.tau_ref) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + spike = self.spike_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) + V += (self.V_reset - V) * spike_no_grad + t_last_spike = bm.where(spike_no_grad, t, self.t_last_spike) + + # updates + self.V.value = V + self.spike.value = spike + self.t_last_spike.value = stop_gradient(t_last_spike) + + else: + # spike, spiking time, and membrane potential reset + spike = self.spike_fun(V - self.V_th) + V += (self.V_reset - V) * stop_gradient(spike) + + # updates + self.V.value = V + self.spike.value = spike + + return spike diff --git a/brainpy/_src/experimental/synapses.py b/brainpy/_src/experimental/synapses.py new file mode 100644 index 000000000..5bb47df55 --- /dev/null +++ b/brainpy/_src/experimental/synapses.py @@ -0,0 +1,265 @@ +from typing import Union, Dict, Callable, Optional, Tuple + +import jax +import numpy as np + +import brainpy.math as bm +from brainpy import check +from brainpy._src import tools +from brainpy._src.connect import TwoEndConnector, All2All, One2One, MatConn, IJConn +from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs +from brainpy._src.initialize import Initializer, variable_, parameter +from brainpy._src.integrators import odeint +from brainpy.types import ArrayType +from .synout import SynOut +from .synstp import SynSTP + + +class Synapse(DynamicalSystem): + def __init__( + self, + conn: TwoEndConnector, + out: Optional[SynOut] = None, + stp: Optional[SynSTP] = None, + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(name=name, mode=mode) + + # parameters + assert isinstance(conn, TwoEndConnector) + self.conn = self._init_conn(conn) + self.pre_size = conn.pre_size + self.post_size = conn.post_size + self.pre_num = conn.pre_num + self.post_num = conn.post_num + assert out is None or isinstance(out, SynOut) + assert stp is None or isinstance(stp, SynSTP) + self.out = out + self.stp = stp + + def _init_conn(self, conn): + if isinstance(conn, TwoEndConnector): + pass + elif isinstance(conn, (bm.ndarray, np.ndarray, jax.Array)): + if (self.pre_num, self.post_num) != conn.shape: + raise ValueError(f'"conn" is provided as a matrix, and it is expected ' + f'to be an array with shape of (self.pre_num, self.post_num) = ' + f'{(self.pre_num, self.post_num)}, however we got {conn.shape}') + conn = MatConn(conn_mat=conn) + elif isinstance(conn, dict): + if not ('i' in conn and 'j' in conn): + raise ValueError(f'"conn" is provided as a dict, and it is expected to ' + f'be a dictionary with "i" and "j" specification, ' + f'however we got {conn}') + conn = IJConn(i=conn['i'], j=conn['j']) + elif conn is None: + conn = None + else: + raise ValueError(f'Unknown "conn" type: {conn}') + return conn + + def _init_weights( + self, + weight: Union[float, ArrayType, Initializer, Callable], + comp_method: str, + data_if_sparse: str = 'csr' + ) -> Tuple[Union[float, ArrayType], ArrayType]: + if comp_method not in ['sparse', 'dense']: + raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}') + if data_if_sparse not in ['csr', 'ij', 'coo']: + raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {data_if_sparse}') + + # connections and weights + if isinstance(self.conn, One2One): + weight = parameter(weight, (self.pre_num,), allow_none=False) + conn_mask = None + + elif isinstance(self.conn, All2All): + weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) + conn_mask = None + + else: + if comp_method == 'sparse': + if data_if_sparse == 'csr': + conn_mask = self.conn.require('pre2post') + elif data_if_sparse in ['ij', 'coo']: + conn_mask = self.conn.require('post_ids', 'pre_ids') + else: + ValueError(f'Unknown sparse data type: {data_if_sparse}') + weight = parameter(weight, conn_mask[0].shape, allow_none=False) + elif comp_method == 'dense': + weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) + conn_mask = self.conn.require('conn_mat') + else: + raise ValueError(f'Unknown connection type: {comp_method}') + + # training weights + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + return weight, conn_mask + + def _syn2post_with_all2all(self, syn_value, syn_weight, include_self): + if bm.ndim(syn_weight) == 0: + if isinstance(self.mode, bm.BatchingMode): + post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) + else: + post_vs = bm.sum(syn_value) + if not include_self: + post_vs = post_vs - syn_value + post_vs = syn_weight * post_vs + else: + post_vs = syn_value @ syn_weight + return post_vs + + def _syn2post_with_one2one(self, syn_value, syn_weight): + return syn_value * syn_weight + + def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): + if bm.ndim(syn_weight) == 0: + post_vs = (syn_weight * syn_value) @ conn_mat + else: + post_vs = syn_value @ (syn_weight * conn_mat) + return post_vs + + +class Exponential(Synapse): + r"""Exponential decay synapse model. + + **Model Descriptions** + + The single exponential decay synapse model assumes the release of neurotransmitter, + its diffusion across the cleft, the receptor binding, and channel opening all happen + very quickly, so that the channels instantaneously jump from the closed to the open state. + Therefore, its expression is given by + + .. math:: + + g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau} + + where :math:`\tau_{delay}` is the time constant of the synaptic state decay, + :math:`t_0` is the time of the pre-synaptic spike, + :math:`g_{\mathrm{max}}` is the maximal conductance. + + Accordingly, the differential form of the exponential synapse is given by + + .. math:: + + \begin{aligned} + & g_{\mathrm{syn}}(t) = g_{max} g * \mathrm{STP} \\ + & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}). + \end{aligned} + + where :math:`\mathrm{STP}` is used to model the short-term plasticity effect. + + Parameters + ---------- + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + tau: float, ArrayType + The time constant of decay. [ms] + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- + + .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. + "The Synapse." Principles of Computational Modelling in Neuroscience. + Cambridge: Cambridge UP, 2011. 172-95. Print. + + """ + + def __init__( + self, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + out: Optional[SynOut] = None, + stp: Optional[SynSTP] = None, + comp_method: str = 'sparse', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + tau: Union[float, ArrayType] = 8.0, + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(Exponential, self).__init__(conn=conn, + out=out, + stp=stp, + name=name, + mode=mode) + + # parameters + self.comp_method = comp_method + self.tau = check.is_float(tau, allow_int=True) + + # connections and weights + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, data_if_sparse='csr') + + # function + self.integral = odeint(lambda g, t: -g / self.tau, method=method) + + # variables + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.g = variable_(bm.zeros, self.post_num, batch_size) + if self.out is not None: + self.out.reset_state(batch_size) + if self.stp is not None: + self.stp.reset_state(batch_size) + + @not_pass_shargs + def update(self, pre_spike): + if self.stp is not None: + syn_value = self.stp(pre_spike) * pre_spike + else: + syn_value = bm.asarray(pre_spike, dtype=bm.float_) + + # post values + if isinstance(self.conn, All2All): + post_vs = self._syn2post_with_all2all(syn_value, self.g_max, self.conn.include_self) + elif isinstance(self.conn, One2One): + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) + else: + if self.comp_method == 'sparse': + bl = tools.import_brainpylib() + + if self.stp is None: + f = lambda s: bl.event_ops.event_csr_matvec(self.g_max, + self.conn_mask[0], + self.conn_mask[1], + s, + shape=(self.pre_num, self.post_num), + transpose=True) + if isinstance(self.mode, bm.BatchingMode): + f = jax.vmap(f) + post_vs = f(pre_spike) + else: + f = lambda s: bl.sparse_ops.cusparse_csr_matvec(self.g_max, + self.conn_mask[0], + self.conn_mask[1], + s, + shape=(self.pre_num, self.post_num), + transpose=True) + if isinstance(self.mode, bm.BatchingMode): + f = jax.vmap(f) + post_vs = f(syn_value) + else: + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + + # updates + self.g.value = self.integral(self.g.value, bm.share.get('t'), bm.dt) + post_vs + + # outputs + if self.out is not None: + return self.out(self.g.value) + else: + return self.g.value + diff --git a/brainpy/_src/experimental/synout.py b/brainpy/_src/experimental/synout.py new file mode 100644 index 000000000..c93eb7907 --- /dev/null +++ b/brainpy/_src/experimental/synout.py @@ -0,0 +1,131 @@ +from typing import Union, Optional + +import brainpy.math as bm +from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs +from brainpy.types import ArrayType + + +class SynOut(DynamicalSystem): + @not_pass_shargs + def update(self, g): + raise NotImplementedError + + def reset_state(self, batch_size: Optional[int] = None): + pass + + +class MgBlock(SynOut): + r"""Synaptic output based on Magnesium blocking. + + Given the synaptic conductance, the model output the post-synaptic current with + + .. math:: + + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) + + where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. + + Parameters + ---------- + E: float, ArrayType + The reversal potential for the synaptic current. [mV] + alpha: float, ArrayType + Binding constant. Default 0.062 + beta: float, ArrayType + Unbinding constant. Default 3.57 + cc_Mg: float, ArrayType + Concentration of Magnesium ion. Default 1.2 [mM]. + name: str + The model name. + """ + + def __init__( + self, + post_potential: bm.Variable, + E: Union[float, ArrayType] = 0., + cc_Mg: Union[float, ArrayType] = 1.2, + alpha: Union[float, ArrayType] = 0.062, + beta: Union[float, ArrayType] = 3.57, + name: str = None, + ): + super(MgBlock, self).__init__(name=name) + assert isinstance(post_potential, bm.Variable) + self.post_potential = post_potential + self.E = E + self.cc_Mg = cc_Mg + self.alpha = alpha + self.beta = beta + + @not_pass_shargs + def update(self, g): + I = g * (self.E - self.post_potential) / (1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post_potential)) + return I + + +class CUBA(SynOut): + r"""Current-based synaptic output. + + Given the conductance, this model outputs the post-synaptic current with a identity function: + + .. math:: + + I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) + + Parameters + ---------- + name: str + The model name. + + + See Also + -------- + COBA + """ + + def __init__(self, name: str = None, ): + super(CUBA, self).__init__(name=name) + + @not_pass_shargs + def update(self, V, g): + return g + + +class COBA(SynOut): + r"""Conductance-based synaptic output. + + Given the synaptic conductance, the model output the post-synaptic current with + + .. math:: + + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) + + Parameters + ---------- + E: float, ArrayType, ndarray + The reversal potential. + name: str + The model name. + + See Also + -------- + CUBA + """ + + def __init__(self, + post_potential: bm.Variable, + E: Union[float, ArrayType] = 0., + name: str = None, ): + super(COBA, self).__init__(name=name) + self.E = E + self.post_potential = post_potential + + @not_pass_shargs + def update(self, g): + I = g * (self.E - self.post_potential) + return I diff --git a/brainpy/_src/experimental/synstp.py b/brainpy/_src/experimental/synstp.py new file mode 100644 index 000000000..e4aac3a22 --- /dev/null +++ b/brainpy/_src/experimental/synstp.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- + +from typing import Union + +import jax.numpy as jnp + +from brainpy import math as bm, tools +from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs +from brainpy._src.initialize import variable_, OneInit, parameter +from brainpy._src.integrators import odeint, JointEq +from brainpy.types import ArrayType, Shape + +__all__ = [ + 'STD', + 'STP', +] + + +class SynSTP(DynamicalSystem): + """Base class for synaptic short-term plasticity.""" + + @not_pass_shargs + def update(self, pre_spike, post_g): + raise NotImplementedError + + +class STD(SynSTP): + r"""Synaptic output with short-term depression. + + This model filters the synaptic current by the following equation: + + .. math:: + + I_{syn}^+(t) = I_{syn}^-(t) * x + + where :math:`x` is the normalized variable between 0 and 1, and + :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before + and after STD filtering. + + Moreover, :math:`x` is updated according to the dynamics of: + + .. math:: + + \frac{dx}{dt} = \frac{1-x}{\tau} - U * x * \delta(t-t_{spike}) + + where :math:`U` is the fraction of resources used per action potential, + :math:`\tau` is the time constant of recovery of the synaptic vesicles. + + Parameters + ---------- + tau: float + The time constant of recovery of the synaptic vesicles. + U: float + The fraction of resources used per action potential. + + See Also + -------- + STP + """ + + def __init__( + self, + pre_size: Shape, + tau: float = 200., + U: float = 0.07, + method: str = 'exp_auto', + name: str = None + ): + super(STD, self).__init__(name=name) + + # parameters + self.pre_size = tools.to_size(pre_size) + self.num = tools.size2num(self.pre_size) + self.U = parameter(U, self.num) + self.tau = parameter(tau, self.num) + self.method = method + + # integral function + self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=self.method) + + # variables + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.x = variable_(jnp.ones, self.num, batch_size) + + @not_pass_shargs + def update(self, pre_spike): + x = self.integral(self.x.value, bm.share.get('t'), bm.share.get('dt')) + self.x.value = bm.where(pre_spike, x - self.U * self.x, x) + return self.x.value + + +class STP(SynSTP): + r"""Synaptic output with short-term plasticity. + + This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`. + + .. math:: + + I_{syn}^+(t) = I_{syn}^-(t) * x * u + + where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before + and after STP filtering, :math:`x` denotes the fraction of resources that remain available + after neurotransmitter depletion, and :math:`u` represents the fraction of available + resources ready for use (release probability). + + The dynamics of :math:`u` and :math:`x` are governed by + + .. math:: + + \begin{aligned} + \frac{du}{dt} & = & -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}), \\ + \frac{dx}{dt} & = & \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\ + \tag{1}\end{aligned} + + where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment + of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding + variables just before the arrival of the spike, and :math:`u^+` + refers to the moment just after the spike. + + Parameters + ---------- + tau_f: float + The time constant of short-term facilitation. + tau_d: float + The time constant of short-term depression. + U: float + The fraction of resources used per action potential. + method: str + The numerical integral method. + + See Also + -------- + STD + """ + + def __init__( + self, + pre_size: Shape, + U: Union[float, ArrayType] = 0.15, + tau_f: Union[float, ArrayType] = 1500., + tau_d: Union[float, ArrayType] = 200., + method: str = 'exp_auto', + name: str = None + ): + super(STP, self).__init__(name=name) + + # parameters + self.pre_size = tools.to_size(pre_size) + self.num = tools.size2num(self.pre_size) + self.tau_f = parameter(tau_f, self.num) + self.tau_d = parameter(tau_d, self.num) + self.U = parameter(U, self.num) + self.method = method + + # integral function + self.integral = odeint(JointEq([self.du, self.dx]), method=self.method) + + # variables + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.x = variable_(jnp.ones, batch_size, self.num) + self.u = variable_(OneInit(self.U), batch_size, self.num) + + du = lambda self, u, t: self.U - u / self.tau_f + dx = lambda self, x, t: (1 - x) / self.tau_d + + @not_pass_shargs + def update(self, pre_spike): + u, x = self.integral(self.u.value, self.x.value, bm.share.get('t'), bm.get_dt()) + u = bm.where(pre_spike, u + self.U * (1 - self.u), u) + x = bm.where(pre_spike, x - u * self.x, x) + self.x.value = x + self.u.value = u + return self.x.value * self.u.value + diff --git a/brainpy/_src/initialize/generic.py b/brainpy/_src/initialize/generic.py index f1925ebaf..a265f4f11 100644 --- a/brainpy/_src/initialize/generic.py +++ b/brainpy/_src/initialize/generic.py @@ -64,7 +64,7 @@ def parameter( if allow_scalar and isinstance(param, (float, int, bool)): return param if callable(param): - param = bm.asarray(param(size)) + param = param(size) elif isinstance(param, (np.ndarray, jnp.ndarray)): param = bm.asarray(param) elif isinstance(param, bm.Variable): diff --git a/brainpy/_src/math/_utils.py b/brainpy/_src/math/_utils.py index 2999a2f7b..6c4379a21 100644 --- a/brainpy/_src/math/_utils.py +++ b/brainpy/_src/math/_utils.py @@ -12,30 +12,13 @@ def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj -def wraps(fun: Callable): - """Specialized version of functools.wraps for wrapping numpy functions. - - This produces a wrapped function with a modified docstring. In particular, if - `update_doc` is True, parameters listed in the wrapped function that are not - supported by the decorated function will be removed from the docstring. For - this reason, it is important that parameter names match those in the original - numpy function. - """ - - def wrap(op): - docstr = getattr(fun, "__doc__", None) - op.__doc__ = docstr - op.__wrapped__ = fun - return op - - return wrap - - def _is_leaf(a): return isinstance(a, Array) -def _compatible_with_brainpy_array(fun: Callable): +def _compatible_with_brainpy_array( + fun: Callable, module: str = '' +): @functools.wraps(fun) def new_fun(*args, **kwargs): args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) @@ -62,6 +45,13 @@ def new_fun(*args, **kwargs): else: out.value = r - new_fun.__doc__ = getattr(fun, "__doc__", None) + new_fun.__doc__ = ( + f'Similar to ``jax.numpy.{module + fun.__name__}`` function, ' + f'while it is compatible with brainpy Array/Variable. \n\n' + f'Note that this function is also compatible with:\n\n' + f'1. NumPy or PyTorch syntax when receiving ``out`` argument.\n' + f'2. PyTorch syntax when receiving ``keepdim`` or ``dim`` argument.\n' + f'3. TensorFlow syntax when receiving ``keep_dims`` argument.' + ) return new_fun diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py index ad338107c..220fd7c09 100644 --- a/brainpy/_src/math/compat_numpy.py +++ b/brainpy/_src/math/compat_numpy.py @@ -101,36 +101,6 @@ _min = min _max = max -# def concatenate(arrays: Union[np.ndarray, Array, Sequence[Array]], -# axis: Optional[int] = None, -# dim: Optional[int] = None, -# dtype: Optional[DTypeLike] = None) -> Array: -# """Join a sequence of arrays along an existing axis. -# -# -# Parameters -# ---------- -# a1, a2, ... : sequence of array_like -# The arrays must have the same shape, except in the dimension -# corresponding to `axis` (the first, by default). -# axis : int, optional -# The axis along which the arrays will be joined. If axis is None, -# arrays are flattened before use. Default is 0. -# dtype : str or dtype -# If provided, the destination array will have this dtype. Cannot be -# provided together with `out`. -# -# Returns -# ------- -# res : ndarray -# The concatenated array. -# """ -# axis = one_of(0, axis, dim, ['axis', 'dim']) -# r = jnp.concatenate(tree_map(_as_jax_array_, arrays, is_leaf=_is_leaf), -# axis=axis, -# dtype=dtype) -# return _return(r) - def fill_diagonal(a, val, inplace=True): if a.ndim < 2: @@ -454,7 +424,6 @@ def msort(a): logical_or = _compatible_with_brainpy_array(jnp.logical_or) logical_xor = _compatible_with_brainpy_array(jnp.logical_xor) all = _compatible_with_brainpy_array(jnp.all) - any = _compatible_with_brainpy_array(jnp.any) alltrue = all diff --git a/brainpy/_src/math/compat_tensorflow.py b/brainpy/_src/math/compat_tensorflow.py index 72a1d2458..77d4c4feb 100644 --- a/brainpy/_src/math/compat_tensorflow.py +++ b/brainpy/_src/math/compat_tensorflow.py @@ -1,11 +1,16 @@ +from typing import Union, Optional + import jax.numpy as jnp import jax.ops +from jax import lax -from .ndarray import _return, _as_jax_array_ +from brainpy._src.math.arrayinterporate import as_jax +from brainpy._src.math.ndarray import Array from .compat_numpy import ( prod, min, sum, all, any, mean, std, var, concatenate, clip, asarray, ) +from .ndarray import _return, _as_jax_array_ __all__ = [ 'concat', @@ -13,6 +18,7 @@ 'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance', 'reduce_euclidean_norm', 'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum', 'unsorted_segment_prod', 'unsorted_segment_max', 'unsorted_segment_min', 'unsorted_segment_mean', + 'segment_sum', 'segment_prod', 'segment_max', 'segment_min', 'clip_by_value', 'cast', ] @@ -210,6 +216,214 @@ def unsorted_segment_mean(data, segment_ids, num_segments): return _return(jnp.nan_to_num(r / d)) +def segment_sum(data: Union[Array, jnp.ndarray], + segment_ids: Union[Array, jnp.ndarray], + num_segments: Optional[int] = None, + indices_are_sorted: bool = False, + unique_indices: bool = False, + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> Array: + """``segment_sum`` operator for brainpy `Array` and `Variable`. + + Parameters + ---------- + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode: lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns + ------- + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ + return Array(jax.ops.segment_sum(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) + + +def segment_prod(data: Union[Array, jnp.ndarray], + segment_ids: Union[Array, jnp.ndarray], + num_segments: Optional[int] = None, + indices_are_sorted: bool = False, + unique_indices: bool = False, + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> Array: + """``segment_prod`` operator for brainpy `Array` and `Variable`. + + Parameters + ---------- + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode: lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns + ------- + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ + return Array(jax.ops.segment_prod(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) + + +def segment_max(data: Union[Array, jnp.ndarray], + segment_ids: Union[Array, jnp.ndarray], + num_segments: Optional[int] = None, + indices_are_sorted: bool = False, + unique_indices: bool = False, + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> Array: + """``segment_max`` operator for brainpy `Array` and `Variable`. + + Parameters + ---------- + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode: lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns + ------- + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ + return Array(jax.ops.segment_max(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) + + +def segment_min(data: Union[Array, jnp.ndarray], + segment_ids: Union[Array, jnp.ndarray], + num_segments: Optional[int] = None, + indices_are_sorted: bool = False, + unique_indices: bool = False, + bucket_size: Optional[int] = None, + mode: Optional[lax.GatherScatterMode] = None) -> Array: + """``segment_min`` operator for brainpy `Array` and `Variable`. + + Parameters + ---------- + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode: lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns + ------- + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ + return Array(jax.ops.segment_min(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) + + def cast(x, dtype): """Casts a tensor to a new type. diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 889e43828..fcb4af366 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -1,27 +1,35 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable +from typing import Union, Callable, Optional +import numpy as np +import jax import jax.numpy as jnp from jax import vmap from jax.lax import cond, stop_gradient -from brainpy import check +from brainpy import check, math as bm from brainpy.check import is_float, is_integer, jit_error_checking from brainpy.errors import UnsupportedError from .object_transform.base import BrainPyObject -from .environment import get_dt, get_float from .ndarray import ndarray, Variable, Array from .arrayinterporate import as_jax -from . import compat_numpy as bm +from .environment import get_dt, get_int __all__ = [ 'AbstractDelay', 'TimeDelay', 'LengthDelay', 'NeuTimeDelay', 'NeuLenDelay', + 'DelayVariable', + 'ROTATE_UPDATE', + 'CONCAT_UPDATE', ] +def _as_jax_array(arr): + return arr.value if isinstance(arr, Array) else arr + + class AbstractDelay(BrainPyObject): def update(self, *args, **kwargs): raise NotImplementedError @@ -125,7 +133,7 @@ def __init__( # delay_len self.t0 = t0 - self.dt = get_dt() if dt is None else dt + self.dt = bm.get_dt() if dt is None else dt is_float(delay_len, 'delay_len', allow_none=False, allow_int=True, min_bound=0.) self.delay_len = delay_len self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 @@ -139,7 +147,7 @@ def __init__( # time variables self.idx = Variable(jnp.asarray([0])) is_float(t0, 't0', allow_none=False, allow_int=True, ) - self.current_time = Variable(jnp.asarray([t0], dtype=get_float())) + self.current_time = Variable(jnp.asarray([t0], dtype=bm.get_float())) # delay data batch_axis = None @@ -267,8 +275,8 @@ class NeuTimeDelay(TimeDelay): pass -ROTATION_UPDATING = 'rotation' -CONCAT_UPDATING = 'concatenate' +ROTATE_UPDATE = 'rotation' +CONCAT_UPDATE = 'concat' class LengthDelay(AbstractDelay): @@ -318,20 +326,20 @@ class LengthDelay(AbstractDelay): def __init__( self, - delay_target: Union[ndarray, jnp.ndarray], + delay_target: Union[ndarray, jax.Array], delay_len: int, - initial_delay_data: Union[float, int, bool, ndarray, jnp.ndarray, Callable] = None, + initial_delay_data: Union[float, int, bool, ndarray, jax.Array, Callable] = None, name: str = None, batch_axis: int = None, - update_method: str = ROTATION_UPDATING + update_method: str = ROTATE_UPDATE ): super(LengthDelay, self).__init__(name=name) - assert update_method in [ROTATION_UPDATING, CONCAT_UPDATING] + assert update_method in [ROTATE_UPDATE, CONCAT_UPDATE] self.update_method = update_method # attributes and variables self.data: Variable = None - self.num_delay_step: int = None + self.num_delay_step: int = 0 self.idx: Variable = None # initialization @@ -356,9 +364,9 @@ def __repr__(self): def reset( self, delay_target, - delay_len=None, - initial_delay_data=None, - batch_axis=None + delay_len: int = None, + initial_delay_data: Union[float, int, bool, ndarray, jnp.ndarray, Callable] = None, + batch_axis: int = None ): if not isinstance(delay_target, (ndarray, jnp.ndarray)): raise ValueError(f'Must be an instance of brainpy.math.ndarray ' @@ -392,12 +400,12 @@ def reset( self.data[1:] = initial_delay_data elif callable(initial_delay_data): self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape, - dtype=delay_target.dtype) + dtype=delay_target.dtype) else: raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') # time variables - if self.update_method == ROTATION_UPDATING: + if self.update_method == ROTATE_UPDATE: if self.idx is None: self.idx = Variable(stop_gradient(jnp.asarray([0], dtype=jnp.int32))) else: @@ -422,11 +430,11 @@ def retrieve(self, delay_len, *indices): if check.is_checking(): jit_error_checking(jnp.any(delay_len >= self.num_delay_step), self._check_delay, delay_len) - if self.update_method == ROTATION_UPDATING: + if self.update_method == ROTATE_UPDATE: delay_idx = (self.idx[0] + delay_len) % self.num_delay_step delay_idx = stop_gradient(delay_idx) - elif self.update_method == CONCAT_UPDATING: + elif self.update_method == CONCAT_UPDATE: delay_idx = delay_len else: @@ -449,11 +457,11 @@ def update(self, value: Union[float, int, bool, Array, jnp.DeviceArray]): value: Any The value of the latest data, used to update this delay variable. """ - if self.update_method == ROTATION_UPDATING: + if self.update_method == ROTATE_UPDATE: self.idx.value = stop_gradient(as_jax((self.idx - 1) % self.num_delay_step)) self.data[self.idx[0]] = value - elif self.update_method == CONCAT_UPDATING: + elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: self.data.value = bm.vstack([bm.broadcast_to(value, self.data.shape[1:]), self.data[1:]]) else: @@ -463,6 +471,263 @@ def update(self, value: Union[float, int, bool, Array, jnp.DeviceArray]): raise ValueError(f'Unknown updating method "{self.update_method}"') +class DelayVariable(AbstractDelay): + """Delay variable which has a fixed delay length. + + The data in this delay variable is arranged as:: + + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + + Parameters + ---------- + target: Variable + The initial delay data. + length: int + The delay data length. + initial_delay_data: Any + The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + + update_method: str + The method used for updating delay. + + See Also + -------- + TimeDelay + """ + + data: Optional[Variable] + idx: Optional[Variable] + length: int + + def __init__( + self, + target: Variable, + length: int = 0, + initial_delay_data: Union[float, int, bool, Array, jax.Array, Callable] = None, + update_method: str = ROTATE_UPDATE + ): + super().__init__() + + assert update_method in [ROTATE_UPDATE, CONCAT_UPDATE] + self.update_method = update_method + + # target + self.target = target + if not isinstance(target, Variable): + raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') + + # delay length + self.length = is_integer(length, allow_none=False, min_bound=0) + + # delay data + if initial_delay_data is not None: + assert isinstance(initial_delay_data, (int, float, bool, Array, jax.Array, Callable)) + self._initial_delay_data = initial_delay_data + self._init_data(length) + + # time variables + if self.update_method == ROTATE_UPDATE: + self.idx = Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) + + # other info + self._access_to_step = dict() + + def register_entry( + self, + entry: str, + delay_step: Optional[Union[int, Array, Callable]] = None, + delay_time: Optional[Union[float, Array, Callable]] = None, + ) -> 'DelayVariable': + """Register an entry to access the data. + + Args: + entry (str): The entry to access the delay data. + delay_step: The delay step of the entry (must be an integer, denoting the delay step). + delay_time: The delay time of the entry (can be a float). + + Returns: + Return the self. + """ + if delay_time is not None: + if delay_step is not None: + raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') + if callable(delay_time): + delay_time = _as_jax_array(delay_time(self.delay_target_shape)) + delay_step = jnp.asarray(delay_time / get_dt(), dtype=get_int()) + elif isinstance(delay_time, float): + delay_step = int(delay_time / get_dt()) + else: + delay_step = jnp.asarray(_as_jax_array(delay_time) / get_dt(), dtype=get_int()) + + # delay steps + if delay_step is None: + delay_type = 'none' + elif isinstance(delay_step, int): + delay_type = 'homo' + elif isinstance(delay_step, (Array, jax.Array, np.ndarray)): + if delay_step.size == 1 and delay_step.ndim == 0: + delay_type = 'homo' + else: + delay_type = 'heter' + delay_step = Array(delay_step) + elif callable(delay_step): + delay_step = delay_step(self.delay_target_shape) + delay_type = 'heter' + else: + raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' + f'integer, array of integers, callable function, brainpy.init.Initializer.') + if delay_type == 'heter': + if delay_step.dtype not in [jnp.int32, jnp.int64]: + raise ValueError('Only support delay steps of int32, int64. If your ' + 'provide delay time length, please divide the "dt" ' + 'then provide us the number of delay steps.') + if self.delay_target_shape[0] != delay_step.shape[0]: + raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') + if delay_type == 'heter': + max_delay_step = int(max(delay_step)) + elif delay_type == 'homo': + max_delay_step = delay_step + else: + max_delay_step = None + + # delay variable + if max_delay_step is not None: + if self.length < max_delay_step: + self._init_data(max_delay_step) + self.length = max_delay_step + self._access_to_step[entry] = delay_step + return self + + def at_entry(self, entry: str, *indices) -> Array: + """Get the data at the given entry. + + Args: + entry (str): The entry to access the data. + *indices: + + Returns: + The data. + """ + assert isinstance(entry, str) + if entry not in self._access_to_step: + raise KeyError(f'Does not find delay access "{entry}".') + delay_step = self._access_to_step[entry] + if delay_step is None: + return self.target.value + else: + assert self.data is not None + if isinstance(delay_step, slice): + return self.retrieve(delay_step, *indices) + elif np.ndim(delay_step) == 0: + return self.retrieve(delay_step, *indices) + else: + if len(indices) == 0 and len(delay_step) == self.target.shape[0]: + indices = (jnp.arange(delay_step.size),) + return self.retrieve(delay_step, *indices) + + @property + def delay_target_shape(self): + """The data shape of the delay target.""" + return self.target.shape + + def __repr__(self): + name = self.__class__.__name__ + return (f'{name}(num_delay_step={self.length}, ' + f'delay_target_shape={self.delay_target_shape}, ' + f'update_method={self.update_method})') + + def _check_delay(self, delay_len): + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.length}. ' + f'But we got {delay_len}') + + def __call__(self, delay_len, *indices): + return self.retrieve(delay_len, *indices) + + def retrieve(self, delay_step, *indices): + """Retrieve the delay data acoording to the delay length. + + Parameters + ---------- + delay_step: int, ArrayType + The delay length used to retrieve the data. + """ + if check.is_checking(): + jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step) + + if self.update_method == ROTATE_UPDATE: + delay_idx = (self.idx.value + delay_step) % self.length + delay_idx = stop_gradient(delay_idx) + + elif self.update_method == CONCAT_UPDATE: + delay_idx = delay_step + + else: + raise ValueError(f'Unknown updating method "{self.update_method}"') + + # the delay index + if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): + raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') + indices = (delay_idx,) + tuple(indices) + + # the delay data + return self.data[indices] + + def update(self): + """Update delay variable with the new data. + """ + # update the delay data at the rotation index + if self.update_method == ROTATE_UPDATE: + self.idx.value = stop_gradient(as_jax((self.idx - 1) % self.length)) + self.data[self.idx.value] = self.target.value + + # update the delay data at the first position + elif self.update_method == CONCAT_UPDATE: + if self.length >= 2: + self.data.value = bm.vstack([self.target.value, self.data[1:]]) + else: + self.data[0] = self.target.value + + def reset(self): + """Reset the delay data. + """ + # initialize delay data + self._init_data(self.length) + + # time variables + if self.update_method == ROTATE_UPDATE: + self.idx.value = stop_gradient(jnp.asarray(0, dtype=jnp.int32)) + + def _init_data(self, length): + if self.target.batch_axis is None: + batch_axis = None + else: + batch_axis = self.target.batch_axis + 1 + self.data = Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), + batch_axis=batch_axis) + # update delay data + self.data[0] = self.target.value + if isinstance(self._initial_delay_data, (Array, jax.Array, float, int, bool)): + self.data[1:] = self._initial_delay_data + elif callable(self._initial_delay_data): + self.data[1:] = self._initial_delay_data((length,) + self.target.shape, dtype=self.target.dtype) + + class NeuLenDelay(LengthDelay): """Neutral Length Delay. Alias of :py:class:`~.LengthDelay`.""" pass diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 8ff95552e..4b73050c5 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -6,16 +6,22 @@ import os import re import sys -from typing import Any, Callable, TypeVar, cast +from typing import Any, Callable, TypeVar, cast, Dict, Union, Optional +import jax +import numpy as np from jax import config, numpy as jnp, devices from jax.lib import xla_bridge +from brainpy._src.tools.dicts import DotDict from . import modes +# from .delayvars import LengthDelay, ROTATE_UPDATE +from .ndarray import Variable, Array bm = None __all__ = [ + 'share', # default data types 'set_float', 'get_float', @@ -52,6 +58,75 @@ ] +def _key_of_var(var: Variable): + if not isinstance(var, Variable): + raise TypeError(f'Delay target should be instance of Variable. But got {type(var)}') + return f'var{id(var)}' + + +def _as_jax_array(arr): + return arr.value if isinstance(arr, Array) else arr + + +class Context: + """Context for brainpy computation.""" + + def __init__(self): + """Initialize function.""" + + '''Shared data across all nodes at current time step. + ''' + self._arguments = DotDict() + + def get(self, key): + """Get the shared data by the ``key``. + + Args: + key (str): the key to indicate the data. + """ + if key in self._arguments: + return self._arguments[key] + else: + raise KeyError(f'Cannot found shared data of {key}.') + + # shared arguments # + # ---------------- # + + def save_shargs(self, **shared) -> None: + """Save shared arguments in the global context.""" + self._arguments.update(shared) + + def get_shargs(self) -> DotDict: + """Get all shared arguments in the global context.""" + r = self._arguments.copy() + return r + + def remove_shargs(self, *args) -> None: + """Clear all shared arguments in the global context.""" + if len(args) > 0: + for a in args: + self._arguments.pop(a) + else: + self._arguments.clear() + + # other # + # ----- # + + def clear(self) -> None: + """Clear all shared data in this computation context.""" + self.remove_shargs() + + +share = Context() +'''Global context manager to manage ``share`` data across all modules.''' + + +def change_share_context(context: Context): + global share + assert isinstance(context, Context), f'Must be instance of {Context.__name__}' + share = context + + # default dtype # -------------------------- diff --git a/brainpy/_src/math/fft.py b/brainpy/_src/math/fft.py index f5058caf2..fd745eadc 100644 --- a/brainpy/_src/math/fft.py +++ b/brainpy/_src/math/fft.py @@ -10,21 +10,21 @@ "irfft2", "irfftn", "rfft", "rfft2", "rfftfreq", "rfftn" ] -fft = _compatible_with_brainpy_array(jfft.fft) -fft2 = _compatible_with_brainpy_array(jfft.fft2) -fftfreq = _compatible_with_brainpy_array(jfft.fftfreq) -fftn = _compatible_with_brainpy_array(jfft.fftn) -fftshift = _compatible_with_brainpy_array(jfft.fftshift) -hfft = _compatible_with_brainpy_array(jfft.hfft) -ifft = _compatible_with_brainpy_array(jfft.ifft) -ifft2 = _compatible_with_brainpy_array(jfft.ifft2) -ifftn = _compatible_with_brainpy_array(jfft.ifftn) -ifftshift = _compatible_with_brainpy_array(jfft.ifftshift) -ihfft = _compatible_with_brainpy_array(jfft.ihfft) -irfft = _compatible_with_brainpy_array(jfft.irfft) -irfft2 = _compatible_with_brainpy_array(jfft.irfft2) -irfftn = _compatible_with_brainpy_array(jfft.irfftn) -rfft = _compatible_with_brainpy_array(jfft.rfft) -rfft2 = _compatible_with_brainpy_array(jfft.rfft2) -rfftfreq = _compatible_with_brainpy_array(jfft.rfftfreq) -rfftn = _compatible_with_brainpy_array(jfft.rfftn) +fft = _compatible_with_brainpy_array(jfft.fft, module='fft.') +fft2 = _compatible_with_brainpy_array(jfft.fft2, module='fft.') +fftfreq = _compatible_with_brainpy_array(jfft.fftfreq, module='fft.') +fftn = _compatible_with_brainpy_array(jfft.fftn, module='fft.') +fftshift = _compatible_with_brainpy_array(jfft.fftshift, module='fft.') +hfft = _compatible_with_brainpy_array(jfft.hfft, module='fft.') +ifft = _compatible_with_brainpy_array(jfft.ifft, module='fft.') +ifft2 = _compatible_with_brainpy_array(jfft.ifft2, module='fft.') +ifftn = _compatible_with_brainpy_array(jfft.ifftn, module='fft.') +ifftshift = _compatible_with_brainpy_array(jfft.ifftshift, module='fft.') +ihfft = _compatible_with_brainpy_array(jfft.ihfft, module='fft.') +irfft = _compatible_with_brainpy_array(jfft.irfft, module='fft.') +irfft2 = _compatible_with_brainpy_array(jfft.irfft2, module='fft.') +irfftn = _compatible_with_brainpy_array(jfft.irfftn, module='fft.') +rfft = _compatible_with_brainpy_array(jfft.rfft, module='fft.') +rfft2 = _compatible_with_brainpy_array(jfft.rfft2, module='fft.') +rfftfreq = _compatible_with_brainpy_array(jfft.rfftfreq, module='fft.') +rfftn = _compatible_with_brainpy_array(jfft.rfftn, module='fft.') diff --git a/brainpy/_src/math/linalg.py b/brainpy/_src/math/linalg.py index ba152e421..a47207d3c 100644 --- a/brainpy/_src/math/linalg.py +++ b/brainpy/_src/math/linalg.py @@ -10,23 +10,23 @@ 'tensorinv', 'tensorsolve', 'multi_dot' ] -cholesky = _compatible_with_brainpy_array(linalg.cholesky) -cond = _compatible_with_brainpy_array(linalg.cond) -det = _compatible_with_brainpy_array(linalg.det) -eig = _compatible_with_brainpy_array(linalg.eig) -eigh = _compatible_with_brainpy_array(linalg.eigh) -eigvals = _compatible_with_brainpy_array(linalg.eigvals) -eigvalsh = _compatible_with_brainpy_array(linalg.eigvalsh) -inv = _compatible_with_brainpy_array(linalg.inv) -svd = _compatible_with_brainpy_array(linalg.svd) -lstsq = _compatible_with_brainpy_array(linalg.lstsq) -matrix_power = _compatible_with_brainpy_array(linalg.matrix_power) -matrix_rank = _compatible_with_brainpy_array(linalg.matrix_rank) -norm = _compatible_with_brainpy_array(linalg.norm) -pinv = _compatible_with_brainpy_array(linalg.pinv) -qr = _compatible_with_brainpy_array(linalg.qr) -solve = _compatible_with_brainpy_array(linalg.solve) -slogdet = _compatible_with_brainpy_array(linalg.slogdet) -tensorinv = _compatible_with_brainpy_array(linalg.tensorinv) -tensorsolve = _compatible_with_brainpy_array(linalg.tensorsolve) -multi_dot = _compatible_with_brainpy_array(linalg.multi_dot) \ No newline at end of file +cholesky = _compatible_with_brainpy_array(linalg.cholesky, module='linalg.') +cond = _compatible_with_brainpy_array(linalg.cond, module='linalg.') +det = _compatible_with_brainpy_array(linalg.det, module='linalg.') +eig = _compatible_with_brainpy_array(linalg.eig, module='linalg.') +eigh = _compatible_with_brainpy_array(linalg.eigh, module='linalg.') +eigvals = _compatible_with_brainpy_array(linalg.eigvals, module='linalg.') +eigvalsh = _compatible_with_brainpy_array(linalg.eigvalsh, module='linalg.') +inv = _compatible_with_brainpy_array(linalg.inv, module='linalg.') +svd = _compatible_with_brainpy_array(linalg.svd, module='linalg.') +lstsq = _compatible_with_brainpy_array(linalg.lstsq, module='linalg.') +matrix_power = _compatible_with_brainpy_array(linalg.matrix_power, module='linalg.') +matrix_rank = _compatible_with_brainpy_array(linalg.matrix_rank, module='linalg.') +norm = _compatible_with_brainpy_array(linalg.norm, module='linalg.') +pinv = _compatible_with_brainpy_array(linalg.pinv, module='linalg.') +qr = _compatible_with_brainpy_array(linalg.qr, module='linalg.') +solve = _compatible_with_brainpy_array(linalg.solve, module='linalg.') +slogdet = _compatible_with_brainpy_array(linalg.slogdet, module='linalg.') +tensorinv = _compatible_with_brainpy_array(linalg.tensorinv, module='linalg.') +tensorsolve = _compatible_with_brainpy_array(linalg.tensorsolve, module='linalg.') +multi_dot = _compatible_with_brainpy_array(linalg.multi_dot, module='linalg.') \ No newline at end of file diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 2ac63f38b..1aaebeba4 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- + from typing import Union, Optional, NoReturn, Sequence, Any, Tuple as TupleType import warnings import operator @@ -857,9 +858,155 @@ def var(self, axis=None, dtype=None, ddof=0, keepdims=False): r = self.value.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) return _return(r) - def view(self, dtype=None, *args, **kwargs): - """New view of array with the same data.""" - return _return(self.value.view(dtype=dtype, *args, **kwargs)) + def view(self, *args, dtype=None): + r"""New view of array with the same data. + + This function is compatible with pytorch syntax. + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`shape`. + + The returned tensor shares the same data and must have the same number + of elements, but may have a different size. For a tensor to be viewed, the new + view size must be compatible with its original size and stride, i.e., each new + view dimension must either be a subspace of an original dimension, or only span + across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following + contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + + .. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + + Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` + without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a + :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which + returns a view if the shapes are compatible, and copies (equivalent to calling + :meth:`contiguous`) otherwise. + + Args: + shape (int...): the desired size + + Example:: + + >>> x = brainpy.math.randn(4, 4) + >>> x.size + [4, 4] + >>> y = x.view(16) + >>> y.size + [16] + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size + [2, 8] + + >>> a = brainpy.math.randn(1, 2, 3, 4) + >>> a.size + [1, 2, 3, 4] + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size + [1, 3, 2, 4] + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size + [1, 3, 2, 4] + >>> brainpy.math.equal(b, c) + False + + + .. method:: view(dtype) -> Tensor + :noindex: + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`dtype`. + + If the element size of :attr:`dtype` is different than that of ``self.dtype``, + then the size of the last dimension of the output will be scaled + proportionally. For instance, if :attr:`dtype` element size is twice that of + ``self.dtype``, then each pair of elements in the last dimension of + :attr:`self` will be combined, and the size of the last dimension of the output + will be half that of :attr:`self`. If :attr:`dtype` element size is half that + of ``self.dtype``, then each element in the last dimension of :attr:`self` will + be split in two, and the size of the last dimension of the output will be + double that of :attr:`self`. For this to be possible, the following conditions + must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + + Additionally, if the element size of :attr:`dtype` is greater than that of + ``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + + If any of the above conditions are not met, an error is thrown. + + + Args: + dtype (:class:`dtype`): the desired dtype + + Example:: + + >>> x = brainpy.math.randn(4, 4) + >>> x + Array([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + brainpy.math.float32 + + >>> y = x.view(brainpy.math.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=brainpy.math.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(brainpy.math.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(brainpy.math.cfloat).size + [4, 2] + + >>> x.view(brainpy.math.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=brainpy.math.uint8) + >>> x.view(brainpy.math.uint8).size + [4, 16] + + """ + if len(args) == 0: + if dtype is None: + raise ValueError('Provide dtype or shape.') + else: + return _return(self.value.view(dtype)) + else: + if isinstance(args[0], int): # shape + if dtype is not None: + raise ValueError('Provide one of dtype or shape. Not both.') + return _return(self.value.reshape(*args)) + else: # dtype + assert not isinstance(args[0], int) + assert dtype is None + return _return(self.value.view(args[0])) # ------------------ # NumPy support diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index baa03f1a4..041bb8a70 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -351,7 +351,20 @@ def unique_name(self, name=None, type_=None): check_name_uniqueness(name=name, obj=self) return name - def state_dict(self): + def __state_dict__(self) -> dict: + return self.vars(include_self=True, level=0).unique() + + def __load_state_dict__(self, state_dict: dict) -> Optional[Tuple[Sequence[str], Sequence[str]]]: + variables = self.vars(include_self=True, level=0).unique() + keys1 = set(state_dict.keys()) + keys2 = set(variables.keys()) + for key in keys2.intersection(keys1): + variables[key].value = state_dict[key] + unexpected_keys = list(keys1 - keys2) + missing_keys = list(keys2 - keys1) + return unexpected_keys, missing_keys + + def state_dict(self) -> dict: """Returns a dictionary containing a whole state of the module. Returns @@ -359,9 +372,10 @@ def state_dict(self): out: dict A dictionary containing a whole state of the module. """ - return self.vars().unique().dict() + nodes = self.nodes() # retrieve all nodes + return {key: node.__state_dict__() for key, node in nodes.items()} - def load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True): + def load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True, compatible='v2'): """Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. @@ -380,13 +394,24 @@ def load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True): * **missing_keys** is a list of str containing the missing keys * **unexpected_keys** is a list of str containing the unexpected keys """ - variables = self.vars().unique() - keys1 = set(state_dict.keys()) - keys2 = set(variables.keys()) - unexpected_keys = list(keys1 - keys2) - missing_keys = list(keys2 - keys1) - for key in keys2.intersection(keys1): - variables[key].value = state_dict[key] + if compatible == 'v1': + variables = self.vars().unique() + keys1 = set(state_dict.keys()) + keys2 = set(variables.keys()) + unexpected_keys = list(keys1 - keys2) + missing_keys = list(keys2 - keys1) + for key in keys2.intersection(keys1): + variables[key].value = state_dict[key] + elif compatible == 'v2': + nodes = self.nodes() + missing_keys = [] + unexpected_keys = [] + for name, node in nodes.items(): + missing, unexpected = node.__load_state_dict__(state_dict[name]) + missing_keys.extend([f'{name}.{key}' for key in missing]) + unexpected_keys.extend([f'{name}.{key}' for key in unexpected]) + else: + raise ValueError(f'Unknown compatible version: {compatible}') if warn: if len(unexpected_keys): warnings.warn(f'Unexpected keys in state_dict: {unexpected_keys}', UserWarning) diff --git a/brainpy/_src/math/operators/__init__.py b/brainpy/_src/math/operators/__init__.py index f7e7e0ba9..c628a8269 100644 --- a/brainpy/_src/math/operators/__init__.py +++ b/brainpy/_src/math/operators/__init__.py @@ -7,18 +7,15 @@ from . import ( op_register, pre_syn_post, - wrap_jax, sparse_matmul, ) __all__ = ( op_register.__all__ + pre_syn_post.__all__ - + wrap_jax.__all__ + sparse_matmul.__all__ ) from .sparse_matmul import * from .op_register import * from .pre_syn_post import * -from .wrap_jax import * diff --git a/brainpy/_src/math/operators/op_register.py b/brainpy/_src/math/operators/op_register.py index ba75d45e0..099edb488 100644 --- a/brainpy/_src/math/operators/op_register.py +++ b/brainpy/_src/math/operators/op_register.py @@ -3,11 +3,11 @@ import warnings from typing import Callable -import brainpylib from jax.tree_util import tree_map from brainpy._src.math.object_transform.base import BrainPyObject from brainpy._src.math.ndarray import Array +from brainpy._src.tools.package import import_brainpylib __all__ = [ 'XLACustomOp', @@ -80,6 +80,7 @@ def __init__( gpu_func = None # register OP + brainpylib = import_brainpylib() self.op = brainpylib.register_op_with_numba( self.name, cpu_func=cpu_func, diff --git a/brainpy/_src/math/operators/wrap_jax.py b/brainpy/_src/math/operators/wrap_jax.py deleted file mode 100644 index 9ef4ec648..000000000 --- a/brainpy/_src/math/operators/wrap_jax.py +++ /dev/null @@ -1,226 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Union, Optional - -import jax.numpy as jnp -from jax import lax -from jax import ops as jops - -from brainpy._src.math.ndarray import Array -from brainpy._src.math.arrayinterporate import as_jax - -__all__ = [ - 'segment_sum', - 'segment_prod', - 'segment_max', - 'segment_min', -] - - -def segment_sum(data: Union[Array, jnp.ndarray], - segment_ids: Union[Array, jnp.ndarray], - num_segments: Optional[int] = None, - indices_are_sorted: bool = False, - unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> Array: - """``segment_sum`` operator for brainpy `Array` and `Variable`. - - Parameters - ---------- - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns - ------- - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. - """ - return Array(jops.segment_sum(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) - - -def segment_prod(data: Union[Array, jnp.ndarray], - segment_ids: Union[Array, jnp.ndarray], - num_segments: Optional[int] = None, - indices_are_sorted: bool = False, - unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> Array: - """``segment_prod`` operator for brainpy `Array` and `Variable`. - - Parameters - ---------- - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns - ------- - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. - """ - return Array(jops.segment_prod(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) - - -def segment_max(data: Union[Array, jnp.ndarray], - segment_ids: Union[Array, jnp.ndarray], - num_segments: Optional[int] = None, - indices_are_sorted: bool = False, - unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> Array: - """``segment_max`` operator for brainpy `Array` and `Variable`. - - Parameters - ---------- - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns - ------- - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. - """ - return Array(jops.segment_max(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) - - -def segment_min(data: Union[Array, jnp.ndarray], - segment_ids: Union[Array, jnp.ndarray], - num_segments: Optional[int] = None, - indices_are_sorted: bool = False, - unique_indices: bool = False, - bucket_size: Optional[int] = None, - mode: Optional[lax.GatherScatterMode] = None) -> Array: - """``segment_min`` operator for brainpy `Array` and `Variable`. - - Parameters - ---------- - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns - ------- - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. - """ - return Array(jops.segment_min(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 34670d576..fbd79ec7f 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -745,7 +745,7 @@ def uniform(self, low=0.0, high=1.0, size=None, key=None): r = jr.uniform(key, shape=_size2shape(size), minval=low, maxval=high) return _return(r) - def truncated_normal(self, lower, upper, size, scale=None, key=None): + def truncated_normal(self, lower, upper, size=None, scale=None, key=None): lower = _as_jax_array(lower) lower = _check_py_seq(lower) upper = _as_jax_array(upper) diff --git a/brainpy/_src/math/tests/test_delay_vars.py b/brainpy/_src/math/tests/test_delay_vars.py index e2955e437..8241005ee 100644 --- a/brainpy/_src/math/tests/test_delay_vars.py +++ b/brainpy/_src/math/tests/test_delay_vars.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import brainpy.math as bm -from brainpy._src.math.delayvars import ROTATION_UPDATING, CONCAT_UPDATING +from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE class TestTimeDelay(unittest.TestCase): @@ -94,7 +94,7 @@ def test_current_time2(self): class TestLengthDelay(unittest.TestCase): def test1(self): dim = 3 - for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]: + for update_method in [ROTATE_UPDATE, CONCAT_UPDATE]: delay = bm.LengthDelay(jnp.zeros(dim), 10, update_method=update_method) print(delay(1)) self.assertTrue(jnp.allclose(delay(1), jnp.zeros(dim))) @@ -105,7 +105,7 @@ def test1(self): def test2(self): dim = 3 - for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]: + for update_method in [ROTATE_UPDATE, CONCAT_UPDATE]: delay = bm.LengthDelay(jnp.zeros(dim), 10, # initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)), @@ -123,7 +123,7 @@ def test2(self): def test3(self): dim = 3 - for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]: + for update_method in [ROTATE_UPDATE, CONCAT_UPDATE]: delay = bm.LengthDelay(jnp.zeros(dim), 10, # initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)), diff --git a/brainpy/_src/running/runner.py b/brainpy/_src/running/runner.py index cad121c75..8ea7afd63 100644 --- a/brainpy/_src/running/runner.py +++ b/brainpy/_src/running/runner.py @@ -64,9 +64,6 @@ class Runner(BrainPyObject): jit: Dict[str, bool] '''Flag to denote whether to use JIT.''' - target: BrainPyObject - '''The target model to run.''' - def __init__( self, target: BrainPyObject, diff --git a/brainpy/_src/tools/dicts.py b/brainpy/_src/tools/dicts.py index d177daefe..75013b82b 100644 --- a/brainpy/_src/tools/dicts.py +++ b/brainpy/_src/tools/dicts.py @@ -61,6 +61,9 @@ def __init__(self, *args, **kwargs): self.__dict__ = self self.var_names = () + def copy(self) -> 'DotDict': + return type(self)(super().copy()) + def keys(self): """Retrieve all keys in the dict, excluding ignored keys.""" keys = [] diff --git a/brainpy/_src/train/back_propagation.py b/brainpy/_src/train/back_propagation.py index 13fe3e4e9..a9e852feb 100644 --- a/brainpy/_src/train/back_propagation.py +++ b/brainpy/_src/train/back_propagation.py @@ -5,17 +5,17 @@ from collections.abc import Iterable from functools import partial from typing import Union, Dict, Callable, Sequence, Any, Optional +from tqdm import tqdm import jax.numpy as jnp import numpy as np from jax.tree_util import tree_map -import brainpy.math as bm -import brainpy._src.optimizers as optim -from brainpy._src.math.object_transform.base import BrainPyObject import brainpy.losses as losses -from brainpy import tools +import brainpy.math as bm +from brainpy import tools, optim from brainpy._src.dyn.base import DynamicalSystem +from brainpy._src.math.object_transform.base import BrainPyObject from brainpy._src.running import constants as c from brainpy.check import serialize_kwargs from brainpy.errors import UnsupportedError, NoLongerSupportError @@ -49,6 +49,8 @@ class BPTrainer(DSTrainer): should be provided. loss_has_aux: bool To indicate whether the `loss_fun` returns auxiliary data. + loss_auto_run: bool + pass optimizer: optim.Optimizer The optimizer used for training. numpy_mon_after_run: bool @@ -72,7 +74,8 @@ def __init__( loss_fun: Union[str, Callable], # loss function optimizer: optim.Optimizer = None, # optimizer loss_has_aux: bool = False, # loss auxiliary - logger: Any = sys.stdout, + loss_auto_run: bool = True, # loss auxiliary + logger: Optional[Any] = None, # ------------- # API deprecated @@ -126,6 +129,7 @@ def __init__( raise UnsupportedError(f'Do not support {type(loss_fun)} to specify the loss function. ' f'We only support str and callable function.') self._loss_func = loss_fun + self.loss_auto_run = loss_auto_run # loss data self._report_train_metrics = dict() @@ -264,8 +268,10 @@ def fit( # training set fit_t0 = time.time() fit_epoch_metric = dict(loss=[]) - for x, y in (train_data() if callable(train_data) else train_data): + _training_data = train_data() if callable(train_data) else train_data + bar = tqdm(total=len(_training_data) if hasattr(_training_data, '__len__') else None) + for x, y in _training_data: # reset state if reset_state: self.target.reset_state(self._get_input_batch_size(x)) @@ -283,6 +289,7 @@ def fit( if k not in fit_epoch_metric: fit_epoch_metric[k] = [] fit_epoch_metric[k].append(v) + bar.update(1) # report fit_i += 1 @@ -297,9 +304,11 @@ def fit( report_train_metric[k].append(aux[k]) detailed_train_metric[k].extend(v) v.clear() - print((f'Train {fit_i} steps, use {fit_t + fit_t1 - fit_t0:.4f} s' + - ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))), - file=self.logger) + _report = (f'Train {fit_i} steps, use {fit_t + fit_t1 - fit_t0:.4f} s' + + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) + bar.set_description(_report, refresh=True) + if self.logger is not None: + self.logger.write(_report + '\n') if fun_after_report is not None: fun_after_report(fit_i, aux, 'fit') fit_t0 = time.time() @@ -316,20 +325,25 @@ def fit( report_train_metric[k].append(aux[k]) detailed_train_metric[k].extend(v) v.clear() - print((f'Train {epoch_idx} epoch, use {fit_t1 - fit_t0:.4f} s' + - ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))), - file=self.logger) + _report = (f'Train {epoch_idx} epoch, use {fit_t1 - fit_t0:.4f} s' + + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) + bar.set_description(_report, refresh=True) + if self.logger is not None: + self.logger.write(_report + '\n') if fun_after_report is not None: fun_after_report(epoch_idx, aux, 'fit') else: fit_t = time.time() - fit_t0 self.optimizer.lr.step_epoch() + bar.close() # testing set if test_data is not None: test_t0 = time.time() test_epoch_metric = dict(loss=[]) - for x, y in (test_data() if callable(test_data) else test_data): + _testing_data = test_data() if callable(test_data) else test_data + bar = tqdm(total=len(_testing_data) if hasattr(_testing_data, '__len__') else None) + for x, y in _testing_data: # reset state if reset_state: self.target.reset_state(self._get_input_batch_size(x)) @@ -350,6 +364,8 @@ def fit( else: test_epoch_metric['loss'].append(res) + bar.update(1) + # report test_i += 1 if num_report > 0 and test_i % num_report == 0: @@ -363,9 +379,11 @@ def fit( report_test_metric[k].append(aux[k]) detailed_test_metric[k].extend(v) v.clear() - print((f'Test {test_i} steps, use {test_t + test_t1 - test_t0:.4f} s' + - ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))), - file=self.logger) + _report = (f'Test {test_i} steps, use {test_t + test_t1 - test_t0:.4f} s' + + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) + bar.set_description(_report, refresh=True) + if self.logger is not None: + self.logger.write(_report + '\n') if fun_after_report is not None: fun_after_report(test_i, aux, 'test') test_t0 = time.time() @@ -382,14 +400,18 @@ def fit( report_test_metric[k].append(aux[k]) detailed_test_metric[k].extend(v) v.clear() - print((f'Test {epoch_idx} epoch, use {test_t1 - test_t0:.4f} s' + - ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))), - file=self.logger) + _report = (f'Test {epoch_idx} epoch, use {test_t1 - test_t0:.4f} s' + + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) + bar.set_description(_report, refresh=True) + if self.logger is not None: + self.logger.write(_report + '\n') if fun_after_report is not None: fun_after_report(epoch_idx, aux, 'test') else: test_t = time.time() - test_t0 + bar.close() + # finally self._report_train_metrics = {k: np.asarray(v) for k, v in report_train_metric.items()} self._detailed_train_metrics = {k: np.asarray(v) for k, v in detailed_train_metric.items()} @@ -492,7 +514,7 @@ class BPTT(BPTrainer): .. code-block:: python def loss_fun(predicts, targets): - return loss, {'acc': acc, 'spike_num': spike_num} + return loss, {'acc': acc, 'spike_num': spike_num} optimizer: Optimizer The optimizer used for training. Should be an instance of :py:class:`~.Optimizer`. numpy_mon_after_run: bool @@ -544,6 +566,7 @@ def _step_func_fit(self, shared_args, inputs, targets): def _step_func_predict(self, shared, x=None): assert self.data_first_axis == 'B', f'There is no time dimension when using the trainer {self.__class__.__name__}.' + bm.share.save_shargs(**shared) # input step self.target.clear_input() @@ -555,6 +578,7 @@ def _step_func_predict(self, shared, x=None): # monitor step mon = self._step_func_monitor(shared) + bm.share.remove_shargs(shared) return out, mon def _get_f_predict(self, shared_args: Dict = None, jit: bool = True): @@ -630,7 +654,7 @@ def predict( return (t1 - t0, outs) if eval_time else outs -class OnlineBPTT(BPTT): +class _OnlineBPTT(BPTT): def _step_func_loss(self, shared_args, t, i, input_, target_): outputs, mon = self._get_f_predict_one_step(shared_args)(t, i, input_) predicts = (outputs, mon) if len(mon) > 0 else outputs diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index 994eae584..69d534e4a 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -237,8 +237,7 @@ def _step_func_monitor(self, shared): def _check_interface(self): for node in self.train_nodes: - if hasattr(node.offline_fit, 'not_customized'): - if node.offline_fit.not_customized: + if not hasattr(node, 'offline_fit'): raise NoImplementationError( f''' The node diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index 3247b159b..7f22fbc3d 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -276,16 +276,14 @@ def _step_func_fit(self, shared_args, t, i, x, ys): def _check_interface(self): for node in self.train_nodes: - if hasattr(node.online_fit, 'not_customized'): - if node.online_fit.not_customized: + if not hasattr(node, 'online_fit'): raise NoImplementationError( f'The node \n\n{node}\n\n' f'is set to be trainable with {self.__class__.__name__} method. ' f'However, it does not implement the required training ' f'interface "online_fit()" function. ' ) - if hasattr(node.online_init, 'not_customized'): - if node.online_init.not_customized: + if not hasattr(node, 'online_init'): raise NoImplementationError( f'The node \n\n{node}\n\n' f'is set to be trainable with {self.__class__.__name__} method. ' diff --git a/brainpy/check.py b/brainpy/check.py index 31f9c22ac..6e7704e72 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -191,10 +191,10 @@ def is_dict_data(a_dict: Dict, key_type: Union[Type, Tuple[Type, ...]] = None, val_type: Union[Type, Tuple[Type, ...]] = None, name: str = None, - all_none: bool = True): + allow_none: bool = True): """Check the dictionary data. """ - if all_none and a_dict is None: + if allow_none and a_dict is None: return None name = '' if (name is None) else f'"{name}"' if not isinstance(a_dict, dict): diff --git a/brainpy/experimental.py b/brainpy/experimental.py new file mode 100644 index 000000000..f7540ea37 --- /dev/null +++ b/brainpy/experimental.py @@ -0,0 +1,30 @@ + +# synaptic delays +from brainpy._src.experimental.delay import ( + Delay as Delay, +) + +# synapse plasticity +from brainpy._src.experimental.synstp import ( + STP as STP, + STD as STD, +) + +# synapse outputs +from brainpy._src.experimental.synout import ( + COBA as COBA, + CUBA as CUBA, + MgBlock as MgBlock, +) + +# Synapses +from brainpy._src.experimental.synapses import ( + Exponential as Exponential, +) + +# neurons +from brainpy._src.experimental.neurons import ( + LIF as LIF, +) + + diff --git a/brainpy/math/compat_tensorflow.py b/brainpy/math/compat_tensorflow.py index 6a0bcb8d3..479dbedb4 100644 --- a/brainpy/math/compat_tensorflow.py +++ b/brainpy/math/compat_tensorflow.py @@ -19,6 +19,10 @@ unsorted_segment_max as unsorted_segment_max, unsorted_segment_min as unsorted_segment_min, unsorted_segment_mean as unsorted_segment_mean, + segment_sum as segment_sum, + segment_prod as segment_prod, + segment_max as segment_max, + segment_min as segment_min, clip_by_value as clip_by_value, cast as cast, ) diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index 691958840..bae3b9cc2 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -5,4 +5,6 @@ LengthDelay as LengthDelay, NeuTimeDelay as NeuTimeDelay, NeuLenDelay as NeuLenDelay, + ROTATE_UPDATE as ROTATE_UPDATE, + CONCAT_UPDATE as CONCAT_UPDATE, ) diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index 3c0730d72..c7e5df414 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from brainpy._src.math.environment import ( + share as share, set_float as set_float, get_float as get_float, set_int as set_int, diff --git a/brainpy/math/operators.py b/brainpy/math/operators.py index 0aabb9950..0779ece6d 100644 --- a/brainpy/math/operators.py +++ b/brainpy/math/operators.py @@ -28,10 +28,3 @@ csr_matvec as csr_matvec, event_csr_matvec as event_csr_matvec, ) - -from brainpy._src.math.operators.wrap_jax import ( - segment_sum as segment_sum, - segment_prod as segment_prod, - segment_max as segment_max, - segment_min as segment_min, -) diff --git a/docs/auto_generater.py b/docs/auto_generater.py index d6f4db9d9..c24b5ec1f 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -560,19 +560,23 @@ def generate_math_docs(): 'environment': ('Environment Settings', 'brainpy.math'), 'modes': ('Computing Modes', 'brainpy.math'), 'arrayinterporate': ('Array Interoperability', 'brainpy.math'), - # 'compat_numpy': ('Array Operators with NumPy Syntax', 'brainpy.math'), - # 'compat_pytorch': ('Array Operators with PyTorch Syntax', 'brainpy.math'), - # 'compat_tensorflow': ('Array Operators with TensorFlow Syntax', 'brainpy.math'), - 'surrogate': ('Surrogate Gradient Functions', 'brainpy.math.surrogate'), - 'random': ('Random Number Generations', 'brainpy.math.random'), - # 'linalg': ('Linear algebra', 'brainpy.math.linalg'), - # 'fft': ('Discrete Fourier Transform', 'brainpy.math.fft'), + 'compat_numpy': ('Array Operators with NumPy Syntax', 'brainpy.math'), + 'compat_pytorch': ('Array Operators with PyTorch Syntax', 'brainpy.math'), + 'compat_tensorflow': ('Array Operators with TensorFlow Syntax', 'brainpy.math'), + 'surrogate': ('``brainpy.math.surrogate`` module: Surrogate Gradient Functions', + 'brainpy.math.surrogate'), + 'random': ('``brainpy.math.random`` module: Random Number Generations', + 'brainpy.math.random'), + 'linalg': ('``brainpy.math.linalg`` module: Linear algebra', + 'brainpy.math.linalg'), + 'fft': ('``brainpy.math.fft`` module: Discrete Fourier Transform', + 'brainpy.math.fft'), } ) def generate_algorithm_docs(path='apis/auto/algorithms/'): - if not os.path.exists(path): os.makedirs(path) + os.makedirs(path, exist_ok=True) module_and_name = [ ('offline', 'Offline Training Algorithms'),