From 4d0c82a6b7e17e718e6e332b1a602561e0cccd45 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 14 Jul 2022 22:33:52 +0800 Subject: [PATCH 1/5] update apis and models --- brainpy/__init__.py | 12 +- brainpy/dyn/base.py | 38 ++- brainpy/dyn/channels/Ca.py | 28 +-- brainpy/dyn/channels/IH.py | 6 +- brainpy/dyn/channels/K.py | 24 +- brainpy/dyn/channels/KCa.py | 4 +- brainpy/dyn/channels/Na.py | 10 +- brainpy/dyn/channels/base.py | 4 +- brainpy/dyn/channels/leaky.py | 6 +- brainpy/dyn/layers/conv.py | 4 +- brainpy/dyn/layers/dense.py | 6 +- brainpy/dyn/layers/dropout.py | 8 +- brainpy/dyn/layers/nvar.py | 6 +- brainpy/dyn/layers/reservoir.py | 22 +- brainpy/dyn/layers/rnncells.py | 10 +- brainpy/dyn/neurons/biological_models.py | 22 +- brainpy/dyn/neurons/input_groups.py | 12 +- brainpy/dyn/neurons/noise_groups.py | 4 +- brainpy/dyn/neurons/reduced_models.py | 91 ++++--- brainpy/dyn/rates/populations.py | 14 +- brainpy/dyn/synapses/__init__.py | 2 +- brainpy/dyn/synapses/abstract_models.py | 20 +- brainpy/dyn/synapses/biological_models.py | 16 +- .../{couplings.py => delay_couplings.py} | 133 +++++----- brainpy/initialize/generic.py | 10 +- brainpy/math/delayvars.py | 9 +- brainpy/modes.py | 24 +- brainpy/train/offline.py | 4 +- brainpy/train/online.py | 4 +- .../Wang_2002_decision_making_spiking.py | 38 +-- .../whole_brain_simulation_with_fhn.py | 20 +- ...ole_brain_simulation_with_sl_oscillator.py | 14 +- ...Bellec_2020_eprop_evidence_accumulation.py | 234 +++++++++++++----- .../Gauthier_2021_ngrc_double_scroll.py | 4 +- .../training/Gauthier_2021_ngrc_lorenz.py | 1 + examples/training/echo_state_network.py | 6 +- extensions/setup.py | 2 +- extensions/setup_cuda.py | 2 +- extensions/setup_mac.py | 2 +- 39 files changed, 505 insertions(+), 371 deletions(-) rename brainpy/dyn/synapses/{couplings.py => delay_couplings.py} (64%) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index a16ae25b5..d96b67f3c 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -62,16 +62,16 @@ synouts, # synaptic output synplast, # synaptic plasticity ) -from .dyn.base import * -from .dyn.runners import * +# from .dyn.base import * +# from .dyn.runners import * # dynamics training from . import train -from .train.base import * -from .train.online import * -from .train.offline import * -from .train.back_propagation import * +# from .train.base import * +# from .train.online import * +# from .train.offline import * +# from .train.back_propagation import * # automatic dynamics analysis diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index d4cd0acd7..70041c8cc 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -16,7 +16,7 @@ from brainpy.errors import ModelBuildError, NoImplementationError from brainpy.initialize import Initializer, parameter, variable, Uniform, noise as init_noise from brainpy.integrators import odeint, sdeint -from brainpy.modes import Mode, Training, Batching, nonbatching, training +from brainpy.modes import Mode, TrainingMode, BatchingMode, normal, training from brainpy.tools.others import to_size, size2num from brainpy.types import Tensor, Shape @@ -60,14 +60,10 @@ def feedback(self): class DynamicalSystem(Base): """Base Dynamical System class. - Any object has step functions will be a dynamical system. - That is to say, in BrainPy, the essence of the dynamical system - is the "step functions". - Parameters ---------- - name : str, optional - The name of the dynamic system. + name : optional, str + The name of the dynamical system. mode: Mode The model computation mode. It should be instance of :py:class:`~.Mode`. """ @@ -87,7 +83,7 @@ class DynamicalSystem(Base): def __init__( self, name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(DynamicalSystem, self).__init__(name=name) @@ -105,7 +101,7 @@ def __init__( self.fit_record = dict() @property - def mode(self): + def mode(self) -> Mode: return self._mode @mode.setter @@ -366,7 +362,7 @@ def __init__( self, *ds_tuple, name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, **ds_dict ): super(Container, self).__init__(name=name, mode=mode) @@ -468,7 +464,7 @@ def __init__( self, *ds_tuple, name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, **ds_dict ): super(Network, self).__init__(*ds_tuple, @@ -556,7 +552,7 @@ def __init__( size: Shape, name: str = None, keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, ): # size if isinstance(size, (list, tuple)): @@ -625,7 +621,7 @@ def __init__( post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]] = None, name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(SynConn, self).__init__(name=name, mode=mode) @@ -776,7 +772,7 @@ def __init__( stp: Optional[SynSTP] = None, ltp: Optional[SynLTP] = None, name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(TwoEndConn, self).__init__(pre=pre, post=post, @@ -845,13 +841,13 @@ def init_weights( raise ValueError(f'Unknown connection type: {comp_method}') # training weights - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): weight = bm.TrainVar(weight) return weight, conn_mask def syn2post_with_all2all(self, syn_value, syn_weight): if bm.ndim(syn_weight) == 0: - if isinstance(self.mode, Batching): + if isinstance(self.mode, BatchingMode): post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) else: post_vs = bm.sum(syn_value) @@ -931,7 +927,7 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, **channels ): NeuGroup.__init__(self, size, keep_size=keep_size, mode=mode) @@ -947,7 +943,7 @@ def __init__( # variables self.V = variable(V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Batching) else bool + sp_type = bm.dftype() if isinstance(self.mode, BatchingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) # function @@ -965,7 +961,7 @@ def derivative(self, V, t): def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Batching) else bool + sp_type = bm.dftype() if isinstance(self.mode, BatchingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) @@ -993,7 +989,7 @@ def __init__( size: Union[int, Sequence[int]], name: str = None, keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(Channel, self).__init__(name=name, mode=mode) # the geometry size @@ -1049,7 +1045,7 @@ def __init__( self, *modules, name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, **kw_modules ): super(Sequential, self).__init__(*modules, name=name, mode=mode, **kw_modules) diff --git a/brainpy/dyn/channels/Ca.py b/brainpy/dyn/channels/Ca.py index 400042683..9a8efd4fb 100644 --- a/brainpy/dyn/channels/Ca.py +++ b/brainpy/dyn/channels/Ca.py @@ -13,7 +13,7 @@ from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, Batching, nonbatching +from brainpy.modes import Mode, BatchingMode, normal from .base import Calcium, CalciumChannel __all__ = [ @@ -50,7 +50,7 @@ def __init__( C: Union[float, Tensor, Initializer, Callable] = 2.4e-4, method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, **channels ): super(CalciumFixed, self).__init__(size, @@ -105,7 +105,7 @@ def __init__( C_initializer: Union[Initializer, Callable, Tensor] = OneInit(2.4e-4), method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, **channels ): super(CalciumDyna, self).__init__(size, @@ -124,7 +124,7 @@ def __init__( # variables self.C = variable(C_initializer, mode, self.varshape) # Calcium concentration self.E = bm.Variable(self._reversal_potential(self.C), - batch_axis=0 if isinstance(mode, Batching) else None) # Reversal potential + batch_axis=0 if isinstance(mode, BatchingMode) else None) # Reversal potential # function self.integral = odeint(self.derivative, method=method) @@ -271,7 +271,7 @@ def __init__( C_initializer: Union[Initializer, Callable, Tensor] = OneInit(2.4e-4), method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, **channels ): super(CalciumDetailed, self).__init__(size, @@ -315,7 +315,7 @@ def __init__( C_initializer: Union[Initializer, Callable, Tensor] = OneInit(2.4e-4), method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, **channels ): super(CalciumFirstOrder, self).__init__(size, @@ -382,7 +382,7 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = 3., g_max: Union[float, Tensor, Initializer, Callable] = 2., method: str = 'exp_auto', - mode: Mode = nonbatching, + mode: Mode = normal, name: str = None ): super(ICa_p2q_ss, self).__init__(size, @@ -476,7 +476,7 @@ def __init__( g_max: Union[float, Tensor, Initializer, Callable] = 2., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(ICa_p2q_markov, self).__init__(size, keep_size=keep_size, @@ -578,7 +578,7 @@ def __init__( phi: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(ICaN_IS2008, self).__init__(size, keep_size=keep_size, @@ -676,7 +676,7 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q @@ -774,7 +774,7 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q @@ -867,7 +867,7 @@ def __init__( V_sh: Union[float, Tensor, Initializer, Callable] = 25., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(ICaHT_HM1992, self).__init__(size, keep_size=keep_size, @@ -974,7 +974,7 @@ def __init__( V_sh: Union[float, Tensor, Initializer, Callable] = 0., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): phi_p = T_base_p ** ((T - 23.) / 10.) if phi_p is None else phi_p phi_q = T_base_q ** ((T - 23.) / 10.) if phi_q is None else phi_q @@ -1061,7 +1061,7 @@ def __init__( V_sh: Union[float, Tensor, Initializer, Callable] = 0., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(ICaL_IS2008, self).__init__(size, keep_size=keep_size, diff --git a/brainpy/dyn/channels/IH.py b/brainpy/dyn/channels/IH.py index 7e97b6d7f..4cb942416 100644 --- a/brainpy/dyn/channels/IH.py +++ b/brainpy/dyn/channels/IH.py @@ -11,7 +11,7 @@ from brainpy.initialize import Initializer, parameter, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, Batching, nonbatching +from brainpy.modes import Mode, BatchingMode, normal from .base import IhChannel, CalciumChannel, Calcium __all__ = [ @@ -63,7 +63,7 @@ def __init__( phi: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(Ih_HM1992, self).__init__(size, keep_size=keep_size, @@ -173,7 +173,7 @@ def __init__( phi: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): # IhChannel.__init__(self, size, name=name, keep_size=keep_size) CalciumChannel.__init__(self, diff --git a/brainpy/dyn/channels/K.py b/brainpy/dyn/channels/K.py index f13719d71..c9ea80168 100644 --- a/brainpy/dyn/channels/K.py +++ b/brainpy/dyn/channels/K.py @@ -11,7 +11,7 @@ from brainpy.initialize import Initializer, parameter, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, Batching, nonbatching +from brainpy.modes import Mode, BatchingMode, normal from .base import PotassiumChannel __all__ = [ @@ -76,7 +76,7 @@ def __init__( phi: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IK_p4_markov, self).__init__(size, keep_size=keep_size, @@ -175,7 +175,7 @@ def __init__( phi: Optional[Union[float, Tensor, Initializer, Callable]] = None, method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): phi = T_base ** ((T - 36) / 10) if phi is None else phi super(IKDR_Ba2002, self).__init__(size, @@ -250,7 +250,7 @@ def __init__( V_sh: Union[int, float, Tensor, Initializer, Callable] = -60., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IK_TM1991, self).__init__(size, keep_size=keep_size, @@ -321,7 +321,7 @@ def __init__( V_sh: Union[int, float, Tensor, Initializer, Callable] = -45., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IK_HH1952, self).__init__(size, keep_size=keep_size, @@ -393,7 +393,7 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IKA_p4q_ss, self).__init__(size, keep_size=keep_size, @@ -509,7 +509,7 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IKA1_HM1992, self).__init__(size, keep_size=keep_size, @@ -604,7 +604,7 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IKA2_HM1992, self).__init__(size, keep_size=keep_size, @@ -688,7 +688,7 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IKK2_pq_ss, self).__init__(size, keep_size=keep_size, @@ -800,7 +800,7 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IKK2A_HM1992, self).__init__(size, keep_size=keep_size, @@ -891,7 +891,7 @@ def __init__( phi_q: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IKK2B_HM1992, self).__init__(size, keep_size=keep_size, @@ -977,7 +977,7 @@ def __init__( V_sh: Union[float, Tensor, Initializer, Callable] = 0., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IKNI_Ya1989, self).__init__(size, keep_size=keep_size, diff --git a/brainpy/dyn/channels/KCa.py b/brainpy/dyn/channels/KCa.py index 44682932f..f413fef09 100644 --- a/brainpy/dyn/channels/KCa.py +++ b/brainpy/dyn/channels/KCa.py @@ -12,7 +12,7 @@ from brainpy.initialize import Initializer, parameter, variable from brainpy.integrators.ode import odeint from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, Batching, nonbatching +from brainpy.modes import Mode, BatchingMode, normal from .base import Calcium, CalciumChannel, PotassiumChannel __all__ = [ @@ -83,7 +83,7 @@ def __init__( phi: Union[float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): CalciumChannel.__init__(self, size=size, diff --git a/brainpy/dyn/channels/Na.py b/brainpy/dyn/channels/Na.py index 47baa932d..f225049f9 100644 --- a/brainpy/dyn/channels/Na.py +++ b/brainpy/dyn/channels/Na.py @@ -11,7 +11,7 @@ from brainpy.initialize import Initializer, parameter, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Tensor, Shape -from brainpy.modes import Mode, Batching, nonbatching +from brainpy.modes import Mode, BatchingMode, normal from .base import SodiumChannel __all__ = [ @@ -61,7 +61,7 @@ def __init__( phi: Union[int, float, Tensor, Initializer, Callable] = 1., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(INa_p3q_markov, self).__init__(size=size, keep_size=keep_size, @@ -171,7 +171,7 @@ def __init__( V_sh: Union[int, float, Tensor, Initializer, Callable] = -50., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(INa_Ba2002, self).__init__(size, keep_size=keep_size, @@ -258,7 +258,7 @@ def __init__( V_sh: Union[int, float, Tensor, Initializer, Callable] = -63., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(INa_TM1991, self).__init__(size, keep_size=keep_size, @@ -345,7 +345,7 @@ def __init__( V_sh: Union[int, float, Tensor, Initializer, Callable] = -45., method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(INa_HH1952, self).__init__(size, keep_size=keep_size, diff --git a/brainpy/dyn/channels/base.py b/brainpy/dyn/channels/base.py index a71bae257..08987cda2 100644 --- a/brainpy/dyn/channels/base.py +++ b/brainpy/dyn/channels/base.py @@ -5,7 +5,7 @@ import brainpy.math as bm from brainpy.dyn.base import Container, CondNeuGroup, Channel, check_master from brainpy.types import Shape -from brainpy.modes import nonbatching, Mode +from brainpy.modes import normal, Mode __all__ = [ 'Ion', 'IonChannel', @@ -92,7 +92,7 @@ def __init__( keep_size: bool = False, method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, **channels ): Ion.__init__(self, size, keep_size=keep_size, mode=mode) diff --git a/brainpy/dyn/channels/leaky.py b/brainpy/dyn/channels/leaky.py index c546f8a0f..2eb67cdff 100644 --- a/brainpy/dyn/channels/leaky.py +++ b/brainpy/dyn/channels/leaky.py @@ -9,7 +9,7 @@ from brainpy.initialize import Initializer, parameter from brainpy.types import Tensor, Shape -from brainpy.modes import Mode, Batching, nonbatching +from brainpy.modes import Mode, BatchingMode, normal from .base import LeakyChannel @@ -38,7 +38,7 @@ def __init__( E: Union[int, float, Tensor, Initializer, Callable] = -70., method: str = None, name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IL, self).__init__(size, keep_size=keep_size, @@ -79,7 +79,7 @@ def __init__( E: Union[int, float, Tensor, Initializer, Callable] = -90., method: str = None, name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): super(IKL, self).__init__(size=size, keep_size=keep_size, diff --git a/brainpy/dyn/layers/conv.py b/brainpy/dyn/layers/conv.py index 4d6e5cbd2..37bee6a02 100644 --- a/brainpy/dyn/layers/conv.py +++ b/brainpy/dyn/layers/conv.py @@ -6,7 +6,7 @@ import brainpy.math as bm from brainpy.dyn.base import DynamicalSystem from brainpy.initialize import XavierNormal, ZeroInit, parameter -from brainpy.modes import Mode, Training, training +from brainpy.modes import Mode, TrainingMode, training __all__ = [ 'GeneralConv', @@ -117,7 +117,7 @@ def __init__( kernel_shape = _check_tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels) self.w = parameter(self.w_init, kernel_shape) self.b = parameter(self.b_init, (1,) * len(self.kernel_size) + (self.out_channels,)) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.w = bm.TrainVar(self.w) self.b = bm.TrainVar(self.b) diff --git a/brainpy/dyn/layers/dense.py b/brainpy/dyn/layers/dense.py index 42697b317..73a13aa48 100644 --- a/brainpy/dyn/layers/dense.py +++ b/brainpy/dyn/layers/dense.py @@ -9,7 +9,7 @@ from brainpy.dyn.base import DynamicalSystem from brainpy.errors import MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter -from brainpy.modes import Mode, Training, training +from brainpy.modes import Mode, TrainingMode, training from brainpy.tools.checking import check_initializer from brainpy.types import Tensor @@ -37,7 +37,7 @@ class Dense(DynamicalSystem): The weight initialization. b_initializer: optional, Initializer The bias initialization. - trainable: bool + mode: Mode Enable training this node or not. (default True) """ @@ -71,7 +71,7 @@ def __init__( # parameter initialization self.W = parameter(self.weight_initializer, (num_in, self.num_out)) self.b = parameter(self.bias_initializer, (self.num_out,)) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.W = bm.TrainVar(self.W) self.b = None if (self.b is None) else bm.TrainVar(self.b) diff --git a/brainpy/dyn/layers/dropout.py b/brainpy/dyn/layers/dropout.py index e1e8d8cb3..542844006 100644 --- a/brainpy/dyn/layers/dropout.py +++ b/brainpy/dyn/layers/dropout.py @@ -38,7 +38,13 @@ class Dropout(DynamicalSystem): research 15.1 (2014): 1929-1958. """ - def __init__(self, prob, seed=None, mode: Mode = training, name=None): + def __init__( + self, + prob: float, + seed: int = None, + mode: Mode = training, + name: str = None + ): super(Dropout, self).__init__(mode=mode, name=name) self.prob = prob self.rng = bm.random.RandomState(seed=seed) diff --git a/brainpy/dyn/layers/nvar.py b/brainpy/dyn/layers/nvar.py index 0ac001ead..f373e1e82 100644 --- a/brainpy/dyn/layers/nvar.py +++ b/brainpy/dyn/layers/nvar.py @@ -8,7 +8,7 @@ import brainpy.math as bm from brainpy.dyn.base import DynamicalSystem -from brainpy.modes import Mode, Batching, batching +from brainpy.modes import Mode, BatchingMode, batching from brainpy.tools.checking import (check_integer, check_sequence) __all__ = [ @@ -93,7 +93,7 @@ def __init__( # delay variables self.idx = bm.Variable(jnp.asarray([0])) - if isinstance(self.mode, Batching): + if isinstance(self.mode, BatchingMode): batch_size = 1 # first initialize the state with batch size = 1 self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)), batch_axis=1) else: @@ -135,7 +135,7 @@ def update(self, sha, x): # 1. Store the current input self.store[self.idx[0]] = x - if isinstance(self.mode, Batching): + if isinstance(self.mode, BatchingMode): # 2. Linear part: # select all previous inputs, including the current, with strides linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature) diff --git a/brainpy/dyn/layers/reservoir.py b/brainpy/dyn/layers/reservoir.py index f378d4cf2..ff6d17625 100644 --- a/brainpy/dyn/layers/reservoir.py +++ b/brainpy/dyn/layers/reservoir.py @@ -5,7 +5,7 @@ import brainpy.math as bm from brainpy.dyn.base import DynamicalSystem from brainpy.initialize import Normal, ZeroInit, Initializer, parameter, variable -from brainpy.modes import Mode, Training, batching +from brainpy.modes import Mode, TrainingMode, batching from brainpy.tools.checking import check_float, check_initializer, check_string from brainpy.tools.others import to_size from brainpy.types import Tensor @@ -64,7 +64,7 @@ class Reservoir(DynamicalSystem): Connectivity of recurrent weights matrix, i.e. ratio of reservoir neurons connected to other reservoir neurons, including themselves. Must be in [0, 1], by default 0.1 - conn_type: str + comp_type: str The connectivity type, can be "dense" or "sparse". spectral_radius : float, optional Spectral radius of recurrent weight matrix, by default None @@ -97,7 +97,7 @@ def __init__( b_initializer: Optional[Union[Initializer, Callable, Tensor]] = ZeroInit(), in_connectivity: float = 0.1, rec_connectivity: float = 0.1, - conn_type='dense', + comp_type='dense', spectral_radius: Optional[float] = None, noise_in: float = 0., noise_rec: float = 0., @@ -138,8 +138,8 @@ def __init__( check_float(rec_connectivity, 'rec_connectivity', 0., 1.) self.ff_connectivity = in_connectivity self.rec_connectivity = rec_connectivity - check_string(conn_type, 'conn_type', ['dense', 'sparse']) - self.conn_type = conn_type + check_string(comp_type, 'conn_type', ['dense', 'sparse']) + self.comp_type = comp_type # noises check_float(noise_in, 'noise_ff') @@ -156,10 +156,10 @@ def __init__( if self.ff_connectivity < 1.: conn_mat = self.rng.random(weight_shape) > self.ff_connectivity self.Win[conn_mat] = 0. - if self.conn_type == 'sparse' and self.ff_connectivity < 1.: + if self.comp_type == 'sparse' and self.ff_connectivity < 1.: self.ff_pres, self.ff_posts = bm.where(bm.logical_not(conn_mat)) self.Win = self.Win[self.ff_pres, self.ff_posts] - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.Win = bm.TrainVar(self.Win) # initialize recurrent weights @@ -171,11 +171,11 @@ def __init__( if self.spectral_radius is not None: current_sr = max(abs(bm.linalg.eig(self.Wrec)[0])) self.Wrec *= self.spectral_radius / current_sr - if self.conn_type == 'sparse' and self.rec_connectivity < 1.: + if self.comp_type == 'sparse' and self.rec_connectivity < 1.: self.rec_pres, self.rec_posts = bm.where(bm.logical_not(conn_mat)) self.Wrec = self.Wrec[self.rec_pres, self.rec_posts] self.bias = parameter(self._b_initializer, (self.num_unit,)) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.Wrec = bm.TrainVar(self.Wrec) self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) @@ -190,7 +190,7 @@ def update(self, sha, x): # inputs x = bm.concatenate(x, axis=-1) if self.noise_ff > 0: x += self.noise_ff * self.rng.uniform(-1, 1, x.shape) - if self.conn_type == 'sparse' and self.ff_connectivity < 1.: + if self.comp_type == 'sparse' and self.ff_connectivity < 1.: sparse = {'data': self.Win, 'index': (self.ff_pres, self.ff_posts), 'shape': self.Wff_shape} @@ -198,7 +198,7 @@ def update(self, sha, x): else: hidden = bm.dot(x, self.Win) # recurrent - if self.conn_type == 'sparse' and self.rec_connectivity < 1.: + if self.comp_type == 'sparse' and self.rec_connectivity < 1.: sparse = {'data': self.Wrec, 'index': (self.rec_pres, self.rec_posts), 'shape': (self.num_unit, self.num_unit)} diff --git a/brainpy/dyn/layers/rnncells.py b/brainpy/dyn/layers/rnncells.py index a4bfde547..e31fab454 100644 --- a/brainpy/dyn/layers/rnncells.py +++ b/brainpy/dyn/layers/rnncells.py @@ -11,7 +11,7 @@ parameter, variable, Initializer) -from brainpy.modes import Mode, Training, training +from brainpy.modes import Mode, TrainingMode, training from brainpy.tools.checking import (check_integer, check_initializer) from brainpy.types import Tensor @@ -41,7 +41,7 @@ def __init__(self, # state self.state = variable(bm.zeros, mode, self.num_out) - if train_state and isinstance(self.mode, Training): + if train_state and isinstance(self.mode, TrainingMode): self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False)) self.state[:] = self.state2train @@ -123,7 +123,7 @@ def __init__( self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out)) self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out)) self.b = parameter(self._b_initializer, (self.num_out,)) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.Wi = bm.TrainVar(self.Wi) self.Wh = bm.TrainVar(self.Wh) self.b = None if (self.b is None) else bm.TrainVar(self.b) @@ -221,7 +221,7 @@ def __init__( self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 3)) self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 3)) self.b = parameter(self._b_initializer, (self.num_out * 3,)) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.Wi = bm.TrainVar(self.Wi) self.Wh = bm.TrainVar(self.Wh) self.b = bm.TrainVar(self.b) if (self.b is not None) else None @@ -345,7 +345,7 @@ def __init__( self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 4)) self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 4)) self.b = parameter(self._b_initializer, (self.num_out * 4,)) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.Wi = bm.TrainVar(self.Wi) self.Wh = bm.TrainVar(self.Wh) self.b = None if (self.b is None) else bm.TrainVar(self.b) diff --git a/brainpy/dyn/neurons/biological_models.py b/brainpy/dyn/neurons/biological_models.py index 9189171d9..531a5efd6 100644 --- a/brainpy/dyn/neurons/biological_models.py +++ b/brainpy/dyn/neurons/biological_models.py @@ -8,9 +8,9 @@ from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint from brainpy.integrators.sde import sdeint +from brainpy.modes import Mode, BatchingMode, normal from brainpy.tools.checking import check_initializer from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, Batching, Training, nonbatching, batching, training __all__ = [ 'HH', @@ -212,7 +212,7 @@ def __init__( name: str = None, # training parameter - mode: Mode = nonbatching, + mode: Mode = normal, ): # initialization super(HH, self).__init__(size=size, @@ -407,7 +407,7 @@ def __init__( name: str = None, # training parameter - mode: Mode = nonbatching, + mode: Mode = normal, ): # initialization super(MorrisLecar, self).__init__(size=size, @@ -664,7 +664,7 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): # initialization super(PinskyRinzelModel, self).__init__(size=size, @@ -706,11 +706,11 @@ def __init__( self.Vs = variable(self._Vs_initializer, mode, self.varshape) self.Vd = variable(self._Vd_initializer, mode, self.varshape) self.Ca = variable(self._Ca_initializer, mode, self.varshape) - self.h = bm.Variable(self.inf_h(self.Vs), batch_axis=0 if isinstance(mode, Batching) else None) - self.n = bm.Variable(self.inf_n(self.Vs), batch_axis=0 if isinstance(mode, Batching) else None) - self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if isinstance(mode, Batching) else None) - self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if isinstance(mode, Batching) else None) - self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if isinstance(mode, Batching) else None) + self.h = bm.Variable(self.inf_h(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None) + self.n = bm.Variable(self.inf_n(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None) + self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None) + self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None) + self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if isinstance(mode, BatchingMode) else None) self.Id = variable(bm.zeros, mode, self.varshape) # input to soma self.Is = variable(bm.zeros, mode, self.varshape) # input to dendrite # self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool)) @@ -725,7 +725,7 @@ def reset_state(self, batch_size=None): self.Vd.value = variable(self._Vd_initializer, batch_size, self.varshape) self.Vs.value = variable(self._Vs_initializer, batch_size, self.varshape) self.Ca.value = variable(self._Ca_initializer, batch_size, self.varshape) - batch_axis = 0 if isinstance(self.mode, Batching) else None + batch_axis = 0 if isinstance(self.mode, BatchingMode) else None self.h.value = bm.Variable(self.inf_h(self.Vs), batch_axis=batch_axis) self.n.value = bm.Variable(self.inf_n(self.Vs), batch_axis=batch_axis) self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis) @@ -973,7 +973,7 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, - mode: Mode = nonbatching, + mode: Mode = normal, ): # initialization super(WangBuzsakiModel, self).__init__(size=size, keep_size=keep_size, name=name, mode=mode) diff --git a/brainpy/dyn/neurons/input_groups.py b/brainpy/dyn/neurons/input_groups.py index 80520091e..2db392e02 100644 --- a/brainpy/dyn/neurons/input_groups.py +++ b/brainpy/dyn/neurons/input_groups.py @@ -8,8 +8,8 @@ from brainpy.dyn.base import NeuGroup from brainpy.errors import ModelBuildError from brainpy.initialize import Initializer, parameter, variable +from brainpy.modes import Mode, BatchingMode, normal from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, Batching, Training, nonbatching, batching, training __all__ = [ 'InputGroup', @@ -23,7 +23,7 @@ def __init__( self, size: Shape, keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, name: str = None, ): super(InputGroup, self).__init__(name=name, @@ -74,7 +74,7 @@ def __init__( indices: Union[Sequence, Tensor], need_sort: bool = True, keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, name: str = None ): super(SpikeTimeGroup, self).__init__(size=size, @@ -109,7 +109,7 @@ def cond_fun(t): def body_fun(t): i = self.i[0] - if isinstance(self.mode, Batching): + if isinstance(self.mode, BatchingMode): self.spike[:, self.indices[i]] = True else: self.spike[self.indices[i]] = True @@ -136,7 +136,7 @@ def __init__( freqs: Union[int, float, jnp.ndarray, bm.JaxArray, Initializer], seed: int = None, keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, name: str = None ): super(PoissonGroup, self).__init__(size=size, @@ -154,7 +154,7 @@ def __init__( self.rng = bm.random.RandomState(seed=seed) def update(self, tdi, x=None): - shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, Batching) else self.varshape + shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, BatchingMode) else self.varshape self.spike.update(self.rng.random(shape) <= (self.freqs * tdi['dt'] / 1000.)) def reset(self, batch_size=None): diff --git a/brainpy/dyn/neurons/noise_groups.py b/brainpy/dyn/neurons/noise_groups.py index 5fe5527f4..496f7ff8e 100644 --- a/brainpy/dyn/neurons/noise_groups.py +++ b/brainpy/dyn/neurons/noise_groups.py @@ -6,8 +6,8 @@ from brainpy.dyn.base import NeuGroup from brainpy.initialize import Initializer from brainpy.integrators.sde import sdeint +from brainpy.modes import Mode, normal from brainpy.types import Tensor, Shape -from brainpy.modes import Mode, Batching, Training, nonbatching, batching, training __all__ = [ 'OUProcess', @@ -51,7 +51,7 @@ def __init__( tau: Union[float, Tensor, Initializer, Callable] = 10., method: str = 'euler', keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, name: str = None, ): super(OUProcess, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode) diff --git a/brainpy/dyn/neurons/reduced_models.py b/brainpy/dyn/neurons/reduced_models.py index 9e6d0f492..834292844 100644 --- a/brainpy/dyn/neurons/reduced_models.py +++ b/brainpy/dyn/neurons/reduced_models.py @@ -9,9 +9,9 @@ from brainpy.initialize import (ZeroInit, OneInit, Initializer, parameter, variable, noise as init_noise) from brainpy.integrators import sdeint, odeint, JointEq +from brainpy.modes import Mode, BatchingMode, TrainingMode, normal from brainpy.tools.checking import check_initializer, check_callable from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, Batching, Training, nonbatching, batching, training __all__ = [ 'LeakyIntegrator', @@ -66,6 +66,7 @@ class LeakyIntegrator(NeuGroup): def __init__( self, + # neuron group size size: Shape, keep_size: bool = False, @@ -77,11 +78,9 @@ def __init__( V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), noise: Union[float, Tensor, Initializer, Callable] = None, - # training parameters - mode: Mode = nonbatching, - # other parameter name: str = None, + mode: Mode = normal, method: str = 'exp_auto', ): super(LeakyIntegrator, self).__init__(size=size, @@ -196,7 +195,7 @@ def __init__( name: str = None, # training parameter - mode: Mode = nonbatching, + mode: Mode = normal, spike_fun: Callable = bm.spike_with_sigmoid_grad, ): # initialization @@ -222,7 +221,7 @@ def __init__( # variables self.V = variable(self._V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(mode, Training) else bool # the gradient of spike is a float + sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) if self.tau_ref is not None: self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) @@ -240,7 +239,7 @@ def derivative(self, V, t, I_ext): def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) if self.tau_ref is not None: self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) @@ -256,12 +255,12 @@ def update(self, tdi, x=None): if self.tau_ref is not None: # refractory refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): refractory = stop_gradient(refractory) V = bm.where(refractory, self.V, V) # spike, refractory, spiking time, and membrane potential reset - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): spike = self.spike_fun(V - self.V_th) spike_no_grad = stop_gradient(spike) V += (self.V_reset - V) * spike_no_grad @@ -281,7 +280,7 @@ def update(self, tdi, x=None): else: # spike, spiking time, and membrane potential reset - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): spike = self.spike_fun(V - self.V_th) spike_no_grad = stop_gradient(spike) V += (self.V_reset - V) * spike_no_grad @@ -407,7 +406,7 @@ def __init__( V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), noise: Union[float, Tensor, Initializer, Callable] = None, keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, method: str = 'exp_auto', name: str = None ): @@ -435,7 +434,7 @@ def __init__( # variables self.V = variable(V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) if self.tau_ref is not None: @@ -450,7 +449,7 @@ def __init__( def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) if self.tau_ref is not None: @@ -575,7 +574,7 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, name: str = None ): super(AdExIF, self).__init__(size=size, @@ -606,7 +605,7 @@ def __init__( self.V = variable(V_initializer, mode, self.varshape) self.w = variable(w_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(mode, Batching) else bool + sp_type = bm.dftype() if isinstance(mode, BatchingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) # functions @@ -619,7 +618,7 @@ def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.w.value = variable(self._w_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dV(self, V, t, w, I_ext): @@ -727,7 +726,7 @@ def __init__( V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), noise: Union[float, Tensor, Initializer, Callable] = None, keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, method: str = 'exp_auto', name: str = None ): @@ -755,7 +754,7 @@ def __init__( # variables self.V = variable(V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) if self.tau_ref is not None: @@ -770,7 +769,7 @@ def __init__( def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) if self.tau_ref is not None: @@ -895,7 +894,7 @@ def __init__( noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, - mode: Mode = nonbatching, + mode: Mode = normal, name: str = None ): super(AdQuaIF, self).__init__(size=size, @@ -925,7 +924,7 @@ def __init__( self.V = variable(V_initializer, mode, self.varshape) self.w = variable(w_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) @@ -939,7 +938,7 @@ def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.w.value = variable(self._w_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) @@ -1075,7 +1074,7 @@ def __init__( name: str = None, # parameter for training - mode: Mode = nonbatching, + mode: Mode = normal, spike_fun: Callable = bm.spike_with_sigmoid_grad, ): # initialization @@ -1118,7 +1117,7 @@ def __init__( self.V_th = variable(Vth_initializer, mode, self.varshape) self.V = variable(V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) # integral @@ -1133,7 +1132,7 @@ def reset_state(self, batch_size=None): self.V_th.value = variable(self._Vth_initializer, batch_size, self.varshape) self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dI1(self, I1, t): @@ -1160,7 +1159,7 @@ def update(self, tdi, x=None): I1, I2, V_th, V = self.integral(self.I1, self.I2, self.V_th, self.V, t, self.input, dt=dt) # spike and resets - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): spike = self.spike_fun(V - self.V_th) V += (self.V_reset - V) * spike I1 += spike * (self.R1 * I1 + self.A1 - I1) @@ -1235,12 +1234,12 @@ def __init__( a_initializer: Union[Initializer, Callable, Tensor] = OneInit(-50.), # parameter for training - mode: Mode = nonbatching, spike_fun: Callable = bm.spike_with_relu_grad, # other parameters method: str = 'exp_auto', name: str = None, + mode: Mode = normal, ): super(ALIFBellec2020, self).__init__(name=name, size=size, @@ -1268,7 +1267,7 @@ def __init__( self.a = variable(a_initializer, mode, self.varshape) self.V = variable(V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) if self.tau_ref is not None: self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) @@ -1294,7 +1293,7 @@ def reset_state(self, batch_size=None): self.a.value = variable(self._a_initializer, batch_size, self.varshape) self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) if self.tau_ref is not None: self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) @@ -1310,11 +1309,11 @@ def update(self, tdi, x=None): if self.tau_ref is not None: # refractory refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): refractory = stop_gradient(refractory) V = bm.where(refractory, self.V, V) # spike and reset - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): spike = self.spike_fun(V - self.V_th_reset - self.beta * self.a) spike_no_grad = stop_gradient(spike) V -= self.V_th_reset * spike_no_grad @@ -1333,7 +1332,7 @@ def update(self, tdi, x=None): else: # spike and reset - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): spike = self.spike_fun(V - self.V_th_reset - self.beta * self.a) V -= self.V_th_reset * stop_gradient(spike) else: @@ -1429,7 +1428,7 @@ def __init__( u_initializer: Union[Initializer, Callable, Tensor] = OneInit(), noise: Union[float, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', - mode: Mode = nonbatching, + mode: Mode = normal, spike_fun: Callable = bm.spike_with_sigmoid_grad, keep_size: bool = False, name: str = None @@ -1460,7 +1459,7 @@ def __init__( self.u = variable(u_initializer, mode, self.varshape) self.V = variable(V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) if self.tau_ref is not None: self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape) @@ -1476,7 +1475,7 @@ def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.u.value = variable(self._u_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) if self.tau_ref is not None: self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape) @@ -1499,12 +1498,12 @@ def update(self, tdi, x=None): if self.tau_ref is not None: refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): refractory = stop_gradient(refractory) V = bm.where(refractory, self.V, V) # spike, refractory, and reset membrane potential - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): spike = self.spike_fun(V - self.V_th) spike_no_grad = stop_gradient(spike) V += spike_no_grad * (self.c - self.V_th) @@ -1523,7 +1522,7 @@ def update(self, tdi, x=None): else: # spike, refractory, and reset membrane potential - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): spike = self.spike_fun(V - self.V_th) spike_no_grad = stop_gradient(spike) V += spike_no_grad * (self.c - self.V_th) @@ -1658,7 +1657,7 @@ def __init__( name: str = None, # parameters for training - mode: Mode = nonbatching, + mode: Mode = normal, spike_fun: Callable = bm.spike2_with_sigmoid_grad, ): # initialization @@ -1692,7 +1691,7 @@ def __init__( self.y = variable(self._y_initializer, mode, self.varshape) self.z = variable(self._z_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) # integral @@ -1706,7 +1705,7 @@ def reset_state(self, batch_size=None): self.y.value = variable(self._y_initializer, batch_size, self.varshape) self.z.value = variable(self._z_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dV(self, V, t, y, z, I_ext): @@ -1726,7 +1725,7 @@ def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt if x is not None: self.input += x V, y, z = self.integral(self.V, self.y, self.z, t, self.input, dt=dt) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.spike.value = self.spike_fun(V - self.V_th, self.V - self.V_th) else: self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) @@ -1834,7 +1833,7 @@ def __init__( name: str = None, # parameters for training - mode: Mode = nonbatching, + mode: Mode = normal, spike_fun: Callable = bm.spike2_with_sigmoid_grad, ): # initialization @@ -1861,7 +1860,7 @@ def __init__( self.V = variable(self._V_initializer, mode, self.varshape) self.w = variable(self._w_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) # integral @@ -1874,7 +1873,7 @@ def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.w.value = variable(self._w_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - sp_type = bm.dftype() if isinstance(self.mode, Training) else bool + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dV(self, V, t, w, I_ext): @@ -1891,7 +1890,7 @@ def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt if x is not None: self.input += x V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt=dt) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.spike.value = self.spike_fun(V - self.Vth, self.V - self.Vth) else: self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py index 92853f767..235fdc59f 100644 --- a/brainpy/dyn/rates/populations.py +++ b/brainpy/dyn/rates/populations.py @@ -8,10 +8,10 @@ from brainpy.initialize import Initializer, Uniform, parameter, variable, ZeroInit from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint +from brainpy.modes import Mode, normal from brainpy.tools.checking import check_float, check_initializer from brainpy.tools.errors import check_error_in_jit from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, Batching, Training, nonbatching, batching, training __all__ = [ 'RateModel', @@ -91,7 +91,7 @@ def __init__( name: str = None, # parameter for training - mode: Mode = nonbatching, + mode: Mode = normal, ): super(FHN, self).__init__(size=size, name=name, @@ -272,7 +272,7 @@ def __init__( dt: float = None, # parameter for training - mode: Mode = nonbatching, + mode: Mode = normal, ): super(FeedbackFHN, self).__init__(size=size, name=name, @@ -467,7 +467,7 @@ def __init__( name: str = None, # parameter for training - mode: Mode = nonbatching, + mode: Mode = normal, ): super(QIF, self).__init__(size=size, name=name, @@ -606,7 +606,7 @@ def __init__( name: str = None, # parameter for training - mode: Mode = nonbatching, + mode: Mode = normal, ): super(StuartLandauOscillator, self).__init__(size=size, name=name, @@ -759,7 +759,7 @@ def __init__( name: str = None, # parameter for training - mode: Mode = nonbatching, + mode: Mode = normal, ): super(WilsonCowanModel, self).__init__(size=size, name=name, keep_size=keep_size) @@ -913,7 +913,7 @@ def __init__( name: str = None, # parameter for training - mode: Mode = nonbatching, + mode: Mode = normal, ): super(ThresholdLinearModel, self).__init__(size, name=name, diff --git a/brainpy/dyn/synapses/__init__.py b/brainpy/dyn/synapses/__init__.py index 9387f94db..ca2960417 100644 --- a/brainpy/dyn/synapses/__init__.py +++ b/brainpy/dyn/synapses/__init__.py @@ -4,7 +4,7 @@ from .biological_models import * from .learning_rules import * from .gap_junction import * -from .couplings import * +from .delay_couplings import * # compatible interface from . import compat diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index d270b05b5..347bd3e6f 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -11,7 +11,7 @@ from brainpy.initialize import Initializer, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Tensor -from brainpy.modes import Mode, Batching, Training, nonbatching, batching, training +from brainpy.modes import Mode, BatchingMode, TrainingMode, normal, batching, training from ..synouts import CUBA, MgBlock __all__ = [ @@ -101,7 +101,7 @@ def __init__( name: str = None, # training parameters - mode: Mode = nonbatching, + mode: Mode = normal, stop_spike_gradient: bool = False, ): super(Delta, self).__init__(name=name, @@ -153,7 +153,7 @@ def update(self, tdi, pre_spike=None): else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max) - if isinstance(self.mode, Batching): f = vmap(f) + if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(pre_spike) # if not isinstance(self.stp, _NullSynSTP): # raise NotImplementedError() @@ -283,7 +283,7 @@ def __init__( method: str = 'exp_auto', # training parameters - mode: Mode = nonbatching, + mode: Mode = normal, stop_spike_gradient: bool = False, ): super(Exponential, self).__init__(pre=pre, @@ -341,7 +341,7 @@ def update(self, tdi, pre_spike=None): else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max) - if isinstance(self.mode, Batching): f = vmap(f) + if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(pre_spike) # if not isinstance(self.stp, _NullSynSTP): # raise NotImplementedError() @@ -468,7 +468,7 @@ def __init__( name: str = None, # training parameters - mode: Mode = nonbatching, + mode: Mode = normal, stop_spike_gradient: bool = False, ): super(DualExponential, self).__init__(pre=pre, @@ -543,7 +543,7 @@ def update(self, tdi, pre_spike=None): else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) - if isinstance(self.mode, Batching): f = vmap(f) + if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(syn_value) else: post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) @@ -648,7 +648,7 @@ def __init__( name: str = None, # training parameters - mode: Mode = nonbatching, + mode: Mode = normal, stop_spike_gradient: bool = False, ): super(Alpha, self).__init__(pre=pre, @@ -834,7 +834,7 @@ def __init__( name: str = None, # training parameters - mode: Mode = nonbatching, + mode: Mode = normal, stop_spike_gradient: bool = False, # deprecated @@ -946,7 +946,7 @@ def update(self, tdi, pre_spike=None): else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) - if isinstance(self.mode, Batching): f = vmap(f) + if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(syn_value) else: post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) diff --git a/brainpy/dyn/synapses/biological_models.py b/brainpy/dyn/synapses/biological_models.py index dc3cea591..48269620a 100644 --- a/brainpy/dyn/synapses/biological_models.py +++ b/brainpy/dyn/synapses/biological_models.py @@ -13,7 +13,7 @@ from brainpy.initialize import Initializer, variable from brainpy.integrators import odeint, JointEq from brainpy.types import Tensor -from brainpy.modes import Mode, Batching, Training, nonbatching, batching, training +from brainpy.modes import Mode, BatchingMode, TrainingMode, normal, batching, training __all__ = [ 'AMPA', @@ -153,7 +153,7 @@ def __init__( name: str = None, # training parameters - mode: Mode = nonbatching, + mode: Mode = normal, stop_spike_gradient: bool = False, # deprecated @@ -225,7 +225,7 @@ def update(self, tdi, pre_spike=None): # update synaptic variables self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value) TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T self.g.value = self.integral(self.g, t, TT, dt) @@ -240,7 +240,7 @@ def update(self, tdi, pre_spike=None): else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) - if isinstance(self.mode, Batching): f = vmap(f) + if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(syn_value) else: post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) @@ -336,7 +336,7 @@ def __init__( name: str = None, # training parameters - mode: Mode = nonbatching, + mode: Mode = normal, stop_spike_gradient: bool = False, # deprecated @@ -506,7 +506,7 @@ def __init__( name: str = None, # training parameters - mode: Mode = nonbatching, + mode: Mode = normal, stop_spike_gradient: bool = False, ): super(BioNMDA, self).__init__(pre=pre, @@ -580,7 +580,7 @@ def update(self, tdi, pre_spike=None): # update synapse variables self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) - if isinstance(self.mode, Training): + if isinstance(self.mode, TrainingMode): self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value) T = ((t - self.spike_arrival_time) < self.T_dur) * self.T_0 self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt) @@ -595,7 +595,7 @@ def update(self, tdi, pre_spike=None): else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask) - if isinstance(self.mode, Batching): f = vmap(f) + if isinstance(self.mode, BatchingMode): f = vmap(f) post_vs = f(syn_value) else: post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) diff --git a/brainpy/dyn/synapses/couplings.py b/brainpy/dyn/synapses/delay_couplings.py similarity index 64% rename from brainpy/dyn/synapses/couplings.py rename to brainpy/dyn/synapses/delay_couplings.py index d0a83f50e..a76a573d2 100644 --- a/brainpy/dyn/synapses/couplings.py +++ b/brainpy/dyn/synapses/delay_couplings.py @@ -8,7 +8,8 @@ import brainpy.math as bm from brainpy.dyn.base import DynamicalSystem from brainpy.initialize import Initializer -from brainpy.tools.checking import check_sequence, check_integer +from brainpy.modes import Mode, TrainingMode, normal +from brainpy.tools.checking import check_sequence from brainpy.types import Tensor __all__ = [ @@ -25,7 +26,7 @@ class DelayCoupling(DynamicalSystem): ---------- delay_var: Variable The delay variable. - target_var: Variable, sequence of Variable + var_to_output: Variable, sequence of Variable The target variables to output. conn_mat: JaxArray, ndarray The connection matrix. @@ -40,14 +41,15 @@ class DelayCoupling(DynamicalSystem): def __init__( self, delay_var: bm.Variable, - target_var: Union[bm.Variable, Sequence[bm.Variable]], + var_to_output: Union[bm.Variable, Sequence[bm.Variable]], conn_mat: Tensor, required_shape: Tuple[int, ...], delay_steps: Optional[Union[int, Tensor, Initializer, Callable]] = None, initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None, - name: str = None + name: str = None, + mode: Mode = normal, ): - super(DelayCoupling, self).__init__(name=name) + super(DelayCoupling, self).__init__(name=name, mode=mode) # delay variable if not isinstance(delay_var, bm.Variable): @@ -56,10 +58,10 @@ def __init__( self.delay_var = delay_var # output variables - if isinstance(target_var, bm.Variable): - target_var = [target_var] - check_sequence(target_var, 'output_var', elem_type=bm.Variable, allow_none=False) - self.output_var = target_var + if isinstance(var_to_output, bm.Variable): + var_to_output = [var_to_output] + check_sequence(var_to_output, 'output_var', elem_type=bm.Variable, allow_none=False) + self.output_var = var_to_output # Connection matrix self.conn_mat = bm.asarray(conn_mat) @@ -73,10 +75,6 @@ def __init__( self.delay_steps = None self.delay_type = 'none' num_delay_step = None - elif isinstance(delay_steps, int): - self.delay_steps = delay_steps - num_delay_step = delay_steps - self.delay_type = 'int' elif callable(delay_steps): delay_steps = delay_steps(required_shape) if delay_steps.dtype not in [bm.int32, bm.int64, bm.uint32, bm.uint64]: @@ -87,22 +85,30 @@ def __init__( elif isinstance(delay_steps, (bm.JaxArray, jnp.ndarray)): if delay_steps.dtype not in [bm.int32, bm.int64, bm.uint32, bm.uint64]: raise ValueError(f'"delay_steps" must be integer typed. But we got {delay_steps.dtype}') - if delay_steps.shape != required_shape: - raise ValueError(f'we expect the delay matrix has the shape of {required_shape}. ' - f'While we got {delay_steps.shape}.') + if delay_steps.ndim == 0: + self.delay_type = 'int' + else: + self.delay_type = 'array' + if delay_steps.shape != required_shape: + raise ValueError(f'we expect the delay matrix has the shape of ' + f'(pre.num, post.num), i.e., {required_shape}. ' + f'While we got {delay_steps.shape}.') self.delay_steps = delay_steps - self.delay_type = 'array' num_delay_step = self.delay_steps.max() + elif isinstance(delay_steps, int): + self.delay_steps = delay_steps + num_delay_step = delay_steps + self.delay_type = 'int' else: raise ValueError(f'Unknown type of delay steps: {type(delay_steps)}') # delay variables - self.delay_step = self.register_delay(f'delay_{id(delay_var)}', - delay_step=num_delay_step, - delay_target=delay_var, - initial_delay_data=initial_delay_data) + _ = self.register_delay(f'delay_{id(delay_var)}', + delay_step=num_delay_step, + delay_target=delay_var, + initial_delay_data=initial_delay_data) - def reset(self): + def reset_state(self, batch_size=None): pass @@ -120,10 +126,10 @@ class DiffusiveCoupling(DelayCoupling): >>> import brainpy as bp >>> from brainpy.dyn import rates - >>> areas = rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn') - >>> conn = rates.DiffusiveCoupling(areas.x, areas.x, areas.input, - >>> conn_mat=Cmat, delay_steps=Dmat, - >>> initial_delay_data=bp.init.Uniform(0, 0.05)) + >>> areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn') + >>> conn = bp.synapses.DiffusiveCoupling(areas.x, areas.x, areas.input, + >>> conn_mat=Cmat, delay_steps=Dmat, + >>> initial_delay_data=bp.init.Uniform(0, 0.05)) >>> net = bp.dyn.Network(areas, conn) Parameters @@ -132,7 +138,7 @@ class DiffusiveCoupling(DelayCoupling): The first coupling variable, used for delay. coupling_var2: Variable Another coupling variable. - target_var: Variable, sequence of Variable + var_to_output: Variable, sequence of Variable The target variables to output. conn_mat: JaxArray, ndarray The connection matrix. @@ -148,11 +154,12 @@ def __init__( self, coupling_var1: bm.Variable, coupling_var2: bm.Variable, - target_var: Union[bm.Variable, Sequence[bm.Variable]], + var_to_output: Union[bm.Variable, Sequence[bm.Variable]], conn_mat: Tensor, delay_steps: Optional[Union[int, Tensor, Initializer, Callable]] = None, initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None, - name: str = None + name: str = None, + mode: Mode = normal, ): if not isinstance(coupling_var1, bm.Variable): raise ValueError(f'"coupling_var1" must be an instance of brainpy.math.Variable. ' @@ -169,12 +176,13 @@ def __init__( super(DiffusiveCoupling, self).__init__( delay_var=coupling_var1, - target_var=target_var, + var_to_output=var_to_output, conn_mat=conn_mat, required_shape=(coupling_var1.size, coupling_var2.size), delay_steps=delay_steps, initial_delay_data=initial_delay_data, - name=name + name=name, + mode=mode, ) self.coupling_var1 = coupling_var1 @@ -182,22 +190,29 @@ def __init__( def update(self, tdi): # delays + axis = self.coupling_var1.ndim + delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] if self.delay_steps is None: - diffusive = bm.expand_dims(self.coupling_var1, axis=1) - self.coupling_var2 - diffusive = (self.conn_mat * diffusive).sum(axis=0) + diffusive = (bm.expand_dims(self.coupling_var1, axis=axis) - + bm.expand_dims(self.coupling_var2, axis=axis - 1)) + diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) elif self.delay_type == 'array': - delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] - f = vmap(lambda i: delay_var(self.delay_steps[i], bm.arange(self.coupling_var2.size))) # (post.num,) - delays = f(bm.arange(self.coupling_var1.size).value) # (pre.num, post.num) - diffusive = delays - self.coupling_var2 # (pre.num, post.num) - diffusive = (self.conn_mat * diffusive).sum(axis=0) + if isinstance(self.mode, TrainingMode): + indices = (slice(None, None, None), bm.arange(self.coupling_var1.size),) + else: + indices = (bm.arange(self.coupling_var1.size),) + f = vmap(lambda i: delay_var(self.delay_steps[:, i], *indices)) # (..., pre.num) + delays = f(bm.arange(self.coupling_var2.size).value) # (..., post.num, pre.num) + diffusive = (bm.moveaxis(delays, axis - 1, axis) - + bm.expand_dims(self.coupling_var2, axis=axis - 1)) # (..., pre.num, post.num) + diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) elif self.delay_type == 'int': - delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] - delayed_var = delay_var(self.delay_steps) - diffusive = bm.expand_dims(delayed_var, axis=1) - self.coupling_var2 - diffusive = (self.conn_mat * diffusive).sum(axis=0) + delayed_data = delay_var(self.delay_steps) # (..., pre.num) + diffusive = (bm.expand_dims(delayed_data, axis=axis) - + bm.expand_dims(self.coupling_var2, axis=axis - 1)) # (..., pre.num, post.num) + diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) else: - raise ValueError + raise ValueError(f'Unknown delay type {self.delay_type}') # output to target variable for target in self.output_var: @@ -216,7 +231,7 @@ class AdditiveCoupling(DelayCoupling): ---------- coupling_var: Variable The coupling variable, used for delay. - target_var: Variable, sequence of Variable + var_to_output: Variable, sequence of Variable The target variables to output. conn_mat: JaxArray, ndarray The connection matrix. @@ -231,11 +246,12 @@ class AdditiveCoupling(DelayCoupling): def __init__( self, coupling_var: bm.Variable, - target_var: Union[bm.Variable, Sequence[bm.Variable]], + var_to_output: Union[bm.Variable, Sequence[bm.Variable]], conn_mat: Tensor, delay_steps: Optional[Union[int, Tensor, Initializer, Callable]] = None, initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None, - name: str = None + name: str = None, + mode: Mode = normal, ): if not isinstance(coupling_var, bm.Variable): raise ValueError(f'"coupling_var" must be an instance of brainpy.math.Variable. ' @@ -246,29 +262,34 @@ def __init__( super(AdditiveCoupling, self).__init__( delay_var=coupling_var, - target_var=target_var, + var_to_output=var_to_output, conn_mat=conn_mat, required_shape=(coupling_var.size, coupling_var.size), delay_steps=delay_steps, initial_delay_data=initial_delay_data, - name=name + name=name, + mode=mode, ) self.coupling_var = coupling_var - def update(self, t, dt): + def update(self, tdi): # delay function + axis = self.coupling_var.ndim + delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] if self.delay_steps is None: additive = self.coupling_var @ self.conn_mat elif self.delay_type == 'array': - delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] - f = vmap(lambda i: delay_var(self.delay_steps[i], bm.arange(self.coupling_var.size))) # (pre.num,) - delays = f(bm.arange(self.coupling_var.size).value) # (post.num, pre.num) - additive = (self.conn_mat * delays.T).sum(axis=0) + if isinstance(self.mode, TrainingMode): + indices = (slice(None, None, None), bm.arange(self.coupling_var.size),) + else: + indices = (bm.arange(self.coupling_var.size),) + f = vmap(lambda i: delay_var(self.delay_steps[:, i], *indices)) # (.., pre.num,) + delays = f(bm.arange(self.coupling_var.size).value) # (..., post.num, pre.num) + additive = (self.conn_mat * bm.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1) elif self.delay_type == 'int': - delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0] - delayed_var = delay_var(self.delay_steps) - additive = (self.conn_mat * delayed_var).sum(axis=0) + delayed_var = delay_var(self.delay_steps) # (..., pre.num) + additive = delayed_var @ self.conn_mat else: raise ValueError diff --git a/brainpy/initialize/generic.py b/brainpy/initialize/generic.py index 654d9399d..258e86f47 100644 --- a/brainpy/initialize/generic.py +++ b/brainpy/initialize/generic.py @@ -8,7 +8,7 @@ import brainpy.math as bm from brainpy.tools.others import to_size from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, NonBatching, Batching, Training +from brainpy.modes import Mode, NormalMode, BatchingMode, TrainingMode from .base import Initializer @@ -91,9 +91,9 @@ def variable( if callable(data): if var_shape is None: raise ValueError('"varshape" cannot be None when data is a callable function.') - if isinstance(batch_size_or_mode, NonBatching): + if isinstance(batch_size_or_mode, NormalMode): return bm.Variable(data(var_shape)) - elif isinstance(batch_size_or_mode, Batching): + elif isinstance(batch_size_or_mode, BatchingMode): new_shape = var_shape[:batch_axis] + (1,) + var_shape[batch_axis:] return bm.Variable(data(new_shape), batch_axis=batch_axis) elif batch_size_or_mode in (None, False): @@ -105,9 +105,9 @@ def variable( if var_shape is not None: if bm.shape(data) != var_shape: raise ValueError(f'The shape of "data" {bm.shape(data)} does not match with "var_shape" {var_shape}') - if isinstance(batch_size_or_mode, NonBatching): + if isinstance(batch_size_or_mode, NormalMode): return bm.Variable(data(var_shape)) - elif isinstance(batch_size_or_mode, Batching): + elif isinstance(batch_size_or_mode, BatchingMode): return bm.Variable(bm.expand_dims(data, axis=batch_axis), batch_axis=batch_axis) elif batch_size_or_mode in (None, False): return bm.Variable(data) diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index 3f5ef4120..cfcd286cb 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -327,17 +327,20 @@ def reset( batch_axis = None if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): batch_axis = delay_target.batch_axis + 1 - self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype), + self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, + dtype=delay_target.dtype), batch_axis=batch_axis) else: - self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) + self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, + dtype=delay_target.dtype) self.data[-1] = delay_target if initial_delay_data is None: pass elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)): self.data[:-1] = initial_delay_data elif callable(initial_delay_data): - self.data[:-1] = initial_delay_data((delay_len,) + delay_target.shape, dtype=delay_target.dtype) + self.data[:-1] = initial_delay_data((delay_len,) + delay_target.shape, + dtype=delay_target.dtype) else: raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') diff --git a/brainpy/modes.py b/brainpy/modes.py index b1c315812..844c0ee0d 100644 --- a/brainpy/modes.py +++ b/brainpy/modes.py @@ -3,11 +3,11 @@ __all__ = [ 'Mode', - 'NonBatching', - 'Batching', - 'Training', + 'NormalMode', + 'BatchingMode', + 'TrainingMode', - 'nonbatching', + 'normal', 'batching', 'training', ] @@ -18,20 +18,22 @@ def __repr__(self): return self.__class__.__name__ -class NonBatching(Mode): +class NormalMode(Mode): + """Normal non-batching mode.""" pass -class Batching(Mode): +class BatchingMode(Mode): + """Batching mode.""" pass -class Training(Batching): +class TrainingMode(BatchingMode): + """Training mode requires data batching.""" pass -nonbatching = NonBatching() -batching = Batching() -training = Training() - +normal = NormalMode() +batching = BatchingMode() +training = TrainingMode() diff --git a/brainpy/train/offline.py b/brainpy/train/offline.py index 870f295ed..4ad1b4aa4 100644 --- a/brainpy/train/offline.py +++ b/brainpy/train/offline.py @@ -11,7 +11,7 @@ from brainpy.base import Base from brainpy.dyn.base import DynamicalSystem from brainpy.errors import NoImplementationError -from brainpy.modes import Training +from brainpy.modes import TrainingMode from brainpy.tools.checking import serialize_kwargs from brainpy.types import Tensor, Output from .base import DSTrainer @@ -58,7 +58,7 @@ def __init__( # get all trainable nodes nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique() - self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, Training)]) + self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, TrainingMode)]) if len(self.train_nodes) == 0: raise ValueError('Found no trainable nodes.') diff --git a/brainpy/train/online.py b/brainpy/train/online.py index 9928cdcae..8ae1e34f2 100644 --- a/brainpy/train/online.py +++ b/brainpy/train/online.py @@ -12,7 +12,7 @@ from brainpy.base import Base from brainpy.dyn.base import DynamicalSystem from brainpy.errors import NoImplementationError -from brainpy.modes import Training +from brainpy.modes import TrainingMode from brainpy.tools.checking import serialize_kwargs from brainpy.tools.others.dicts import DotDict from brainpy.types import Tensor, Output @@ -57,7 +57,7 @@ def __init__( # get all trainable nodes nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique() - self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, Training)]) + self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, TrainingMode)]) if len(self.train_nodes) == 0: raise ValueError('Found no trainable nodes.') diff --git a/examples/simulation/Wang_2002_decision_making_spiking.py b/examples/simulation/Wang_2002_decision_making_spiking.py index 995a686af..485f7aec8 100644 --- a/examples/simulation/Wang_2002_decision_making_spiking.py +++ b/examples/simulation/Wang_2002_decision_making_spiking.py @@ -15,7 +15,7 @@ class PoissonStim(bp.dyn.NeuGroup): - def __init__(self, size, freq_mean, freq_var, t_interval, mode=bp.modes.NonBatching()): + def __init__(self, size, freq_mean, freq_var, t_interval, mode=bp.modes.NormalMode()): super(PoissonStim, self).__init__(size=size, mode=mode) # parameters @@ -43,12 +43,12 @@ def update(self, tdi): in_interval = bm.logical_and(in_interval, (t - self.freq_t_last_change) >= self.t_interval) self.freq.value = bm.where(in_interval, self.rng.normal(self.freq_mean, self.freq_var, self.freq.shape), prev_freq) self.freq_t_last_change.value = bm.where(in_interval, t, self.freq_t_last_change) - shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, bp.modes.Batching) else self.varshape + shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, bp.modes.BatchingMode) else self.varshape self.spike.value = self.rng.random(shape) < self.freq * self.dt class DecisionMaking(bp.dyn.Network): - def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15, mode=bp.modes.NonBatching()): + def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15, mode=bp.modes.NormalMode()): super(DecisionMaking, self).__init__() num_exc = int(1600 * scale) @@ -262,9 +262,9 @@ def single_run(): def batching_run(): num_row, num_col = 3, 4 - num_batch = 200 + num_batch = 12 coherence = bm.expand_dims(bm.linspace(-100, 100., num_batch), 1) - net = DecisionMaking(scale=1., coherence=coherence, mu0=20., mode=bp.modes.Batching()) + net = DecisionMaking(scale=1., coherence=coherence, mu0=20., mode=bp.modes.BatchingMode()) net.reset_state(batch_size=num_batch) runner = bp.dyn.DSRunner( @@ -272,20 +272,20 @@ def batching_run(): ) runner.run(total_period) - # coherence = coherence.to_numpy() - # fig, gs = bp.visualize.get_figure(num_row, num_col, 3, 4) - # for i in range(num_row): - # for j in range(num_col): - # idx = i * num_col + j - # if idx < num_batch: - # mon = {'A.spike': runner.mon['A.spike'][:, idx], - # 'B.spike': runner.mon['B.spike'][:, idx], - # 'IA.freq': runner.mon['IA.freq'][:, idx], - # 'IB.freq': runner.mon['IB.freq'][:, idx], - # 'ts': runner.mon['ts']} - # ax = fig.add_subplot(gs[i, j]) - # visualize_raster(ax, mon=mon, title=f'coherence={coherence[idx, 0]}%') - # plt.show() + coherence = coherence.to_numpy() + fig, gs = bp.visualize.get_figure(num_row, num_col, 3, 4) + for i in range(num_row): + for j in range(num_col): + idx = i * num_col + j + if idx < num_batch: + mon = {'A.spike': runner.mon['A.spike'][:, idx], + 'B.spike': runner.mon['B.spike'][:, idx], + 'IA.freq': runner.mon['IA.freq'][:, idx], + 'IB.freq': runner.mon['IB.freq'][:, idx], + 'ts': runner.mon['ts']} + ax = fig.add_subplot(gs[i, j]) + visualize_raster(ax, mon=mon, title=f'coherence={coherence[idx, 0]}%') + plt.show() if __name__ == '__main__': diff --git a/examples/simulation/whole_brain_simulation_with_fhn.py b/examples/simulation/whole_brain_simulation_with_fhn.py index 6bd5d2f03..04d7a40de 100644 --- a/examples/simulation/whole_brain_simulation_with_fhn.py +++ b/examples/simulation/whole_brain_simulation_with_fhn.py @@ -6,13 +6,12 @@ import brainpy as bp import brainpy.math as bm -from brainpy.dyn import rates bp.check.turn_off() def bifurcation_analysis(): - model = rates.FHN(1, method='exp_auto') + model = bp.rates.FHN(1, method='exp_auto') pp = bp.analysis.Bifurcation2D( model, target_vars={'x': [-2, 2], 'y': [-2, 2]}, @@ -38,11 +37,14 @@ def __init__(self, signal_speed=20.): delay_mat = bm.round(hcp['Dmat'] / signal_speed / bm.get_dt()) bm.fill_diagonal(delay_mat, 0) - self.fhn = rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn') - self.coupling = rates.DiffusiveCoupling(self.fhn.x, self.fhn.x, self.fhn.input, - conn_mat=conn_mat, - delay_steps=delay_mat.astype(bm.int_), - initial_delay_data=bp.init.Uniform(0, 0.05)) + self.fhn = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn') + self.coupling = bp.synapses.DiffusiveCoupling( + self.fhn.x, self.fhn.x, + var_to_output=self.fhn.input, + conn_mat=conn_mat, + delay_steps=delay_mat.astype(bm.int_), + initial_delay_data=bp.init.Uniform(0, 0.05) + ) def brain_simulation(): @@ -55,11 +57,11 @@ def brain_simulation(): fc = bp.measure.functional_connectivity(runner.mon['fhn.x']) ax = axs[0].imshow(fc) plt.colorbar(ax, ax=axs[0]) - axs[1].plot(runner.mon.ts, runner.mon['fhn.x'][:, ::5], alpha=0.8) + axs[1].plot(runner.mon['ts'], runner.mon['fhn.x'][:, ::5], alpha=0.8) plt.tight_layout() plt.show() if __name__ == '__main__': - bifurcation_analysis() + # bifurcation_analysis() brain_simulation() diff --git a/examples/simulation/whole_brain_simulation_with_sl_oscillator.py b/examples/simulation/whole_brain_simulation_with_sl_oscillator.py index 1fa946a49..5bd0fe670 100644 --- a/examples/simulation/whole_brain_simulation_with_sl_oscillator.py +++ b/examples/simulation/whole_brain_simulation_with_sl_oscillator.py @@ -5,13 +5,12 @@ import brainpy as bp import brainpy.math as bm -from brainpy.dyn import rates bp.check.turn_off() def bifurcation_analysis(): - model = rates.StuartLandauOscillator(1, method='exp_auto') + model = bp.rates.StuartLandauOscillator(1, method='exp_auto') pp = bp.analysis.Bifurcation2D( model, target_vars={'x': [-2, 2], 'y': [-2, 2]}, @@ -37,8 +36,11 @@ def __init__(self): gc = 0.6 # global coupling strength self.sl = bp.rates.StuartLandauOscillator(80, x_ou_sigma=0.14, y_ou_sigma=0.14, name='sl') - self.coupling = bp.synapses.DiffusiveCoupling(self.sl.x, self.sl.x, self.sl.input, - conn_mat=conn_mat * gc) + self.coupling = bp.synapses.DiffusiveCoupling( + self.sl.x, self.sl.x, + var_to_output=self.sl.input, + conn_mat=conn_mat * gc + ) def simulation(): @@ -51,11 +53,11 @@ def simulation(): fc = bp.measure.functional_connectivity(runner.mon['sl.x']) ax = axs[0].imshow(fc) plt.colorbar(ax, ax=axs[0]) - axs[1].plot(runner.mon.ts, runner.mon['sl.x'][:, ::5], alpha=0.8) + axs[1].plot(runner.mon['ts'], runner.mon['sl.x'][:, ::5], alpha=0.8) plt.tight_layout() plt.show() if __name__ == '__main__': - bifurcation_analysis() + # bifurcation_analysis() simulation() diff --git a/examples/training/Bellec_2020_eprop_evidence_accumulation.py b/examples/training/Bellec_2020_eprop_evidence_accumulation.py index 231726f19..444f1e567 100644 --- a/examples/training/Bellec_2020_eprop_evidence_accumulation.py +++ b/examples/training/Bellec_2020_eprop_evidence_accumulation.py @@ -8,17 +8,108 @@ of spiking neurons. Nature communications, 11(1), 1-15. """ + +import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import brainpy as bp import brainpy.math as bm from jax.lax import stop_gradient +from matplotlib import patches + +bm.set_dt(1.) # Simulation time step [ms] + +# training parameters +n_batch = 64 # batch size + +# neuron model and simulation parameters +reg_f = 1. # regularization coefficient for firing rate +reg_rate = 10 # target firing rate for regularization [Hz] + +# Experiment parameters +t_cue_spacing = 150 # distance between two consecutive cues in ms + +# Frequencies +input_f0 = 40. / 1000. # poisson firing rate of input neurons in khz +regularization_f0 = reg_rate / 1000. # mean target network firing frequency + + +class ALIF(bp.dyn.NeuGroup): + def __init__( + self, num_in, num_rec, tau=20., thr=0.03, + dampening_factor=0.3, tau_adaptation=200., + stop_z_gradients=False, n_refractory=1, + name=None, mode=bp.modes.training, + ): + super(ALIF, self).__init__(name=name, size=num_rec, mode=mode) + + self.n_in = num_in + self.n_rec = num_rec + self.n_regular = int(num_rec / 2) + self.n_adaptive = num_rec - self.n_regular + + self.n_refractory = n_refractory + self.tau_adaptation = tau_adaptation + # generate threshold decay time constants # + rhos = bm.exp(- bm.get_dt() / tau_adaptation) # decay factors for adaptive threshold + beta = 1.7 * (1 - rhos) / (1 - bm.exp(-1 / tau)) # this is a heuristic value + # multiplicative factors for adaptive threshold + self.beta = bm.concatenate([bm.zeros(self.n_regular), beta * bm.ones(self.n_adaptive)]) + + self.decay_b = jnp.exp(-bm.get_dt() / tau_adaptation) + self.decay = jnp.exp(-bm.get_dt() / tau) + self.dampening_factor = dampening_factor + self.stop_z_gradients = stop_z_gradients + self.tau = tau + self.thr = thr + self.mask = jnp.diag(jnp.ones(num_rec, dtype=bool)) -bm.set_dt(1.) + # parameters + self.w_in = bm.TrainVar(bm.random.randn(num_in, self.num) / jnp.sqrt(num_in)) + self.w_rec = bm.TrainVar(bm.random.randn(self.num, self.num) / jnp.sqrt(self.num)) + + # Variables + self.v = bm.Variable(jnp.zeros((1, self.num)), batch_axis=0) + self.b = bm.Variable(jnp.zeros((1, self.num)), batch_axis=0) + self.r = bm.Variable(jnp.zeros((1, self.num)), batch_axis=0) + self.spike = bm.Variable(jnp.zeros((1, self.num)), batch_axis=0) + + def reset_state(self, batch_size=1): + self.v.value = bm.Variable(jnp.zeros((batch_size, self.n_rec))) + self.b.value = bm.Variable(jnp.zeros((batch_size, self.n_rec))) + self.r.value = bm.Variable(jnp.zeros((batch_size, self.n_rec))) + self.spike.value = bm.Variable(jnp.zeros((batch_size, self.n_rec))) + + def compute_z(self, v, b): + adaptive_thr = self.thr + b * self.beta + v_scaled = (v - adaptive_thr) / self.thr + z = bm.spike_with_relu_grad(v_scaled, self.dampening_factor) + z = z * 1 / bm.get_dt() + return z + + def update(self, sha, x): + z = self.spike.value + if self.stop_z_gradients: + z = stop_gradient(z) + + # threshold update does not have to depend on the stopped-gradient-z, it's local + new_b = self.decay_b * self.b.value + self.spike.value + + # gradients are blocked in spike transmission + i_t = jnp.matmul(x.value, self.w_in.value) + jnp.matmul(z, jnp.where(self.mask, 0, self.w_rec.value)) + i_reset = z * self.thr * bm.get_dt() + new_v = self.decay * self.v + i_t - i_reset + + # spike generation + self.spike.value = bm.where(self.r.value > 0, 0., self.compute_z(new_v, new_b)) + new_r = bm.clip(self.r.value + self.n_refractory * self.spike - 1, 0, self.n_refractory) + self.r.value = stop_gradient(new_r) + self.v.value = new_v + self.b.value = new_b class EligSNN(bp.dyn.Network): - def __init__(self, num_in, num_rec, num_out, neuron_model='lif'): + def __init__(self, num_in, num_rec, num_out, stop_z_gradients=False): super(EligSNN, self).__init__() # parameters @@ -27,36 +118,25 @@ def __init__(self, num_in, num_rec, num_out, neuron_model='lif'): self.num_out = num_out # neurons - self.i = bp.neurons.InputGroup(num_in, trainable=True) - self.o = bp.neurons.LeakyIntegrator(num_out, tau=20, trainable=True) - tau_a = 2e3 - tau_v = 2e1 - n_regular = 50 - n_adaptive = num_rec - n_regular - beta_a1 = bm.exp(- bm.get_dt() / tau_a) - beta_a2 = 1.7 * (1 - beta_a1) / (1 - bm.exp(-1 / tau_v)) - self.r = bp.neurons.ALIFBellec2020( - n_regular + n_adaptive, trainable=True, - V_rest=0., tau_ref=5., V_th=0.6, tau_a=tau_a, tau=tau_v, - beta=bm.concatenate([bm.ones(n_regular), bm.ones(n_adaptive) * beta_a2]), - ) + self.r = ALIF(num_in=num_in, num_rec=num_rec, tau=20, tau_adaptation=2000, + n_refractory=5, stop_z_gradients=stop_z_gradients, thr=0.6) + self.o = bp.neurons.LeakyIntegrator(num_out, tau=20, mode=bp.modes.training) # synapses - self.i2r = bp.layers.Dense(num_in, num_rec, W_initializer=bp.init.KaimingNormal()) - self.r2r = bp.layers.Dense(num_rec, num_rec, W_initializer=bp.init.KaimingNormal()) - self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), - output=bp.synouts.CUBA(), tau=10., - g_max=bp.init.KaimingNormal(), - trainable=True) - - def update(self, shared, x): - self.i2r(shared, x) - self.r(shared, x=self.r2r(shared, stop_gradient(self.r.spike.value))) - self.r2o(shared, ) - self.o(shared, ) + self.r2o = bp.layers.Dense(num_rec, num_out, + W_initializer=bp.init.KaimingNormal(), + b_initializer=None) + + def update(self, sha, x): + self.r(sha, x) + self.o.input += self.r2o(sha, self.r.spike.value) + self.o(sha) return self.o.V.value +net = EligSNN(num_in=40, num_rec=100, num_out=2, stop_z_gradients=True) + + @bp.tools.numba_jit def generate_click_task_data(batch_size, seq_len, n_neuron, recall_duration, prob, f0=0.5, n_cues=7, t_cue=100, t_interval=150, n_input_symbols=4): @@ -100,29 +180,17 @@ def generate_click_task_data(batch_size, seq_len, n_neuron, recall_duration, pro def get_data(batch_size, n_in, t_interval, f0): # used for obtaining a new randomly generated batch of examples def generate_data(): - for _ in range(100): + for _ in range(10): seq_len = int(t_interval * 7 + 1200) spk_data, _, target_data, _ = generate_click_task_data( batch_size=batch_size, seq_len=seq_len, n_neuron=n_in, recall_duration=150, - prob=0.3, t_cue=100, n_cues=7, t_interval=t_interval, f0=f0, n_input_symbols=4) + prob=0.3, t_cue=100, n_cues=7, t_interval=t_interval, f0=f0, n_input_symbols=4 + ) yield spk_data, target_data return generate_data -# experiment parameters -reg_f = 1. # regularization coefficient for firing rate -reg_rate = 10 # target firing rate for regularization [Hz] -t_cue_spacing = 150 # distance between two consecutive cues in ms - -# frequency -input_f0 = 40. / 1000. # poisson firing rate of input neurons in khz -regularization_f0 = reg_rate / 1000. # mean target network firing frequency - -# model -net = EligSNN(num_in=40, num_rec=100, num_out=2, neuron_model='alif') - - def loss_fun(predicts, targets): predicts, mon = predicts @@ -145,33 +213,67 @@ def loss_fun(predicts, targets): # Aggregate the losses # loss = loss_reg_f + loss_cls - loss_res = {'loss': loss, 'loss reg': loss_reg_f, 'accuracy': accuracy} return loss, loss_res # Training -trainer = bp.train.BPTT(net, - loss_fun, - loss_has_aux=True, - optimizer=bp.optimizers.Adam(lr=1e-2), - monitors={'r.spike': net.r.spike}, ) +trainer = bp.train.BPTT( + net, loss_fun, + loss_has_aux=True, + optimizer=bp.optimizers.Adam(lr=0.005), + monitors={'r.spike': net.r.spike}, +) trainer.fit(get_data(64, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0), - num_epoch=2, num_report=10) - - -fig, gs = bp.visualize.get_figure(2, 2, 4, 5) - -fig.add_subplot(gs[0, 0]) -plt.plot(bm.as_numpy(trainer.train_losses)) -plt.ylabel('Overall Loss') -fig.add_subplot(gs[0, 1]) -plt.plot(bm.as_numpy(trainer.train_loss_aux['loss'])) -plt.ylabel('Accuracy Loss') -fig.add_subplot(gs[1, 0]) -plt.plot(bm.as_numpy(trainer.train_loss_aux['loss reg'])) -plt.ylabel('Regularization Loss') -fig.add_subplot(gs[1, 1]) -plt.plot(bm.as_numpy(trainer.train_loss_aux['accuracy'])) -plt.ylabel('Accuracy') -plt.show() + num_epoch=30, + num_report=10) + +# visualization +dataset, _ = next(get_data(20, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0)()) +runner = bp.train.DSTrainer(net, monitors={'spike': net.r.spike}) +outs = runner.predict(dataset, reset_state=True) + +for i in range(10): + fig, gs = bp.visualize.get_figure(3, 1, 2., 6.) + ax_inp = fig.add_subplot(gs[0, 0]) + ax_rec = fig.add_subplot(gs[1, 0]) + ax_out = fig.add_subplot(gs[2, 0]) + + data = dataset[i] + # insert empty row + n_channel = data.shape[1] // 4 + zero_fill = np.zeros((data.shape[0], int(n_channel / 2))) + data = np.concatenate((data[:, 3 * n_channel:], + zero_fill, + data[:, 2 * n_channel:3 * n_channel], + zero_fill, + data[:, :n_channel], + zero_fill, + data[:, n_channel:2 * n_channel]), axis=1) + ax_inp.set_yticklabels([]) + ax_inp.add_patch(patches.Rectangle((0, 2 * n_channel + 2 * int(n_channel / 2)), + data.shape[0], n_channel, + facecolor="red", alpha=0.1)) + ax_inp.add_patch(patches.Rectangle((0, 3 * n_channel + 3 * int(n_channel / 2)), + data.shape[0], n_channel, + facecolor="blue", alpha=0.1)) + bp.visualize.raster_plot(runner.mon.ts, data, ax=ax_inp, marker='|') + ax_inp.set_ylabel('Input Activity') + ax_inp.set_xticklabels([]) + ax_inp.set_xticks([]) + + # spiking activity + bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'][i], ax=ax_rec, marker='|') + ax_rec.set_ylabel('Spiking Activity') + ax_rec.set_xticklabels([]) + ax_rec.set_xticks([]) + # decision activity + ax_out.set_yticks([0, 0.5, 1]) + ax_out.set_ylabel('Output Activity') + ax_out.plot(runner.mon.ts, outs[i, :, 0], label='Readout 0', alpha=0.7) + ax_out.plot(runner.mon.ts, outs[i, :, 1], label='Readout 1', alpha=0.7) + ax_out.set_xticklabels([]) + ax_out.set_xticks([]) + ax_out.set_xlabel('Time [ms]') + plt.legend() + plt.show() diff --git a/examples/training/Gauthier_2021_ngrc_double_scroll.py b/examples/training/Gauthier_2021_ngrc_double_scroll.py index ff46c5be5..f9b5e458b 100644 --- a/examples/training/Gauthier_2021_ngrc_double_scroll.py +++ b/examples/training/Gauthier_2021_ngrc_double_scroll.py @@ -109,8 +109,8 @@ def __init__(self, num_in): self.r = bp.layers.NVAR(num_in, delay=2, order=3, mode=bp.modes.batching) self.di = bp.layers.Dense(self.r.num_out, num_in, mode=bp.modes.training) - def update(self, shared, x): - di = self.di(shared, self.r(shared, x)) + def update(self, sha, x): + di = self.di(sha, self.r(sha, x)) return x + di diff --git a/examples/training/Gauthier_2021_ngrc_lorenz.py b/examples/training/Gauthier_2021_ngrc_lorenz.py index 710c3d5f0..42fa2fbb5 100644 --- a/examples/training/Gauthier_2021_ngrc_lorenz.py +++ b/examples/training/Gauthier_2021_ngrc_lorenz.py @@ -14,6 +14,7 @@ import brainpy as bp import brainpy.math as bm + bm.enable_x64() diff --git a/examples/training/echo_state_network.py b/examples/training/echo_state_network.py index 8d1320c07..f10c5677c 100644 --- a/examples/training/echo_state_network.py +++ b/examples/training/echo_state_network.py @@ -12,12 +12,12 @@ def __init__(self, num_in, num_hidden, num_out): Wrec_initializer=bp.init.Normal(scale=0.1), in_connectivity=0.02, rec_connectivity=0.02, - conn_type='dense', + comp_type='dense', mode=bp.modes.batching) self.o = bp.layers.Dense(num_hidden, num_out, W_initializer=bp.init.Normal()) - def update(self, shared_args, x): - return self.o(shared_args, self.r(shared_args, x)) + def update(self, sha, x): + return self.o(sha, self.r(sha, x)) class NGRC(bp.dyn.DynamicalSystem): diff --git a/extensions/setup.py b/extensions/setup.py index 45c22d62e..a5b770b75 100644 --- a/extensions/setup.py +++ b/extensions/setup.py @@ -34,7 +34,7 @@ author_email='chao.brain@qq.com', packages=find_packages(exclude=['lib*']), include_package_data=True, - install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8", "cffi", "numba"], + install_requires=["jax", "jaxlib", "pybind11>=2.6", "cffi", "numba"], extras_require={"test": "pytest"}, python_requires='>=3.7', url='https://github.com/PKU-NIP-Lab/BrainPy', diff --git a/extensions/setup_cuda.py b/extensions/setup_cuda.py index 9750a1454..30a2a46f2 100644 --- a/extensions/setup_cuda.py +++ b/extensions/setup_cuda.py @@ -91,7 +91,7 @@ def build_extension(self, ext): author_email='chao.brain@qq.com', packages=find_packages(exclude=['lib*']), include_package_data=True, - install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8", "cffi", "numba"], + install_requires=["jax", "jaxlib", "pybind11>=2.6", "cffi", "numba"], extras_require={"test": "pytest"}, python_requires='>=3.7', url='https://github.com/PKU-NIP-Lab/BrainPy', diff --git a/extensions/setup_mac.py b/extensions/setup_mac.py index f2d8b6dac..1450ee46a 100644 --- a/extensions/setup_mac.py +++ b/extensions/setup_mac.py @@ -36,7 +36,7 @@ author_email='chao.brain@qq.com', packages=find_packages(exclude=['lib*']), include_package_data=True, - install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8", "cffi", "numba"], + install_requires=["jax", "jaxlib", "pybind11>=2.6", "cffi", "numba"], extras_require={"test": "pytest"}, python_requires='>=3.7', url='https://github.com/PKU-NIP-Lab/BrainPy', From 97a4976a8f1587f081e0b0e10b7cb3d82b49fbf0 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 27 Jul 2022 18:35:07 +0800 Subject: [PATCH 2/5] enable monitor inputs by add "reset_input()" function --- brainpy/analysis/highdim/slow_points.py | 21 +- brainpy/dyn/base.py | 291 ++++++++++++----------- brainpy/dyn/neurons/biological_models.py | 32 ++- brainpy/dyn/neurons/fractional_models.py | 4 + brainpy/dyn/neurons/input_groups.py | 46 +++- brainpy/dyn/neurons/reduced_models.py | 53 +++-- brainpy/dyn/rates/populations.py | 11 + brainpy/dyn/runners.py | 1 + brainpy/train/back_propagation.py | 1 + brainpy/train/online.py | 1 + 10 files changed, 281 insertions(+), 180 deletions(-) diff --git a/brainpy/analysis/highdim/slow_points.py b/brainpy/analysis/highdim/slow_points.py index 4538b09c2..571ea4ae7 100644 --- a/brainpy/analysis/highdim/slow_points.py +++ b/brainpy/analysis/highdim/slow_points.py @@ -340,7 +340,7 @@ def find_fps_with_gd_method( f_eval_loss = self._get_f_eval_loss() def f_loss(): - return f_eval_loss(tree_map(lambda a: a.value, + return f_eval_loss(tree_map(lambda a: bm.as_device_array(a), fixed_points, is_leaf=lambda x: isinstance(x, bm.JaxArray))).mean() @@ -386,9 +386,11 @@ def batch_train(start_i, n_batch): f'is below tolerance {tolerance:0.10f}.') self._opt_losses = bm.concatenate(opt_losses) - self._losses = f_eval_loss(tree_map(lambda a: a.value, fixed_points, + self._losses = f_eval_loss(tree_map(lambda a: bm.as_device_array(a), + fixed_points, is_leaf=lambda x: isinstance(x, bm.JaxArray))) - self._fixed_points = tree_map(lambda a: a.value, fixed_points, + self._fixed_points = tree_map(lambda a: bm.as_device_array(a), + fixed_points, is_leaf=lambda x: isinstance(x, bm.JaxArray)) self._selected_ids = jnp.arange(num_candidate) @@ -425,7 +427,7 @@ def find_fps_with_opt_solver( print(f"Optimizing with {opt_solver} to find fixed points:") # optimizing - res = f_opt(tree_map(lambda a: a.value, + res = f_opt(tree_map(lambda a: bm.as_device_array(a), candidates, is_leaf=lambda a: isinstance(a, bm.JaxArray))) @@ -720,16 +722,27 @@ def _generate_ds_cell_function( shared = DotDict(t=t, dt=dt, i=0) def f_cell(h: Dict): + target.clear_input() + + # update target variables for k, v in self.target_vars.items(): v.value = (bm.asarray(h[k], dtype=v.dtype) if v.batch_axis is None else bm.asarray(bm.expand_dims(h[k], axis=v.batch_axis), dtype=v.dtype)) + + # update excluded variables for k, v in self.excluded_vars.items(): v.value = self.excluded_data[k] + + # add inputs if f_input is not None: f_input(shared) + + # call update functions args = (shared,) + self.args target.update(*args) + + # get new states new_h = {k: (v.value if v.batch_axis is None else jnp.squeeze(v.value, axis=v.batch_axis)) for k, v in self.target_vars.items()} return new_h diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 70041c8cc..c97c2bc09 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -25,7 +25,7 @@ 'DynamicalSystem', # containers - 'Container', 'Network', 'Sequential', + 'Container', 'Network', 'Sequential', 'System', # channel models 'Channel', @@ -38,25 +38,6 @@ ] -def not_customized(fun: Callable) -> Callable: - """Marks the given module method is not implemented. - - Methods wrapped in @not_customized can define submodules directly within the method. - - For instance:: - - @not_customized - init_fb(self): - ... - - @not_customized - def feedback(self): - ... - """ - fun.not_implemented = True - return fun - - class DynamicalSystem(Base): """Base Dynamical System class. @@ -83,7 +64,7 @@ class DynamicalSystem(Base): def __init__( self, name: str = None, - mode: Mode = normal, + mode: Optional[Mode] = None, ): super(DynamicalSystem, self).__init__(name=name) @@ -91,6 +72,7 @@ def __init__( self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector() # mode setting + if mode is None: mode = normal if not isinstance(mode, Mode): raise ValueError(f'Should be instance of {Mode.__name__}, but we got {type(Mode)}: {Mode}') self._mode = mode @@ -257,6 +239,7 @@ def reset_state(self, batch_size=None): if len(child_nodes) > 0: for node in child_nodes.values(): node.reset_state(batch_size=batch_size) + self.reset_local_delays(child_nodes) else: raise NotImplementedError('Must implement "reset_state" function by subclass self. ' f'Error of {self.name}') @@ -320,26 +303,30 @@ def __del__(self): del self.__dict__[key] gc.collect() - @not_customized + @tools.not_customized def online_init(self): raise NoImplementationError('Subclass must implement online_init() function when using OnlineTrainer.') - @not_customized + @tools.not_customized def offline_init(self): raise NoImplementationError('Subclass must implement offline_init() function when using OfflineTrainer.') - @not_customized + @tools.not_customized def online_fit(self, target: Tensor, fit_record: Dict[str, Tensor]): raise NoImplementationError('Subclass must implement online_fit() function when using OnlineTrainer.') - @not_customized + @tools.not_customized def offline_fit(self, target: Tensor, fit_record: Dict[str, Tensor]): raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.') + def clear_input(self): + for node in self.nodes(level=1, include_self=False).subset(NeuGroup).unique().values(): + node.clear_input() + class Container(DynamicalSystem): """Container object which is designed to add other instances of DynamicalSystem. @@ -367,23 +354,6 @@ def __init__( ): super(Container, self).__init__(name=name, mode=mode) - # # children dynamical systems - # self.implicit_nodes = Collector() - # for ds in ds_tuple: - # if not isinstance(ds, DynamicalSystem): - # raise ModelBuildError(f'{self.__class__.__name__} receives instances of ' - # f'DynamicalSystem, however, we got {type(ds)}.') - # if ds.name in self.implicit_nodes: - # raise ValueError(f'{ds.name} has been paired with {ds}. Please change a unique name.') - # self.register_implicit_nodes({node.name: node for node in ds_tuple}) - # for key, ds in ds_dict.items(): - # if not isinstance(ds, DynamicalSystem): - # raise ModelBuildError(f'{self.__class__.__name__} receives instances of ' - # f'DynamicalSystem, however, we got {type(ds)}.') - # if key in self.implicit_nodes: - # raise ValueError(f'{key} has been paired with {ds}. Please change a unique name.') - # self.register_implicit_nodes(ds_dict) - # add tuple-typed components for module in ds_tuple: if isinstance(module, DynamicalSystem): @@ -442,6 +412,92 @@ def __getattr__(self, item): return super(Container, self).__getattribute__(item) +class Sequential(Container): + def __init__( + self, + *modules, + name: str = None, + mode: Mode = normal, + **kw_modules + ): + super(Sequential, self).__init__(*modules, name=name, mode=mode, **kw_modules) + + def __getattr__(self, item): + """Wrap the dot access ('self.'). """ + child_ds = super(Sequential, self).__getattribute__('implicit_nodes') + if item in child_ds: + return child_ds[item] + else: + return super(Sequential, self).__getattribute__(item) + + def __getitem__(self, key: Union[int, slice]): + if isinstance(key, str): + if key not in self.implicit_nodes: + raise KeyError(f'Does not find a component named {key} in\n {str(self)}') + return self.implicit_nodes[key] + elif isinstance(key, slice): + keys = tuple(self.implicit_nodes.keys())[key] + components = tuple(self.implicit_nodes.values())[key] + return Sequential(dict(zip(keys, components))) + elif isinstance(key, int): + return self.implicit_nodes.values()[key] + elif isinstance(key, (tuple, list)): + all_keys = tuple(self.implicit_nodes.keys()) + all_vals = tuple(self.implicit_nodes.values()) + keys, vals = [], [] + for i in key: + if isinstance(i, int): + raise KeyError(f'We excepted a tuple/list of int, but we got {type(i)}') + keys.append(all_keys[i]) + vals.append(all_vals[i]) + return Sequential(dict(zip(keys, vals))) + else: + raise KeyError(f'Unknown type of key: {type(key)}') + + def __repr__(self): + def f(x): + if not isinstance(x, DynamicalSystem) and callable(x): + signature = inspect.signature(x) + args = [f'{k}={v.default}' for k, v in signature.parameters.items() + if v.default is not inspect.Parameter.empty] + args = ', '.join(args) + while not hasattr(x, '__name__'): + if not hasattr(x, 'func'): + break + x = x.func # Handle functools.partial + if not hasattr(x, '__name__') and hasattr(x, '__class__'): + return x.__class__.__name__ + if args: + return f'{x.__name__}(*, {args})' + return x.__name__ + else: + x = repr(x).split('\n') + x = [x[0]] + [' ' + y for y in x[1:]] + return '\n'.join(x) + + entries = '\n'.join(f' [{i}] {f(x)}' for i, x in enumerate(self)) + return f'{self.__class__.__name__}(\n{entries}\n)' + + def update(self, sha: dict, x: Any) -> Tensor: + """Update function of a sequential model. + + Parameters + ---------- + sha: dict + The shared arguments (ShA) across multiple layers. + x: Any + The input information. + + Returns + ------- + y: Tensor + The output tensor. + """ + for node in self.implicit_nodes.values(): + x = node(sha, x) + return x + + class Network(Container): """Base class to model network objects, an alias of Container. @@ -524,6 +580,10 @@ def reset_state(self, batch_size=None): self.reset_local_delays(nodes) +class System(Network): + pass + + class NeuGroup(DynamicalSystem): """Base class to model neuronal groups. @@ -599,6 +659,9 @@ def update(self, tdi, x=None): raise NotImplementedError(f'Subclass of {self.__class__.__name__} must ' f'implement "update" function.') + def clear_input(self): + pass + class SynConn(DynamicalSystem): """Base class to model two-end synaptic connections. @@ -692,12 +755,6 @@ class SynComponent(DynamicalSystem): def reset_state(self, batch_size=None): pass - def filter(self, g): - return g - - def __call__(self, *args, **kwargs): - return self.filter(*args, **kwargs) - def register_master(self, master: SynConn): if not isinstance(master, SynConn): raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}') @@ -708,10 +765,43 @@ def register_master(self, master: SynConn): def __repr__(self): return self.__class__.__name__ + def __call__(self, *args, **kwargs): + return self.filter(*args, **kwargs) + + def filter(self, g): + raise NotImplementedError + class SynOutput(SynComponent): """Base class for synaptic current output.""" + def __init__( + self, + name: str = None, + target_var: Union[str, bm.Variable] = None, + ): + super(SynOutput, self).__init__(name=name) + # check target variable + if target_var is not None: + if not isinstance(target_var, (str, bm.Variable)): + raise TypeError('"target_var" must be instance of string or Variable. ' + f'But we got {type(target_var)}') + self.target_var: Optional[bm.Variable] = target_var + + def register_master(self, master: SynConn): + super(SynOutput, self).register_master(master) + # initialize target variable to output + if isinstance(self.target_var, str): + if not hasattr(self.master.post, self.target_var): + raise KeyError(f'Post-synaptic group does not have target variable: {self.target_var}') + self.target_var = getattr(self.master.post, self.target_var) + + def filter(self, g): + if self.target_var is None: + return g + else: + self.target_var += g + def update(self, tdi): pass @@ -741,25 +831,25 @@ class TwoEndConn(SynConn): Post-synaptic neuron group. conn : optional, ndarray, JaxArray, dict, TwoEndConnector The connection method between pre- and post-synaptic groups. - output: SynOutput + output: Optional, SynOutput The output for the synaptic current. .. versionadded:: 2.1.13 The output component for a two-end connection model. - stp: SynSTP + stp: Optional, SynSTP The short-term plasticity model for the synaptic variables. .. versionadded:: 2.1.13 The short-term plasticity component for a two-end connection model. - ltp: SynLTP + ltp: Optional, SynLTP The long-term plasticity model for the synaptic variables. .. versionadded:: 2.1.13 The long-term plasticity component for a two-end connection model. - name : str, optional + name: Optional, str The name of the dynamic system. """ @@ -781,11 +871,12 @@ def __init__( mode=mode) # synaptic output - if output is None: output = SynOutput() - if not isinstance(output, SynOutput): - raise TypeError(f'output must be instance of {SynOutput.__name__}, but we got {type(output)}') - self.output: SynOutput = output - self.output.register_master(master=self) + if output is not None: + if not isinstance(output, SynOutput): + raise TypeError(f'output must be instance of {SynOutput.__name__}, ' + f'but we got {type(output)}') + output.register_master(master=self) + self.output: Optional[SynOutput] = output # short-term synaptic plasticity if stp is not None: @@ -1038,89 +1129,3 @@ def check_master(master, *channels, **named_channels): if not isinstance(channel, Channel): raise ValueError(f'Do not support {type(channel)}. ') _check(master, channel) - - -class Sequential(Container): - def __init__( - self, - *modules, - name: str = None, - mode: Mode = normal, - **kw_modules - ): - super(Sequential, self).__init__(*modules, name=name, mode=mode, **kw_modules) - - def __getattr__(self, item): - """Wrap the dot access ('self.'). """ - child_ds = super(Sequential, self).__getattribute__('implicit_nodes') - if item in child_ds: - return child_ds[item] - else: - return super(Sequential, self).__getattribute__(item) - - def __getitem__(self, key: Union[int, slice]): - if isinstance(key, str): - if key not in self.implicit_nodes: - raise KeyError(f'Does not find a component named {key} in\n {str(self)}') - return self.implicit_nodes[key] - elif isinstance(key, slice): - keys = tuple(self.implicit_nodes.keys())[key] - components = tuple(self.implicit_nodes.values())[key] - return Sequential(dict(zip(keys, components))) - elif isinstance(key, int): - return self.implicit_nodes.values()[key] - elif isinstance(key, (tuple, list)): - all_keys = tuple(self.implicit_nodes.keys()) - all_vals = tuple(self.implicit_nodes.values()) - keys, vals = [], [] - for i in key: - if isinstance(i, int): - raise KeyError(f'We excepted a tuple/list of int, but we got {type(i)}') - keys.append(all_keys[i]) - vals.append(all_vals[i]) - return Sequential(dict(zip(keys, vals))) - else: - raise KeyError(f'Unknown type of key: {type(key)}') - - def __repr__(self): - def f(x): - if not isinstance(x, DynamicalSystem) and callable(x): - signature = inspect.signature(x) - args = [f'{k}={v.default}' for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty] - args = ', '.join(args) - while not hasattr(x, '__name__'): - if not hasattr(x, 'func'): - break - x = x.func # Handle functools.partial - if not hasattr(x, '__name__') and hasattr(x, '__class__'): - return x.__class__.__name__ - if args: - return f'{x.__name__}(*, {args})' - return x.__name__ - else: - x = repr(x).split('\n') - x = [x[0]] + [' ' + y for y in x[1:]] - return '\n'.join(x) - - entries = '\n'.join(f' [{i}] {f(x)}' for i, x in enumerate(self)) - return f'{self.__class__.__name__}(\n{entries}\n)' - - def update(self, sha: dict, x: Any) -> Tensor: - """Update function of a sequential model. - - Parameters - ---------- - sha: dict - The shared arguments (ShA) across multiple layers. - x: Any - The input information. - - Returns - ------- - y: Tensor - The output tensor. - """ - for node in self.implicit_nodes.values(): - x = node(sha, x) - return x diff --git a/brainpy/dyn/neurons/biological_models.py b/brainpy/dyn/neurons/biological_models.py index 531a5efd6..25b26bf9a 100644 --- a/brainpy/dyn/neurons/biological_models.py +++ b/brainpy/dyn/neurons/biological_models.py @@ -8,7 +8,7 @@ from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint from brainpy.integrators.sde import sdeint -from brainpy.modes import Mode, BatchingMode, normal +from brainpy.modes import Mode, BatchingMode, TrainingMode, normal from brainpy.tools.checking import check_initializer from brainpy.types import Shape, Tensor @@ -247,8 +247,8 @@ def __init__( self.n = variable(self._n_initializer, mode, self.varshape) self.V = variable(self._V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - # sp_type = bm.dftype() if trainable else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) # integral if self.noise is None: @@ -262,8 +262,8 @@ def reset_state(self, batch_size=None): self.n.value = variable(self._n_initializer, batch_size, self.varshape) self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - # sp_type = bm.dftype() if self.trainable else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dm(self, m, t, V): alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) @@ -303,6 +303,8 @@ def update(self, tdi, x=None): self.m.value = m self.h.value = h self.n.value = n + + def clear_input(self): self.input[:] = 0. @@ -441,8 +443,8 @@ def __init__( self.W = variable(self._W_initializer, mode, self.varshape) self.V = variable(self._V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - # sp_type = bm.dftype() if trainable else bool - self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) # integral if self.noise is None: @@ -454,8 +456,8 @@ def reset_state(self, batch_size=None): self.W.value = variable(self._W_initializer, batch_size, self.varshape) self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - # sp_type = bm.dftype() if self.trainable else bool - self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def dV(self, V, t, W, I_ext): M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) @@ -482,6 +484,8 @@ def update(self, tdi, x=None): spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.spike.value = spike + + def clear_input(self): self.input[:] = 0. @@ -796,6 +800,8 @@ def update(self, tdi, x=None): self.s.value = s self.c.value = c self.q.value = q + + def clear_input(self): self.Id[:] = 0. self.Is[:] = 0. @@ -1003,7 +1009,8 @@ def __init__( self.n = variable(self._n_initializer, mode, self.varshape) self.V = variable(self._V_initializer, mode, self.varshape) self.input = variable(bm.zeros, mode, self.varshape) - self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool + self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape) # integral if self.noise is None: @@ -1016,7 +1023,8 @@ def reset_state(self, batch_size=None): self.n.value = variable(self._n_initializer, batch_size, self.varshape) self.V.value = variable(self._V_initializer, batch_size, self.varshape) self.input.value = variable(bm.zeros, batch_size, self.varshape) - self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool + self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape) def m_inf(self, V): alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) @@ -1054,4 +1062,6 @@ def update(self, tdi, x=None): self.V.value = V self.h.value = h self.n.value = n + + def clear_input(self): self.input[:] = 0. diff --git a/brainpy/dyn/neurons/fractional_models.py b/brainpy/dyn/neurons/fractional_models.py index 8947e85e4..3d0277bf3 100644 --- a/brainpy/dyn/neurons/fractional_models.py +++ b/brainpy/dyn/neurons/fractional_models.py @@ -163,6 +163,8 @@ def update(self, tdi, x=None): self.V.value = V self.w.value = w self.y.value = y + + def clear_input(self): self.input[:] = 0. @@ -304,4 +306,6 @@ def update(self, tdi, x=None): self.V.value = bm.where(spikes, self.c, V) self.u.value = bm.where(spikes, u + self.d, u) self.spike.value = spikes + + def clear_input(self): self.input[:] = 0. diff --git a/brainpy/dyn/neurons/input_groups.py b/brainpy/dyn/neurons/input_groups.py index 2db392e02..dbc35bcbe 100644 --- a/brainpy/dyn/neurons/input_groups.py +++ b/brainpy/dyn/neurons/input_groups.py @@ -13,12 +13,23 @@ __all__ = [ 'InputGroup', + 'OutputGroup', 'SpikeTimeGroup', 'PoissonGroup', ] class InputGroup(NeuGroup): + """Input neuron group for place holder. + + Parameters + ---------- + size: int, tuple of int + keep_size: bool + mode: Mode + name: str + """ + def __init__( self, size: Shape, @@ -39,6 +50,37 @@ def reset_state(self, batch_size=None): pass +class OutputGroup(NeuGroup): + """Output neuron group for place holder. + + Parameters + ---------- + size: int, tuple of int + keep_size: bool + mode: Mode + name: str + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + mode: Mode = normal, + name: str = None, + ): + super(OutputGroup, self).__init__(name=name, + size=size, + keep_size=keep_size, + mode=mode) + self.spike = None + + def update(self, tdi, x=None): + pass + + def reset_state(self, batch_size=None): + pass + + class SpikeTimeGroup(NeuGroup): """The input neuron group characterized by spikes emitting at given times. @@ -92,10 +134,10 @@ def __init__( # data about times and indices self.times = bm.asarray(times) - self.indices = bm.asarray(indices) + self.indices = bm.asarray(indices, dtype=bm.ditype()) # variables - self.i = bm.Variable(bm.zeros(1)) + self.i = bm.Variable(bm.zeros(1, dtype=bm.ditype())) self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) if need_sort: sort_idx = bm.argsort(self.times) diff --git a/brainpy/dyn/neurons/reduced_models.py b/brainpy/dyn/neurons/reduced_models.py index 834292844..170de5d0d 100644 --- a/brainpy/dyn/neurons/reduced_models.py +++ b/brainpy/dyn/neurons/reduced_models.py @@ -290,7 +290,7 @@ def update(self, tdi, x=None): self.V.value = V self.spike.value = spike - # reset input + def clear_input(self): self.input[:] = 0. @@ -480,6 +480,8 @@ def update(self, tdi, x=None): self.V.value = V self.spike.value = spike self.t_last_spike.value = t_last_spike + + def clear_input(self): self.input[:] = 0. @@ -642,6 +644,8 @@ def update(self, tdi, x=None): self.V.value = bm.where(spike, self.V_reset, V) self.w.value = bm.where(spike, w + self.b, w) self.spike.value = spike + + def clear_input(self): self.input[:] = 0. @@ -797,6 +801,8 @@ def update(self, tdi, x=None): self.V.value = V self.spike.value = spike self.t_last_spike.value = t_last_spike + + def clear_input(self): self.input[:] = 0. @@ -962,6 +968,8 @@ def update(self, tdi, x=None): self.V.value = bm.where(spike, self.V_reset, V) self.w.value = bm.where(spike, w + self.b, w) self.spike.value = spike + + def clear_input(self): self.input[:] = 0. @@ -1179,7 +1187,7 @@ def update(self, tdi, x=None): self.V_th.value = V_th self.V.value = V - # reset input + def clear_input(self): self.input[:] = 0. @@ -1234,12 +1242,13 @@ def __init__( a_initializer: Union[Initializer, Callable, Tensor] = OneInit(-50.), # parameter for training - spike_fun: Callable = bm.spike_with_relu_grad, + spike_fun: Callable = bm.spike_with_linear_grad, # other parameters method: str = 'exp_auto', name: str = None, mode: Mode = normal, + eprop: bool = False ): super(ALIFBellec2020, self).__init__(name=name, size=size, @@ -1248,7 +1257,7 @@ def __init__( # parameters self.V_rest = parameter(V_rest, self.varshape, allow_none=False) - self.V_th_reset = parameter(V_th, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) self.R = parameter(R, self.varshape, allow_none=False) self.beta = parameter(beta, self.varshape, allow_none=False) self.tau = parameter(tau, self.varshape, allow_none=False) @@ -1256,6 +1265,7 @@ def __init__( self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) self.noise = init_noise(noise, self.varshape, num_vars=2) self.spike_fun = check_callable(spike_fun, 'spike_fun') + self.eprop = eprop # initializers check_initializer(V_initializer, 'V_initializer') @@ -1279,7 +1289,7 @@ def __init__( else: self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - def dVth(self, a, t): + def da(self, a, t): return -a / self.tau_a def dV(self, V, t, I_ext): @@ -1287,7 +1297,7 @@ def dV(self, V, t, I_ext): @property def derivative(self): - return JointEq([self.dV, self.dVth]) + return JointEq([self.dV, self.da]) def reset_state(self, batch_size=None): self.a.value = variable(self._a_initializer, batch_size, self.varshape) @@ -1314,36 +1324,33 @@ def update(self, tdi, x=None): V = bm.where(refractory, self.V, V) # spike and reset if isinstance(self.mode, TrainingMode): - spike = self.spike_fun(V - self.V_th_reset - self.beta * self.a) - spike_no_grad = stop_gradient(spike) - V -= self.V_th_reset * spike_no_grad - spike_ = spike_no_grad > 0. + spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) + V -= self.V_th * (stop_gradient(spike) if self.eprop else spike) # will be used in other place, like Delta Synapse, so stop its gradient + spike_ = spike > 0. refractory = stop_gradient(bm.logical_or(refractory, spike_).value) t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike).value) else: - spike = V >= (self.V_th_reset + self.beta * self.a) + spike = V >= (self.V_th + self.beta * self.a) refractory = bm.logical_or(refractory, spike) t_last_spike = bm.where(spike, t, self.t_last_spike) - V -= self.V_th_reset * spike - a += spike + V -= self.V_th * spike self.refractory.value = refractory self.t_last_spike.value = t_last_spike else: # spike and reset if isinstance(self.mode, TrainingMode): - spike = self.spike_fun(V - self.V_th_reset - self.beta * self.a) - V -= self.V_th_reset * stop_gradient(spike) + spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) + V -= self.V_th * (stop_gradient(spike) if self.eprop else spike) else: - spike = V >= (self.V_th_reset + self.beta * self.a) - V -= self.V_th_reset * spike - a += spike + spike = V >= (self.V_th + self.beta * self.a) + V -= self.V_th * spike self.spike.value = spike self.V.value = V - self.a.value = a + self.a.value = a + spike - # reset input + def clear_input(self): self.input[:] = 0. @@ -1536,6 +1543,8 @@ def update(self, tdi, x=None): self.V.value = V self.u.value = u self.spike.value = spike + + def clear_input(self): self.input[:] = 0. @@ -1732,6 +1741,8 @@ def update(self, tdi, x=None): self.V.value = V self.y.value = y self.z.value = z + + def clear_input(self): self.input[:] = 0. @@ -1896,4 +1907,6 @@ def update(self, tdi, x=None): self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) self.V.value = V self.w.value = w + + def clear_input(self): self.input[:] = 0. diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py index 235fdc59f..e3ac7f474 100644 --- a/brainpy/dyn/rates/populations.py +++ b/brainpy/dyn/rates/populations.py @@ -178,6 +178,8 @@ def update(self, tdi, x=None): x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x self.y.value = y + + def clear_input(self): self.input[:] = 0. self.input_y[:] = 0. @@ -371,6 +373,8 @@ def update(self, tdi, x=None): x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x self.y.value = y + + def clear_input(self): self.input[:] = 0. self.input_y[:] = 0. @@ -552,6 +556,8 @@ def update(self, tdi, x=None): x, y = self.integral(self.x, self.y, t=t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x self.y.value = y + + def clear_input(self): self.input[:] = 0. self.input_y[:] = 0. @@ -692,6 +698,8 @@ def update(self, tdi, x=None): dt=dt) self.x.value = x self.y.value = y + + def clear_input(self): self.input[:] = 0. self.input_y[:] = 0. @@ -847,6 +855,8 @@ def update(self, tdi, x=None): x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x self.y.value = y + + def clear_input(self): self.input[:] = 0. self.input_y[:] = 0. @@ -965,5 +975,6 @@ def update(self, tdi, x=None): di = di / self.tau_i self.i.value = bm.maximum(self.i + di * dt, 0.) + def clear_input(self): self.Ie[:] = 0. self.Ii[:] = 0. diff --git a/brainpy/dyn/runners.py b/brainpy/dyn/runners.py index 80450b538..4dff85db7 100644 --- a/brainpy/dyn/runners.py +++ b/brainpy/dyn/runners.py @@ -568,6 +568,7 @@ def f_predict(self, shared_args: Dict = None): def _step_func(inputs): t, i, x = inputs + self.target.clear_input() # input step shared = DotDict(t=t, i=i, dt=self.dt) self._input_step(shared) diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py index a91d4b18a..68a530560 100644 --- a/brainpy/train/back_propagation.py +++ b/brainpy/train/back_propagation.py @@ -575,6 +575,7 @@ def f_predict_one_step(self, shared_args: Dict = None, jit: bool = False): def run_func(t, i, x): shared = DotDict(t=t, i=i, dt=self.dt) shared.update(shared_args) + self.target.clear_input() outs = self.target(shared, x) hist = monitor_func(shared) return outs, hist diff --git a/brainpy/train/online.py b/brainpy/train/online.py index 8ae1e34f2..1835d9c84 100644 --- a/brainpy/train/online.py +++ b/brainpy/train/online.py @@ -239,6 +239,7 @@ def _step_func(all_inputs): shared = DotDict(t=t, dt=self.dt, i=i) # input step + self.target.clear_input() self._input_step(shared) # update step From b2be3075150e04a8cc96885a26fca9f034642700 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 29 Jul 2022 15:05:43 +0800 Subject: [PATCH 3/5] updates --- brainpy/__init__.py | 7 +- brainpy/base/base.py | 2 +- brainpy/base/tests/test_collector.py | 47 +++--- brainpy/dyn/layers/__init__.py | 2 +- brainpy/dyn/layers/{dense.py => linear.py} | 0 brainpy/dyn/synapses/abstract_models.py | 45 +++--- brainpy/dyn/synapses/biological_models.py | 20 ++- brainpy/dyn/synapses/learning_rules.py | 4 +- brainpy/dyn/synouts/conductances.py | 39 +++-- brainpy/dyn/synouts/ions.py | 24 ++- brainpy/initialize/generic.py | 2 +- brainpy/initialize/random_inits.py | 14 +- brainpy/initialize/regular_inits.py | 13 +- brainpy/math/delayvars.py | 26 ++-- brainpy/math/numpy_ops.py | 19 ++- brainpy/math/operators/op_register.py | 65 ++++++++- brainpy/math/operators/spikegrad.py | 138 +++++++++++++++--- brainpy/math/random.py | 5 +- brainpy/running/runner.py | 1 + brainpy/tools/others/others.py | 23 ++- brainpy/visualization/figures.py | 7 +- .../Wang_2002_decision_making_spiking.py | 44 +++--- .../whole_brain_simulation_with_fhn.py | 32 +++- ...ole_brain_simulation_with_sl_oscillator.py | 38 ++++- ...Bellec_2020_eprop_evidence_accumulation.py | 132 +++++------------ .../Gauthier_2021_ngrc_lorenz_inference.py | 1 - examples/training/Song_2016_EI_RNN.py | 6 +- examples/training/SurrogateGrad_lif.py | 24 +-- .../SurrogateGrad_lif_fashion_mnist.py | 20 +-- extensions/brainpylib/custom_op/regis_op.py | 10 +- 30 files changed, 521 insertions(+), 289 deletions(-) rename brainpy/dyn/layers/{dense.py => linear.py} (100%) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index d96b67f3c..3e3909b3a 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -62,16 +62,11 @@ synouts, # synaptic output synplast, # synaptic plasticity ) -# from .dyn.base import * -# from .dyn.runners import * +from .dyn.runners import * # dynamics training from . import train -# from .train.base import * -# from .train.online import * -# from .train.offline import * -# from .train.back_propagation import * # automatic dynamics analysis diff --git a/brainpy/base/base.py b/brainpy/base/base.py index ea82463e6..c7ca6f525 100644 --- a/brainpy/base/base.py +++ b/brainpy/base/base.py @@ -120,7 +120,7 @@ def vars(self, method='absolute', level=-1, include_self=True): for node_path, node in nodes.items(): for k in dir(node): v = getattr(node, k) - if isinstance(v, math.Variable): + if isinstance(v, math.Variable) and not k.startswith('_') and not k.endswith('_'): gather[f'{node_path}.{k}' if node_path else k] = v gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()}) return gather diff --git a/brainpy/base/tests/test_collector.py b/brainpy/base/tests/test_collector.py index 4c2ad0bd4..db60230a6 100644 --- a/brainpy/base/tests/test_collector.py +++ b/brainpy/base/tests/test_collector.py @@ -31,11 +31,11 @@ def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., self.int_s = bp.odeint(lambda s, t, TT: self.alpha * TT * (1 - s) - self.beta * s) - def update(self, t, dt): + def update(self, tdi): spike = bp.math.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat - self.t_last_pre_spike[:] = bp.math.where(spike, t, self.t_last_pre_spike) - TT = ((t - self.t_last_pre_spike) < self.T_duration) * self.T - self.s[:] = self.int_s(self.s, t, TT) + self.t_last_pre_spike[:] = bp.math.where(spike, tdi.t, self.t_last_pre_spike) + TT = ((tdi.t - self.t_last_pre_spike) < self.T_duration) * self.T + self.s[:] = self.int_s(self.s, tdi.t, TT) self.post.inputs -= bp.math.sum(self.s, axis=0) * (self.post.V - self.E) @@ -83,8 +83,8 @@ def derivative(self, V, h, n, t, Iext): return dVdt, dhdt, dndt - def update(self, t, dt): - V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs) + def update(self, tdi): + V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.inputs) self.spikes[:] = bp.math.logical_and(self.V < self.V_th, V >= self.V_th) self.V[:] = V self.h[:] = h @@ -160,8 +160,8 @@ def derivative(self, V, h, n, t, Iext): return dVdt, dhdt, dndt - def update(self, t, dt): - V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs) + def update(self, tdi): + V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.inputs) self.spikes[:] = bp.math.logical_and(self.V < self.V_th, V >= self.V_th) self.V[:] = V self.h[:] = h @@ -215,11 +215,11 @@ def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., self.s = bp.math.Variable(bp.math.zeros(self.size)) self.int_s = bp.odeint(lambda s, t, TT: self.alpha * TT * (1 - s) - self.beta * s) - def update(self, t, _i): + def update(self, tdi): spike = bp.math.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat - self.t_last_pre_spike[:] = bp.math.where(spike, t, self.t_last_pre_spike) - TT = ((t - self.t_last_pre_spike) < self.T_duration) * self.T - self.s[:] = self.int_s(self.s, t, TT) + self.t_last_pre_spike[:] = bp.math.where(spike, tdi.t, self.t_last_pre_spike) + TT = ((tdi.t - self.t_last_pre_spike) < self.T_duration) * self.T + self.s[:] = self.int_s(self.s, tdi.t, TT) self.post.inputs -= bp.math.sum(self.g_max * self.s, axis=0) * (self.post.V - self.E) @@ -240,12 +240,11 @@ def test_net_1(): # nodes print() pprint(list(net.nodes().unique().keys())) - assert len(net.nodes()) == 7 + assert len(net.nodes()) == 5 print() pprint(list(net.nodes(method='relative').unique().keys())) - assert len(net.nodes(method='relative')) == 10 - + assert len(net.nodes(method='relative')) == 6 def test_net_vars_2(): @@ -265,9 +264,23 @@ def test_net_vars_2(): # nodes print() pprint(list(net.nodes().keys())) - assert len(net.nodes()) == 7 + assert len(net.nodes()) == 5 print() pprint(list(net.nodes(method='relative').keys())) - assert len(net.nodes(method='relative')) == 10 + assert len(net.nodes(method='relative')) == 6 + + +def test_hidden_variables(): + class BPClass(bp.base.Base): + def __init__(self): + super(BPClass, self).__init__() + + self._rng_ = bp.math.random.RandomState() + self.rng = bp.math.random.RandomState() + + model = BPClass() + + print(model.vars(level=-1).keys()) + assert len(model.vars(level=-1)) == 1 diff --git a/brainpy/dyn/layers/__init__.py b/brainpy/dyn/layers/__init__.py index b6a0292bb..4f95bda2f 100644 --- a/brainpy/dyn/layers/__init__.py +++ b/brainpy/dyn/layers/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from .dropout import * -from .dense import * +from .linear import * from .nvar import * from .reservoir import * from .rnncells import * diff --git a/brainpy/dyn/layers/dense.py b/brainpy/dyn/layers/linear.py similarity index 100% rename from brainpy/dyn/layers/dense.py rename to brainpy/dyn/layers/linear.py diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index 347bd3e6f..b2fd11230 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -79,9 +79,6 @@ class Delta(TwoEndConn): The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. g_max: float, ndarray, JaxArray, Initializer, Callable The synaptic strength. Default is 1. - post_input_key: str - The key of the post variable. It should be a string. The key should - be the attribute of the post-synaptic neuron group. post_ref_key: str Whether the post-synaptic group has refractory period. """ @@ -91,16 +88,15 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: Optional[SynOutput] = None, + output: SynOutput = None, stp: Optional[SynSTP] = None, comp_method: str = 'sparse', g_max: Union[float, Tensor, Initializer, Callable] = 1., delay_step: Union[float, Tensor, Initializer, Callable] = None, - post_input_key: str = 'V', post_ref_key: str = None, - name: str = None, - # training parameters + # other parameters + name: str = None, mode: Mode = normal, stop_spike_gradient: bool = False, ): @@ -108,14 +104,12 @@ def __init__( pre=pre, post=post, conn=conn, - output=CUBA() if output is None else output, + output=CUBA(target_var='V') if output is None else output, stp=stp, mode=mode) # parameters self.stop_spike_gradient = stop_spike_gradient - self.post_input_key = post_input_key - self.check_post_attrs(post_input_key) self.post_ref_key = post_ref_key if post_ref_key: self.check_post_attrs(post_ref_key) @@ -166,11 +160,9 @@ def update(self, tdi, pre_spike=None): post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) if self.post_ref_key: post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key)) - post_vs = self.output(post_vs) # update outputs - target = getattr(self.post, self.post_input_key) - target += post_vs + return self.output(post_vs) class Exponential(TwoEndConn): @@ -279,10 +271,10 @@ def __init__( g_max: Union[float, Tensor, Initializer, Callable] = 1., delay_step: Union[int, Tensor, Initializer, Callable] = None, tau: Union[float, Tensor] = 8.0, - name: str = None, method: str = 'exp_auto', - # training parameters + # other parameters + name: str = None, mode: Mode = normal, stop_spike_gradient: bool = False, ): @@ -351,10 +343,9 @@ def update(self, tdi, pre_spike=None): post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # updates self.g.value = self.integral(self.g.value, t, dt) + post_vs - g_out = self.output(self.g) # output - self.post.input += g_out + return self.output(self.g) class DualExponential(TwoEndConn): @@ -465,9 +456,9 @@ def __init__( tau_rise: Union[float, Tensor] = 1., delay_step: Union[int, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', - name: str = None, - # training parameters + # other parameters + name: str = None, mode: Mode = normal, stop_spike_gradient: bool = False, ): @@ -547,10 +538,9 @@ def update(self, tdi, pre_spike=None): post_vs = f(syn_value) else: post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) - post_vs = self.output(post_vs) # output - self.post.input += post_vs + return self.output(post_vs) class Alpha(DualExponential): @@ -645,9 +635,9 @@ def __init__( delay_step: Union[int, Tensor, Initializer, Callable] = None, tau_decay: Union[float, Tensor] = 10.0, method: str = 'exp_auto', - name: str = None, - # training parameters + # other parameters + name: str = None, mode: Mode = normal, stop_spike_gradient: bool = False, ): @@ -822,7 +812,7 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: Optional[SynOutput] = None, + output: SynOutput = None, stp: Optional[SynSTP] = None, comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.15, @@ -831,9 +821,9 @@ def __init__( a: Union[float, Tensor] = 0.5, tau_rise: Union[float, Tensor] = 2., method: str = 'exp_auto', - name: str = None, - # training parameters + # other parameters + name: str = None, mode: Mode = normal, stop_spike_gradient: bool = False, @@ -950,7 +940,6 @@ def update(self, tdi, pre_spike=None): post_vs = f(syn_value) else: post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) - post_vs = self.output(post_vs) # output - self.post.input += post_vs + return self.output(post_vs) diff --git a/brainpy/dyn/synapses/biological_models.py b/brainpy/dyn/synapses/biological_models.py index 48269620a..e1e26f1ff 100644 --- a/brainpy/dyn/synapses/biological_models.py +++ b/brainpy/dyn/synapses/biological_models.py @@ -150,9 +150,9 @@ def __init__( T: float = 0.5, T_duration: float = 0.5, method: str = 'exp_auto', - name: str = None, - # training parameters + # other parameters + name: str = None, mode: Mode = normal, stop_spike_gradient: bool = False, @@ -244,10 +244,9 @@ def update(self, tdi, pre_spike=None): post_vs = f(syn_value) else: post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) - post_vs = self.output(post_vs) # output - self.post.input += post_vs + return self.output(post_vs) class GABAa(AMPA): @@ -333,9 +332,9 @@ def __init__( T: Union[float, Tensor] = 1., T_duration: Union[float, Tensor] = 1., method: str = 'exp_auto', - name: str = None, - # training parameters + # other parameters + name: str = None, mode: Mode = normal, stop_spike_gradient: bool = False, @@ -491,7 +490,7 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: Optional[SynOutput] = None, + output: SynOutput = None, stp: Optional[SynSTP] = None, comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.15, @@ -503,10 +502,10 @@ def __init__( T_0: Union[float, Tensor] = 1., T_dur: Union[float, Tensor] = 0.5, method: str = 'exp_auto', - name: str = None, - # training parameters + # other parameters mode: Mode = normal, + name: str = None, stop_spike_gradient: bool = False, ): super(BioNMDA, self).__init__(pre=pre, @@ -599,7 +598,6 @@ def update(self, tdi, pre_spike=None): post_vs = f(syn_value) else: post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) - post_vs = self.output(post_vs) # output - self.post.input += post_vs + return self.output(post_vs) diff --git a/brainpy/dyn/synapses/learning_rules.py b/brainpy/dyn/synapses/learning_rules.py index cb49ecc40..aeae458d6 100644 --- a/brainpy/dyn/synapses/learning_rules.py +++ b/brainpy/dyn/synapses/learning_rules.py @@ -222,7 +222,7 @@ def derivative(self): dx = lambda x, t: (1 - x) / self.tau_d return JointEq([dI, du, dx]) - def update(self, t, dt): + def update(self, tdi): # delayed pre-synaptic spikes if self.delay_type == 'homo': delayed_I = self.delay_I(self.delay_step) @@ -231,7 +231,7 @@ def update(self, t, dt): else: delayed_I = self.I self.post.input += bm.syn2post(delayed_I, self.post_ids, self.post.num) - self.I.value, u, x = self.integral(self.I, self.u, self.x, t, dt=dt) + self.I.value, u, x = self.integral(self.I, self.u, self.x, tdi.t, tdi.dt) syn_sps = bm.pre2syn(self.pre.spike, self.pre_ids) u = bm.where(syn_sps, u + self.U * (1 - self.u), u) x = bm.where(syn_sps, x - u * self.x, x) diff --git a/brainpy/dyn/synouts/conductances.py b/brainpy/dyn/synouts/conductances.py index 9012a8669..cd443c192 100644 --- a/brainpy/dyn/synouts/conductances.py +++ b/brainpy/dyn/synouts/conductances.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable +from typing import Union, Callable, Optional +from brainpy.math import Variable from brainpy.dyn.base import SynOutput from brainpy.initialize import parameter, Initializer from brainpy.types import Tensor @@ -32,11 +33,12 @@ class CUBA(SynOutput): COBA """ - def __init__(self, name: str = None): - super(CUBA, self).__init__(name=name) - - def filter(self, g): - return g + def __init__( + self, + target_var: Optional[Union[str, Variable]] = 'input', + name: str = None, + ): + super(CUBA, self).__init__(name=name, target_var=target_var) class COBA(SynOutput): @@ -63,14 +65,29 @@ class COBA(SynOutput): def __init__( self, E: Union[float, Tensor, Callable, Initializer] = 0., - name: str = None + target_var: Optional[Union[str, Variable]] = 'input', + membrane_var: Union[str, Variable] = 'V', + name: str = None, ): - super(COBA, self).__init__(name=name) - self._E = E + super(COBA, self).__init__(name=name, target_var=target_var) + self.E = E + self.membrane_var = membrane_var def register_master(self, master): super(COBA, self).register_master(master) - self.E = parameter(self._E, self.master.post.num, allow_none=False) + self.E = parameter(self.E, self.master.post.num, allow_none=False) + + if isinstance(self.membrane_var, str): + if not hasattr(self.master.post, self.membrane_var): + raise KeyError(f'Post-synaptic group does not have membrane variable: {self.membrane_var}') + self.membrane_var = getattr(self.master.post, self.membrane_var) + elif isinstance(self.membrane_var, Variable): + self.membrane_var = self.membrane_var + else: + raise TypeError('"membrane_var" must be instance of string or Variable. ' + f'But we got {type(self.membrane_var)}') def filter(self, g): - return g * (self.E - self.master.post.V) + V = self.membrane_var.value + I = g * (self.E - V) + return super(COBA, self).filter(I) diff --git a/brainpy/dyn/synouts/ions.py b/brainpy/dyn/synouts/ions.py index a9b40370b..3e82234ad 100644 --- a/brainpy/dyn/synouts/ions.py +++ b/brainpy/dyn/synouts/ions.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable +from typing import Union, Callable, Optional import brainpy.math as bm from brainpy.dyn.base import SynOutput @@ -49,13 +49,16 @@ def __init__( cc_Mg: Union[float, Tensor, Callable, Initializer] = 1.2, alpha: Union[float, Tensor, Callable, Initializer] = 0.062, beta: Union[float, Tensor, Callable, Initializer] = 3.57, - name: str = None + target_var: Optional[Union[str, bm.Variable]] = 'input', + membrane_var: Union[str, bm.Variable] = 'V', + name: str = None, ): - super(MgBlock, self).__init__(name=name) + super(MgBlock, self).__init__(name=name, target_var=target_var) self.E = E self.cc_Mg = cc_Mg self.alpha = alpha self.beta = beta + self.membrane_var = membrane_var def register_master(self, master): super(MgBlock, self).register_master(master) @@ -64,6 +67,17 @@ def register_master(self, master): self.alpha = parameter(self.alpha, self.master.post.num, allow_none=False) self.beta = parameter(self.beta, self.master.post.num, allow_none=False) + if isinstance(self.membrane_var, str): + if not hasattr(self.master.post, self.membrane_var): + raise KeyError(f'Post-synaptic group does not have membrane variable: {self.membrane_var}') + self.membrane_var = getattr(self.master.post, self.membrane_var) + elif isinstance(self.membrane_var, bm.Variable): + self.membrane_var = self.membrane_var + else: + raise TypeError('"membrane_var" must be instance of string or Variable. ' + f'But we got {type(self.membrane_var)}') + def filter(self, g): - V = self.master.post.V.value - return g * (self.E - V) / (1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * V)) + V = self.membrane_var.value + I = g * (self.E - V) / (1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * V)) + return super(MgBlock, self).filter(I) diff --git a/brainpy/initialize/generic.py b/brainpy/initialize/generic.py index 258e86f47..81319f4d5 100644 --- a/brainpy/initialize/generic.py +++ b/brainpy/initialize/generic.py @@ -8,7 +8,7 @@ import brainpy.math as bm from brainpy.tools.others import to_size from brainpy.types import Shape, Tensor -from brainpy.modes import Mode, NormalMode, BatchingMode, TrainingMode +from brainpy.modes import Mode, NormalMode, BatchingMode from .base import Initializer diff --git a/brainpy/initialize/random_inits.py b/brainpy/initialize/random_inits.py index b215ddd15..53cf364f4 100644 --- a/brainpy/initialize/random_inits.py +++ b/brainpy/initialize/random_inits.py @@ -5,6 +5,7 @@ from brainpy import math as bm, tools from .base import InterLayerInitializer + __all__ = [ 'Normal', 'Uniform', @@ -110,16 +111,15 @@ def __call__(self, shape, dtype=None): denominator = (fan_in + fan_out) / 2 else: raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode)) - variance = bm.array(self.scale / denominator, dtype=dtype) + variance = np.array(self.scale / denominator, dtype=dtype) if self.distribution == "truncated_normal": - from scipy.stats import truncnorm - # constant is stddev of standard normal truncated to (-2, 2) - stddev = bm.sqrt(variance) / bm.array(.87962566103423978, dtype) - res = truncnorm(-2, 2).rvs(shape) * stddev + stddev = bm.array(np.sqrt(variance) / .87962566103423978, dtype) + rng = bm.random.RandomState(self.rng.randint(0, int(1e7))) + return rng.truncated_normal(-2, 2, shape, dtype) * stddev elif self.distribution == "normal": - res = self.rng.normal(size=shape) * bm.sqrt(variance) + res = self.rng.randn(*shape) * np.sqrt(variance) elif self.distribution == "uniform": - res = self.rng.uniform(low=-1, high=1, size=shape) * bm.sqrt(3 * variance) + res = self.rng.uniform(low=-1, high=1, size=shape) * np.sqrt(3 * variance) else: raise ValueError("invalid distribution for variance scaling initializer") return bm.asarray(res, dtype=dtype) diff --git a/brainpy/initialize/regular_inits.py b/brainpy/initialize/regular_inits.py index 5407ab5d0..44c37861d 100644 --- a/brainpy/initialize/regular_inits.py +++ b/brainpy/initialize/regular_inits.py @@ -5,6 +5,7 @@ __all__ = [ 'ZeroInit', + 'Constant', 'OneInit', 'Identity', ] @@ -24,8 +25,8 @@ def __repr__(self): return self.__class__.__name__ -class OneInit(InterLayerInitializer): - """One initializer. +class Constant(InterLayerInitializer): + """Constant initializer. Initialize the weights with the given values. @@ -36,7 +37,7 @@ class OneInit(InterLayerInitializer): """ def __init__(self, value=1.): - super(OneInit, self).__init__() + super(Constant, self).__init__() self.value = value def __call__(self, shape, dtype=None): @@ -47,6 +48,12 @@ def __repr__(self): return f'{self.__class__.__name__}(value={self.value})' +class OneInit(Constant): + """One initializer. + """ + pass + + class Identity(InterLayerInitializer): """Returns the identity matrix. diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index cfcd286cb..2e403d604 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jax import vmap -from jax.lax import cond +from jax.lax import cond, stop_gradient from brainpy import check from brainpy.base.base import Base @@ -287,6 +287,7 @@ def __init__( delay_len: int, initial_delay_data: Union[float, int, bool, ndarray, jnp.ndarray, Callable] = None, name: str = None, + batch_axis: int = None, ): super(LengthDelay, self).__init__(name=name) @@ -296,13 +297,14 @@ def __init__( self.data: Variable = None # initialization - self.reset(delay_target, delay_len, initial_delay_data) + self.reset(delay_target, delay_len, initial_delay_data, batch_axis) def reset( self, delay_target, delay_len=None, - initial_delay_data=None + initial_delay_data=None, + batch_axis=None ): if not isinstance(delay_target, (ndarray, jnp.ndarray)): raise ValueError(f'Must be an instance of brainpy.math.ndarray ' @@ -324,9 +326,9 @@ def reset( # delay data if self.data is None: - batch_axis = None - if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): - batch_axis = delay_target.batch_axis + 1 + if batch_axis is None: + if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): + batch_axis = delay_target.batch_axis + 1 self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype), batch_axis=batch_axis) @@ -345,8 +347,8 @@ def reset( raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') def _check_delay(self, delay_len): - raise ValueError(f'The request delay length should be less than the ' - f'maximum delay {self.num_delay_step}. But we got {delay_len}') + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.num_delay_step}. But we got {delay_len}') def __call__(self, delay_len, *indices): # check @@ -354,15 +356,17 @@ def __call__(self, delay_len, *indices): check_error_in_jit(bm.any(delay_len >= self.num_delay_step), self._check_delay, delay_len) # the delay length delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step + delay_idx = stop_gradient(delay_idx) if not jnp.issubdtype(delay_idx.dtype, jnp.integer): raise ValueError(f'"delay_len" must be integer, but we got {delay_len}') # the delay data - indices = (delay_idx, ) + tuple(indices) + indices = (delay_idx,) + tuple(indices) return self.data[indices] def update(self, value: Union[float, JaxArray, jnp.DeviceArray]): - self.data[self.idx[0]] = value - self.idx.value = (self.idx + 1) % self.num_delay_step + idx = stop_gradient(self.idx[0]) + self.data[idx] = value + self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step) class NeuLenDelay(LengthDelay): diff --git a/brainpy/math/numpy_ops.py b/brainpy/math/numpy_ops.py index 94a551638..21a267985 100644 --- a/brainpy/math/numpy_ops.py +++ b/brainpy/math/numpy_ops.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from functools import partial + from typing import Optional import jax.numpy as jnp @@ -97,7 +97,9 @@ 'savetxt', 'savez_compressed', 'show_config', 'typename', # others - 'clip_by_norm', 'as_device_array', 'as_variable', 'as_numpy', 'remove_diag', + 'clip_by_norm', 'remove_diag', + 'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', + 'as_variable', ] _min = min @@ -156,7 +158,10 @@ def as_device_array(tensor, dtype=None): return jnp.asarray(tensor, dtype=dtype) -def as_numpy(tensor, dtype=None): +as_jax = as_device_array + + +def as_ndarray(tensor, dtype=None): """Convert the input to a ``numpy.ndarray``. Parameters @@ -175,11 +180,14 @@ def as_numpy(tensor, dtype=None): is already an ndarray with matching dtype. """ if isinstance(tensor, JaxArray): - return tensor.numpy(dtype=dtype) + return tensor.to_numpy(dtype=dtype) else: return np.asarray(tensor, dtype=dtype) +as_numpy = as_ndarray + + def as_variable(tensor, dtype=None): """Convert the input to a ``brainpy.math.Variable``. @@ -1844,7 +1852,6 @@ def linspace(*args, **kwargs): return JaxArray(res) - @wraps(jnp.logspace) def logspace(*args, **kwargs): return JaxArray(jnp.logspace(*args, **kwargs)) @@ -1886,7 +1893,7 @@ def vander(x, N=None, increasing=False): return JaxArray(jnp.vander(x, N=N, increasing=increasing)) -@wraps(jnp.fill_diagonal) +@wraps(np.fill_diagonal) def fill_diagonal(a, val): if not isinstance(a, JaxArray): raise ValueError(f'Must be a JaxArray, but got {type(a)}') diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py index c4c195523..bccd0c692 100644 --- a/brainpy/math/operators/op_register.py +++ b/brainpy/math/operators/op_register.py @@ -3,8 +3,11 @@ from typing import Union, Sequence, Callable from jax.abstract_arrays import ShapedArray +from jax.tree_util import tree_map +from brainpy.base import Base from brainpy.math.jaxarray import JaxArray +from brainpy import tools from .utils import _check_brainpylib try: @@ -13,10 +16,70 @@ brainpylib = None __all__ = [ - 'register_op' + 'XLACustomOp', + 'register_op', ] +class XLACustomOp(Base): + def __init__(self, name=None, apply_cpu_func_to_gpu: bool = False): + _check_brainpylib(register_op.__name__) + super(XLACustomOp, self).__init__(name=name) + + # abstract evaluation function + if hasattr(self.eval_shape, 'not_customized') and self.eval_shape.not_customized: + raise ValueError('Must implement "eval_shape" for abstract evaluation.') + + # cpu function + if hasattr(self.con_compute, 'not_customized') and self.con_compute.not_customized: + if hasattr(self.cpu_func, 'not_customized') and self.cpu_func.not_customized: + raise ValueError('Must implement one of "cpu_func" or "con_compute".') + else: + cpu_func = self.cpu_func + else: + cpu_func = self.con_compute + + # gpu function + if hasattr(self.gpu_func, 'not_customized') and self.gpu_func.not_customized: + gpu_func = None + else: + gpu_func = self.gpu_func + + # register OP + self.op = brainpylib.register_op(self.name, + cpu_func=cpu_func, + gpu_func=gpu_func, + out_shapes=self.eval_shape, + apply_cpu_func_to_gpu=apply_cpu_func_to_gpu) + + @tools.not_customized + def eval_shape(self, *args, **kwargs): + raise NotImplementedError + + @staticmethod + @tools.not_customized + def con_compute(*args, **kwargs): + raise NotImplementedError + + @staticmethod + @tools.not_customized + def cpu_func(*args, **kwargs): + raise NotImplementedError + + @staticmethod + @tools.not_customized + def gpu_func(*args, **kwargs): + raise NotImplementedError + + def __call__(self, *args, **kwargs): + args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a, + args, is_leaf=lambda a: isinstance(a, JaxArray)) + kwargs = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a, + kwargs, is_leaf=lambda a: isinstance(a, JaxArray)) + res = self.op.bind(*args, **kwargs) + return res[0] if len(res) == 1 else res + + def register_op( op_name: str, cpu_func: Callable, diff --git a/brainpy/math/operators/spikegrad.py b/brainpy/math/operators/spikegrad.py index 923473920..bae6d58f0 100644 --- a/brainpy/math/operators/spikegrad.py +++ b/brainpy/math/operators/spikegrad.py @@ -11,9 +11,12 @@ __all__ = [ 'spike_with_sigmoid_grad', + 'spike_with_linear_grad', + 'spike_with_gaussian_grad', + 'spike_with_mg_grad', + 'spike2_with_sigmoid_grad', - 'spike_with_relu_grad', - 'spike2_with_relu_grad', + 'spike2_with_linear_grad', 'step_pwl' ] @@ -36,9 +39,7 @@ def spike_with_sigmoid_grad(x: Tensor, scale: float = None): z = bm.asarray(x >= 0, dtype=dftype()) def grad(dE_dz): - _scale = scale - if scale is None: - _scale = 100. + _scale = 100. if scale is None else scale dE_dx = dE_dz / (_scale * bm.abs(x) + 1.0) ** 2 if scale is None: return (_consistent_type(dE_dx, x),) @@ -68,9 +69,7 @@ def spike2_with_sigmoid_grad(x_new: Tensor, x_old: Tensor, scale: float = None): z = bm.asarray(bm.logical_and(x_new_comp, x_old_comp), dtype=dftype()) def grad(dE_dz): - _scale = scale - if scale is None: - _scale = 100. + _scale = 100. if scale is None else scale dx_new = (dE_dz / (_scale * bm.abs(x_new) + 1.0) ** 2) * bm.asarray(x_old_comp, dtype=dftype()) dx_old = -(dE_dz / (_scale * bm.abs(x_old) + 1.0) ** 2) * bm.asarray(x_new_comp, dtype=dftype()) if scale is None: @@ -86,7 +85,7 @@ def grad(dE_dz): @custom_gradient -def spike_with_relu_grad(x: Tensor, scale: float = None): +def spike_with_linear_grad(x: Tensor, scale: float = None): """Spike function with the relu surrogate gradient. Parameters @@ -99,22 +98,20 @@ def spike_with_relu_grad(x: Tensor, scale: float = None): z = bm.asarray(x >= 0., dtype=dftype()) def grad(dE_dz): - _scale = scale - if scale is None: _scale = 0.3 + _scale = 0.3 if scale is None else scale dE_dx = dE_dz * bm.maximum(1 - bm.abs(x), 0) * _scale if scale is None: return (_consistent_type(dE_dx, x),) else: dscale = bm.zeros_like(_scale) - return (_consistent_type(dE_dx, x), - _consistent_type(dscale, _scale)) + return (_consistent_type(dE_dx, x), _consistent_type(dscale, _scale)) return z, grad @custom_gradient -def spike2_with_relu_grad(x_new: Tensor, x_old: Tensor, scale: float = 10.): - """Spike function with the relu surrogate gradient. +def spike2_with_linear_grad(x_new: Tensor, x_old: Tensor, scale: float = 10.): + """Spike function with the linear surrogate gradient. Parameters ---------- @@ -130,9 +127,7 @@ def spike2_with_relu_grad(x_new: Tensor, x_old: Tensor, scale: float = 10.): z = bm.asarray(bm.logical_and(x_new_comp, x_old_comp), dtype=dftype()) def grad(dE_dz): - _scale = scale - if scale is None: - _scale = 0.3 + _scale = 0.3 if scale is None else scale dx_new = (dE_dz * bm.maximum(1 - bm.abs(x_new), 0) * _scale) * bm.asarray(x_old_comp, dtype=dftype()) dx_old = -(dE_dz * bm.maximum(1 - bm.abs(x_old), 0) * _scale) * bm.asarray(x_new_comp, dtype=dftype()) if scale is None: @@ -147,6 +142,113 @@ def grad(dE_dz): return z, grad +def _gaussian(x, mu, sigma): + return bm.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / bm.sqrt(2 * bm.pi) / sigma + + +@custom_gradient +def spike_with_gaussian_grad(x, sigma=None, scale=None): + """Spike function with the Gaussian surrogate gradient. + """ + z = bm.asarray(x >= 0., dtype=dftype()) + + def grad(dE_dz): + _scale = 0.5 if scale is None else scale + _sigma = 0.5 if sigma is None else sigma + dE_dx = dE_dz * _gaussian(x, 0., _sigma) * _scale + returns = (_consistent_type(dE_dx, x),) + if sigma is not None: + returns += (_consistent_type(bm.zeros_like(_sigma), sigma), ) + if scale is not None: + returns += (_consistent_type(bm.zeros_like(_scale), scale), ) + return returns + + return z, grad + + +@custom_gradient +def spike_with_mg_grad(x, h=None, s=None, sigma=None, scale=None): + """Spike function with the multi-Gaussian surrogate gradient. + + Parameters + ---------- + x: ndarray + The variable to judge spike. + h: float + The hyper-parameters of approximate function + s: float + The hyper-parameters of approximate function + sigma: float + The gaussian sigma. + scale: float + The gradient scale. + """ + z = bm.asarray(x >= 0., dtype=dftype()) + + def grad(dE_dz): + _sigma = 0.5 if sigma is None else sigma + _scale = 0.5 if scale is None else scale + _s = 6.0 if s is None else s + _h = 0.15 if h is None else h + dE_dx = dE_dz * (_gaussian(x, mu=0., sigma=_sigma) * (1. + _h) + - _gaussian(x, mu=_sigma, sigma=_s * _sigma) * _h + - _gaussian(x, mu=-_sigma, sigma=_s * _sigma) * _h) * _scale + returns = (_consistent_type(dE_dx, x),) + if h is not None: + returns += (_consistent_type(bm.zeros_like(_h), h),) + if s is not None: + returns += (_consistent_type(bm.zeros_like(_s), s),) + if sigma is not None: + returns += (_consistent_type(bm.zeros_like(_sigma), sigma),) + if scale is not None: + returns += (_consistent_type(bm.zeros_like(_scale), scale),) + return returns + + return z, grad + +@custom_gradient +def spike2_with_mg_grad(x_new, x_old, h=None, s=None, sigma=None, scale=None): + """Spike function with the multi-Gaussian surrogate gradient. + + Parameters + ---------- + x: ndarray + The variable to judge spike. + h: float + The hyper-parameters of approximate function + s: float + The hyper-parameters of approximate function + sigma: float + The gaussian sigma. + scale: float + The gradient scale. + """ + x_new_comp = x_new >= 0 + x_old_comp = x_old < 0 + z = bm.asarray(bm.logical_and(x_new_comp, x_old_comp), dtype=dftype()) + + def grad(dE_dz): + _sigma = 0.5 if sigma is None else sigma + _scale = 0.5 if scale is None else scale + _s = 6.0 if s is None else s + _h = 0.15 if h is None else h + dE_dx = dE_dz * (_gaussian(x, mu=0., sigma=_sigma) * (1. + _h) + - _gaussian(x, mu=_sigma, sigma=_s * _sigma) * _h + - _gaussian(x, mu=-_sigma, sigma=_s * _sigma) * _h) * _scale + returns = (_consistent_type(dE_dx, x),) + if h is not None: + returns += (_consistent_type(bm.zeros_like(_h), h),) + if s is not None: + returns += (_consistent_type(bm.zeros_like(_s), s),) + if sigma is not None: + returns += (_consistent_type(bm.zeros_like(_sigma), sigma),) + if scale is not None: + returns += (_consistent_type(bm.zeros_like(_scale), scale),) + return returns + + return z, grad + + @custom_jvp def step_pwl(x, threshold, window=0.5, max_spikes_per_dt: int = bm.inf): """ diff --git a/brainpy/math/random.py b/brainpy/math/random.py index 6aa3f9083..e833998a5 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -1052,7 +1052,10 @@ def noncentral_f(self, dfnum, dfden, nonc, size=None, key=None): lambda aux_data, flat_contents: RandomState(*flat_contents)) # default random generator -DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32)) +__a = JaxArray(None) +__a._value = np.random.randint(0, 10000, size=2, dtype=np.uint32) +DEFAULT = RandomState(__a) +del __a @wraps(np.random.default_rng) diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py index 19ce7a81b..8f88b260f 100644 --- a/brainpy/running/runner.py +++ b/brainpy/running/runner.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- + import gc import types from typing import Callable, Dict, Sequence, Union diff --git a/brainpy/tools/others/others.py b/brainpy/tools/others/others.py index d3f350ac1..79ccdecf2 100644 --- a/brainpy/tools/others/others.py +++ b/brainpy/tools/others/others.py @@ -2,7 +2,7 @@ import _thread as thread import threading -from typing import Optional, Tuple +from typing import Optional, Tuple, Callable import numpy as np from jax import lax @@ -10,6 +10,7 @@ from tqdm.auto import tqdm __all__ = [ + 'not_customized', 'to_size', 'size2num', 'timeout', @@ -17,6 +18,26 @@ ] + +def not_customized(fun: Callable) -> Callable: + """Marks the given module method is not implemented. + + Methods wrapped in @not_customized can define submodules directly within the method. + + For instance:: + + @not_customized + init_fb(self): + ... + + @not_customized + def feedback(self): + ... + """ + fun.not_customized = True + return fun + + def size2num(size): if isinstance(size, int): return size diff --git a/brainpy/visualization/figures.py b/brainpy/visualization/figures.py index 704bc5185..faed331f7 100644 --- a/brainpy/visualization/figures.py +++ b/brainpy/visualization/figures.py @@ -9,7 +9,7 @@ ] -def get_figure(row_num, col_num, row_len=3, col_len=6): +def get_figure(row_num, col_num, row_len=3, col_len=6, name=None): """Get the constrained_layout figure. Parameters @@ -28,6 +28,9 @@ def get_figure(row_num, col_num, row_len=3, col_len=6): fig_and_gs : tuple Figure and GridSpec. """ - fig = plt.figure(figsize=(col_num * col_len, row_num * row_len), constrained_layout=True) + if name is None: + fig = plt.figure(figsize=(col_num * col_len, row_num * row_len), constrained_layout=True) + else: + fig = plt.figure(name, figsize=(col_num * col_len, row_num * row_len), constrained_layout=True) gs = GridSpec(row_num, col_num, figure=fig) return fig, gs diff --git a/examples/simulation/Wang_2002_decision_making_spiking.py b/examples/simulation/Wang_2002_decision_making_spiking.py index 485f7aec8..6be4d7e37 100644 --- a/examples/simulation/Wang_2002_decision_making_spiking.py +++ b/examples/simulation/Wang_2002_decision_making_spiking.py @@ -76,25 +76,25 @@ def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15, mode=bp.modes.Norm # E neurons/pyramid neurons - A = bp.dyn.LIF(num_A, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, - tau_ref=2., V_initializer=bp.init.OneInit(-70.), mode=mode) - B = bp.dyn.LIF(num_B, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, - tau_ref=2., V_initializer=bp.init.OneInit(-70.), mode=mode) - N = bp.dyn.LIF(num_N, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, - tau_ref=2., V_initializer=bp.init.OneInit(-70.), mode=mode) + A = bp.neurons.LIF(num_A, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, + tau_ref=2., V_initializer=bp.init.OneInit(-70.), mode=mode) + B = bp.neurons.LIF(num_B, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, + tau_ref=2., V_initializer=bp.init.OneInit(-70.), mode=mode) + N = bp.neurons.LIF(num_N, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, + tau_ref=2., V_initializer=bp.init.OneInit(-70.), mode=mode) # I neurons/interneurons - I = bp.dyn.LIF(num_inh, V_rest=-70., V_reset=-55., V_th=-50., tau=10., R=0.05, - tau_ref=1., V_initializer=bp.init.OneInit(-70.), mode=mode) + I = bp.neurons.LIF(num_inh, V_rest=-70., V_reset=-55., V_th=-50., tau=10., R=0.05, + tau_ref=1., V_initializer=bp.init.OneInit(-70.), mode=mode) # poisson stimulus IA = PoissonStim(num_A, freq_var=10., t_interval=50., freq_mean=mu0 + mu0 / 100. * coherence, mode=mode) IB = PoissonStim(num_B, freq_var=10., t_interval=50., freq_mean=mu0 - mu0 / 100. * coherence, mode=mode) # noise neurons - self.noise_B = bp.dyn.PoissonGroup(num_B, freqs=poisson_freq, mode=mode) - self.noise_A = bp.dyn.PoissonGroup(num_A, freqs=poisson_freq, mode=mode) - self.noise_N = bp.dyn.PoissonGroup(num_N, freqs=poisson_freq, mode=mode) - self.noise_I = bp.dyn.PoissonGroup(num_inh, freqs=poisson_freq, mode=mode) + self.noise_B = bp.neurons.PoissonGroup(num_B, freqs=poisson_freq, mode=mode) + self.noise_A = bp.neurons.PoissonGroup(num_A, freqs=poisson_freq, mode=mode) + self.noise_N = bp.neurons.PoissonGroup(num_N, freqs=poisson_freq, mode=mode) + self.noise_I = bp.neurons.PoissonGroup(num_inh, freqs=poisson_freq, mode=mode) # define external inputs self.IA2A = synapses.Exponential(IA, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, @@ -112,14 +112,14 @@ def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15, mode=bp.modes.Norm output=synouts.COBA(E=0.), mode=mode, **ampa_par) self.N2I_AMPA = synapses.Exponential(N, I, bp.conn.All2All(), g_max=g_E2I_AMPA, output=synouts.COBA(E=0.), mode=mode, **ampa_par) - self.N2B_NMDA = bp.dyn.NMDA(N, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, - output=synouts.MgBlock(E=0., cc_Mg=1.), mode=mode, **nmda_par) - self.N2A_NMDA = bp.dyn.NMDA(N, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, - output=synouts.MgBlock(E=0., cc_Mg=1.), mode=mode, **nmda_par) - self.N2N_NMDA = bp.dyn.NMDA(N, N, bp.conn.All2All(), g_max=g_E2E_NMDA, - output=synouts.MgBlock(E=0., cc_Mg=1.), mode=mode, **nmda_par) - self.N2I_NMDA = bp.dyn.NMDA(N, I, bp.conn.All2All(), g_max=g_E2I_NMDA, - output=synouts.MgBlock(E=0., cc_Mg=1.), mode=mode, **nmda_par) + self.N2B_NMDA = synapses.NMDA(N, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, + output=synouts.MgBlock(E=0., cc_Mg=1.), mode=mode, **nmda_par) + self.N2A_NMDA = synapses.NMDA(N, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, + output=synouts.MgBlock(E=0., cc_Mg=1.), mode=mode, **nmda_par) + self.N2N_NMDA = synapses.NMDA(N, N, bp.conn.All2All(), g_max=g_E2E_NMDA, + output=synouts.MgBlock(E=0., cc_Mg=1.), mode=mode, **nmda_par) + self.N2I_NMDA = synapses.NMDA(N, I, bp.conn.All2All(), g_max=g_E2I_NMDA, + output=synouts.MgBlock(E=0., cc_Mg=1.), mode=mode, **nmda_par) self.B2B_AMPA = synapses.Exponential(B, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, output=synouts.COBA(E=0.), mode=mode, **ampa_par) @@ -163,7 +163,7 @@ def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15, mode=bp.modes.Norm self.I2N = synapses.Exponential(I, N, bp.conn.All2All(), g_max=g_I2E_GABAa, output=synouts.COBA(E=-70.), mode=mode, **gaba_par) self.I2I = synapses.Exponential(I, I, bp.conn.All2All(), g_max=g_I2I_GABAa, - output=synouts.COBA(E=-70.), mode=mode, **gaba_par) + output=synouts.COBA(E=-70.), mode=mode, **gaba_par) # define external projections self.noise2B = synapses.Exponential(self.noise_B, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, @@ -264,7 +264,7 @@ def batching_run(): num_row, num_col = 3, 4 num_batch = 12 coherence = bm.expand_dims(bm.linspace(-100, 100., num_batch), 1) - net = DecisionMaking(scale=1., coherence=coherence, mu0=20., mode=bp.modes.BatchingMode()) + net = DecisionMaking(scale=1., coherence=coherence, mu0=20., mode=bp.modes.batching) net.reset_state(batch_size=num_batch) runner = bp.dyn.DSRunner( diff --git a/examples/simulation/whole_brain_simulation_with_fhn.py b/examples/simulation/whole_brain_simulation_with_fhn.py index 04d7a40de..3a9ef76a9 100644 --- a/examples/simulation/whole_brain_simulation_with_fhn.py +++ b/examples/simulation/whole_brain_simulation_with_fhn.py @@ -47,7 +47,7 @@ def __init__(self, signal_speed=20.): ) -def brain_simulation(): +def net_simulation(): net = Network() runner = bp.dyn.DSRunner(net, monitors=['fhn.x'], inputs=['fhn.input', 0.72]) runner.run(6e3) @@ -62,6 +62,34 @@ def brain_simulation(): plt.show() +def net_analysis(): + net = Network() + + # get candidate points + runner = bp.dyn.DSRunner( + net, + monitors={'x': net.fhn.x, 'y': net.fhn.y}, + inputs=(net.fhn.input, 0.72), + numpy_mon_after_run=False + ) + runner.run(1e3) + candidates = dict(x=runner.mon.x, y=runner.mon.y) + + # analysis + finder = bp.analysis.SlowPointFinder( + net, + inputs=(net.fhn.input, 0.72), + target_vars={'x': net.fhn.x, 'y': net.fhn.y} + ) + finder.find_fps_with_opt_solver(candidates=candidates) + finder.filter_loss(1e-5) + finder.keep_unique(1e-3) + finder.compute_jacobians({'x': finder._fixed_points['x'][:10], + 'y': finder._fixed_points['y'][:10]}, + plot=True) + + if __name__ == '__main__': # bifurcation_analysis() - brain_simulation() + # net_simulation() + net_analysis() diff --git a/examples/simulation/whole_brain_simulation_with_sl_oscillator.py b/examples/simulation/whole_brain_simulation_with_sl_oscillator.py index 5bd0fe670..852759b9e 100644 --- a/examples/simulation/whole_brain_simulation_with_sl_oscillator.py +++ b/examples/simulation/whole_brain_simulation_with_sl_oscillator.py @@ -23,7 +23,7 @@ def bifurcation_analysis(): class Network(bp.dyn.Network): - def __init__(self): + def __init__(self, noise=0.14): super(Network, self).__init__() # Please download the processed data "hcp.npz" of the @@ -35,7 +35,7 @@ def __init__(self): bm.fill_diagonal(conn_mat, 0) gc = 0.6 # global coupling strength - self.sl = bp.rates.StuartLandauOscillator(80, x_ou_sigma=0.14, y_ou_sigma=0.14, name='sl') + self.sl = bp.rates.StuartLandauOscillator(80, x_ou_sigma=noise, y_ou_sigma=noise) self.coupling = bp.synapses.DiffusiveCoupling( self.sl.x, self.sl.x, var_to_output=self.sl.input, @@ -58,6 +58,36 @@ def simulation(): plt.show() -if __name__ == '__main__': +def net_analysis(): + import matplotlib + matplotlib.use('WebAgg') + bp.math.enable_x64() + from sklearn.decomposition import PCA + + # get candidate points + net = Network() + runner = bp.dyn.DSRunner( + net, + monitors={'x': net.sl.x, 'y': net.sl.y}, + numpy_mon_after_run=False + ) + runner.run(1e3) + candidates = dict(x=runner.mon.x, y=runner.mon.y) + + # analysis + net = Network(noise=0.) + finder = bp.analysis.SlowPointFinder( + net, target_vars={'x': net.sl.x, 'y': net.sl.y} + ) + finder.find_fps_with_opt_solver(candidates=candidates) + finder.filter_loss(1e-5) + finder.keep_unique(1e-3) + finder.compute_jacobians({'x': finder._fixed_points['x'][:10], + 'y': finder._fixed_points['y'][:10]}, + plot=True) + + +if __name__ == '__main__1': # bifurcation_analysis() - simulation() + # simulation() + net_analysis() diff --git a/examples/training/Bellec_2020_eprop_evidence_accumulation.py b/examples/training/Bellec_2020_eprop_evidence_accumulation.py index 444f1e567..906b9740b 100644 --- a/examples/training/Bellec_2020_eprop_evidence_accumulation.py +++ b/examples/training/Bellec_2020_eprop_evidence_accumulation.py @@ -9,7 +9,6 @@ """ -import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import brainpy as bp @@ -20,7 +19,7 @@ bm.set_dt(1.) # Simulation time step [ms] # training parameters -n_batch = 64 # batch size +n_batch = 128 # batch size # neuron model and simulation parameters reg_f = 1. # regularization coefficient for firing rate @@ -34,107 +33,57 @@ regularization_f0 = reg_rate / 1000. # mean target network firing frequency -class ALIF(bp.dyn.NeuGroup): - def __init__( - self, num_in, num_rec, tau=20., thr=0.03, - dampening_factor=0.3, tau_adaptation=200., - stop_z_gradients=False, n_refractory=1, - name=None, mode=bp.modes.training, - ): - super(ALIF, self).__init__(name=name, size=num_rec, mode=mode) - - self.n_in = num_in - self.n_rec = num_rec - self.n_regular = int(num_rec / 2) - self.n_adaptive = num_rec - self.n_regular - - self.n_refractory = n_refractory - self.tau_adaptation = tau_adaptation - # generate threshold decay time constants # - rhos = bm.exp(- bm.get_dt() / tau_adaptation) # decay factors for adaptive threshold - beta = 1.7 * (1 - rhos) / (1 - bm.exp(-1 / tau)) # this is a heuristic value - # multiplicative factors for adaptive threshold - self.beta = bm.concatenate([bm.zeros(self.n_regular), beta * bm.ones(self.n_adaptive)]) - - self.decay_b = jnp.exp(-bm.get_dt() / tau_adaptation) - self.decay = jnp.exp(-bm.get_dt() / tau) - self.dampening_factor = dampening_factor - self.stop_z_gradients = stop_z_gradients - self.tau = tau - self.thr = thr - self.mask = jnp.diag(jnp.ones(num_rec, dtype=bool)) - - # parameters - self.w_in = bm.TrainVar(bm.random.randn(num_in, self.num) / jnp.sqrt(num_in)) - self.w_rec = bm.TrainVar(bm.random.randn(self.num, self.num) / jnp.sqrt(self.num)) - - # Variables - self.v = bm.Variable(jnp.zeros((1, self.num)), batch_axis=0) - self.b = bm.Variable(jnp.zeros((1, self.num)), batch_axis=0) - self.r = bm.Variable(jnp.zeros((1, self.num)), batch_axis=0) - self.spike = bm.Variable(jnp.zeros((1, self.num)), batch_axis=0) - - def reset_state(self, batch_size=1): - self.v.value = bm.Variable(jnp.zeros((batch_size, self.n_rec))) - self.b.value = bm.Variable(jnp.zeros((batch_size, self.n_rec))) - self.r.value = bm.Variable(jnp.zeros((batch_size, self.n_rec))) - self.spike.value = bm.Variable(jnp.zeros((batch_size, self.n_rec))) - - def compute_z(self, v, b): - adaptive_thr = self.thr + b * self.beta - v_scaled = (v - adaptive_thr) / self.thr - z = bm.spike_with_relu_grad(v_scaled, self.dampening_factor) - z = z * 1 / bm.get_dt() - return z - - def update(self, sha, x): - z = self.spike.value - if self.stop_z_gradients: - z = stop_gradient(z) - - # threshold update does not have to depend on the stopped-gradient-z, it's local - new_b = self.decay_b * self.b.value + self.spike.value - - # gradients are blocked in spike transmission - i_t = jnp.matmul(x.value, self.w_in.value) + jnp.matmul(z, jnp.where(self.mask, 0, self.w_rec.value)) - i_reset = z * self.thr * bm.get_dt() - new_v = self.decay * self.v + i_t - i_reset - - # spike generation - self.spike.value = bm.where(self.r.value > 0, 0., self.compute_z(new_v, new_b)) - new_r = bm.clip(self.r.value + self.n_refractory * self.spike - 1, 0, self.n_refractory) - self.r.value = stop_gradient(new_r) - self.v.value = new_v - self.b.value = new_b - - class EligSNN(bp.dyn.Network): - def __init__(self, num_in, num_rec, num_out, stop_z_gradients=False): + def __init__(self, num_in, num_rec, num_out, eprop=True, tau_a=2e3, tau_v=2e1): super(EligSNN, self).__init__() # parameters self.num_in = num_in self.num_rec = num_rec self.num_out = num_out + self.eprop = eprop # neurons - self.r = ALIF(num_in=num_in, num_rec=num_rec, tau=20, tau_adaptation=2000, - n_refractory=5, stop_z_gradients=stop_z_gradients, thr=0.6) + self.i = bp.neurons.InputGroup(num_in) self.o = bp.neurons.LeakyIntegrator(num_out, tau=20, mode=bp.modes.training) + n_regular = int(num_rec / 2) + n_adaptive = num_rec - n_regular + beta1 = bm.exp(- bm.get_dt() / tau_a) + beta2 = 1.7 * (1 - beta1) / (1 - bm.exp(-1 / tau_v)) + beta = bm.concatenate([bm.ones(n_regular), bm.ones(n_adaptive) * beta2]) + self.r = bp.neurons.ALIFBellec2020( + num_rec, V_rest=0., tau_ref=5., V_th=0.6, + tau_a=tau_a, tau=tau_v, beta=beta, + V_initializer=bp.init.ZeroInit(), + a_initializer=bp.init.ZeroInit(), + mode=bp.modes.training, eprop=eprop + ) + # synapses + self.i2r = bp.layers.Dense(num_in, num_rec, + W_initializer=bp.init.KaimingNormal(), + b_initializer=None) + self.i2r.W *= tau_v + self.r2r = bp.layers.Dense(num_rec, num_rec, + W_initializer=bp.init.KaimingNormal(), + b_initializer=None) + self.r2r.W *= tau_v self.r2o = bp.layers.Dense(num_rec, num_out, W_initializer=bp.init.KaimingNormal(), b_initializer=None) - def update(self, sha, x): - self.r(sha, x) - self.o.input += self.r2o(sha, self.r.spike.value) - self.o(sha) + def update(self, shared, x): + self.r.input += self.i2r(shared, x) + z = self.r.spike if self.eprop else stop_gradient(self.r.spike.value) + self.r.input += self.r2r(shared, z) + self.r(shared) + self.o.input += self.r2o(shared, self.r.spike.value) + self.o(shared) return self.o.V.value -net = EligSNN(num_in=40, num_rec=100, num_out=2, stop_z_gradients=True) +net = EligSNN(num_in=40, num_rec=100, num_out=2, eprop=False) @bp.tools.numba_jit @@ -221,11 +170,11 @@ def loss_fun(predicts, targets): trainer = bp.train.BPTT( net, loss_fun, loss_has_aux=True, - optimizer=bp.optimizers.Adam(lr=0.005), + optimizer=bp.optimizers.Adam(lr=0.01), monitors={'r.spike': net.r.spike}, ) -trainer.fit(get_data(64, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0), - num_epoch=30, +trainer.fit(get_data(n_batch, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0), + num_epoch=40, num_report=10) # visualization @@ -243,12 +192,9 @@ def loss_fun(predicts, targets): # insert empty row n_channel = data.shape[1] // 4 zero_fill = np.zeros((data.shape[0], int(n_channel / 2))) - data = np.concatenate((data[:, 3 * n_channel:], - zero_fill, - data[:, 2 * n_channel:3 * n_channel], - zero_fill, - data[:, :n_channel], - zero_fill, + data = np.concatenate((data[:, 3 * n_channel:], zero_fill, + data[:, 2 * n_channel:3 * n_channel], zero_fill, + data[:, :n_channel], zero_fill, data[:, n_channel:2 * n_channel]), axis=1) ax_inp.set_yticklabels([]) ax_inp.add_patch(patches.Rectangle((0, 2 * n_channel + 2 * int(n_channel / 2)), diff --git a/examples/training/Gauthier_2021_ngrc_lorenz_inference.py b/examples/training/Gauthier_2021_ngrc_lorenz_inference.py index 8c04f6d5e..2d67673b6 100644 --- a/examples/training/Gauthier_2021_ngrc_lorenz_inference.py +++ b/examples/training/Gauthier_2021_ngrc_lorenz_inference.py @@ -158,7 +158,6 @@ def update(self, sha, x): # -------- # trainer = bp.train.RidgeTrainer(model, alpha=0.05) -# trainer = bp.train.ForceTrainer(model, ) # warm-up outputs = trainer.predict(X_warmup) diff --git a/examples/training/Song_2016_EI_RNN.py b/examples/training/Song_2016_EI_RNN.py index 5b0cce6ba..1c27883ce 100644 --- a/examples/training/Song_2016_EI_RNN.py +++ b/examples/training/Song_2016_EI_RNN.py @@ -128,17 +128,17 @@ def __init__(self, num_input, num_hidden, num_output, num_batch, self.mask = bm.asarray(mask, dtype=bm.dftype()) # input weight - self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden))) + self.w_ir = bm.TrainVar(w_ir(num_input, num_hidden)) # recurrent weight bound = 1 / num_hidden ** 0.5 - self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden))) + self.w_rr = bm.TrainVar(w_rr(num_hidden, num_hidden)) self.w_rr[:, :self.e_size] /= (self.e_size / self.i_size) self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden)) # readout weight bound = 1 / self.e_size ** 0.5 - self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (self.e_size, num_output))) + self.w_ro = bm.TrainVar(w_ro(self.e_size, num_output)) self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output)) # variables diff --git a/examples/training/SurrogateGrad_lif.py b/examples/training/SurrogateGrad_lif.py index da805b18f..499bcdeac 100644 --- a/examples/training/SurrogateGrad_lif.py +++ b/examples/training/SurrogateGrad_lif.py @@ -35,12 +35,13 @@ def __init__(self, num_in, num_rec, num_out): # synapse: i->r self.i2r = bp.synapses.Exponential(self.i, self.r, bp.conn.All2All(), output=bp.synouts.CUBA(), tau=10., - g_max=bp.init.KaimingNormal(scale=10.), + g_max=bp.init.KaimingNormal(scale=20.), mode=bp.modes.training) # synapse: r->o self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), + # delay_step=10, output=bp.synouts.CUBA(), tau=10., - g_max=bp.init.KaimingNormal(scale=10.), + g_max=bp.init.KaimingNormal(scale=20.), mode=bp.modes.training) def update(self, tdi, spike): @@ -63,7 +64,7 @@ def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): else: ax = plt.subplot(gs[i], sharey=a0) ax.plot(mem[i]) - ax.axis("off") + # ax.axis("off") plt.tight_layout() plt.show() @@ -89,7 +90,7 @@ def print_classification_accuracy(output, target): # Before training -runner = bp.train.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) +runner = bp.dyn.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) out = runner.run(inputs=x_data, inputs_are_batching=True, reset_state=True) plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) plot_voltage_traces(out) @@ -100,7 +101,7 @@ def loss(): key = rng.split_key() X = rng.permutation(x_data, key=key) Y = rng.permutation(y_data, key=key) - looper = bp.train.DSRunner(net, numpy_mon_after_run=False, progress_bar=False) + looper = bp.dyn.DSRunner(net, numpy_mon_after_run=False, progress_bar=False) predictions = looper.run(inputs=X, inputs_are_batching=True, reset_state=True) predictions = bm.max(predictions, axis=1) return bp.losses.cross_entropy_loss(predictions, Y) @@ -119,31 +120,30 @@ def train(_): return l -f_train = bm.make_loop(train, - dyn_vars=f_opt.vars() + net.vars() + {'rng': rng}, - has_return=True) +f_train = bm.make_loop( + train, + dyn_vars=f_opt.vars() + net.vars() + {'rng': rng}, + has_return=True +) # train the network net.reset_state(num_sample) train_losses = [] -for i in range(0, 1000, 100): +for i in range(0, 3000, 100): t0 = time.time() _, ls = f_train(bm.arange(i, i + 100, 1)) print(f'Train {i + 100} epoch, loss = {bm.mean(ls):.4f}, used time {time.time() - t0:.4f} s') train_losses.append(ls) - # visualize the training losses plt.plot(bm.as_numpy(bm.concatenate(train_losses))) plt.xlabel("Epoch") plt.ylabel("Training Loss") plt.show() - # predict the output according to the input data runner = bp.dyn.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) out = runner.run(inputs=x_data, inputs_are_batching=True, reset_state=True) plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) plot_voltage_traces(out) print_classification_accuracy(out, y_data) - diff --git a/examples/training/SurrogateGrad_lif_fashion_mnist.py b/examples/training/SurrogateGrad_lif_fashion_mnist.py index ee7837b9f..64a80dceb 100644 --- a/examples/training/SurrogateGrad_lif_fashion_mnist.py +++ b/examples/training/SurrogateGrad_lif_fashion_mnist.py @@ -45,7 +45,9 @@ def __init__(self, num_in, num_rec, num_out): mode=bp.modes.training) # synapse: r->o self.r2o = bp.synapses.Exponential(self.r, self.o, bp.conn.All2All(), - output=bp.synouts.CUBA(), tau=10., + delay_step=int(1 / bm.get_dt()), + output=bp.synouts.CUBA(), + tau=10., g_max=bp.init.KaimingNormal(scale=2.), mode=bp.modes.training) @@ -160,10 +162,11 @@ def loss_fun(predicts, targets): loss = bp.losses.cross_entropy_loss(predicts, targets) return loss + l2_loss + l1_loss - f_opt = bp.optim.Adam(lr=lr) - trainer = bp.train.BPTT(model, loss_fun, f_opt, - monitors={'r.spike': net.r.spike}, - dyn_vars={'rand': bm.random.DEFAULT}) + trainer = bp.train.BPTT( + model, loss_fun, + optimizer=bp.optim.Adam(lr=lr), + monitors={'r.spike': net.r.spike}, + ) trainer.fit(lambda: sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs), num_epoch=nb_epochs) return trainer.train_losses @@ -172,7 +175,7 @@ def loss_fun(predicts, targets): def compute_classification_accuracy(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28): """ Computes classification accuracy on supplied data in batches. """ accs = [] - runner = bp.dyn.DSRunner(model, dyn_vars={'rand': bm.random.DEFAULT}, progress_bar=False) + runner = bp.dyn.DSRunner(model, progress_bar=False) for x_local, y_local in sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False): output = runner.predict(inputs=x_local, inputs_are_batching=True, reset_state=True) m = bm.max(output, 1) # max over time @@ -184,7 +187,6 @@ def compute_classification_accuracy(model, x_data, y_data, batch_size=128, nb_st def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28): runner = bp.dyn.DSRunner(model, - dyn_vars={'rand': bm.random.DEFAULT}, monitors={'r.spike': model.r.spike}, progress_bar=False) data = sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False) @@ -197,7 +199,7 @@ def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100, net = SNN(num_in=num_input, num_rec=100, num_out=10) # load the dataset -root = r"E:\data\fashion-mnist" +root = r"D:\data\fashion-mnist" train_dataset = bp.datasets.FashionMNIST(root, train=True, transform=None, @@ -229,7 +231,6 @@ def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100, print("Training accuracy: %.3f" % (compute_classification_accuracy(net, x_train, y_train, batch_size=512))) print("Test accuracy: %.3f" % (compute_classification_accuracy(net, x_test, y_test, batch_size=512))) - outs, spikes = get_mini_batch_results(net, x_train, y_train) # Let's plot the hidden layer spiking activity for some input stimuli fig = plt.figure(dpi=100) @@ -247,4 +248,3 @@ def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100, plt.ylabel("Units") plt.tight_layout() plt.show() - diff --git a/extensions/brainpylib/custom_op/regis_op.py b/extensions/brainpylib/custom_op/regis_op.py index 8ededa937..cfc09ca6e 100644 --- a/extensions/brainpylib/custom_op/regis_op.py +++ b/extensions/brainpylib/custom_op/regis_op.py @@ -25,7 +25,6 @@ def register_op( gpu_func: Callable = None, batch_fun: Callable = None, apply_cpu_func_to_gpu: bool = False, - return_primitive: bool = False, ): """ Converting the numba-jitted function in a Jax/XLA compatible primitive. @@ -108,10 +107,6 @@ def eval_rule(*inputs): # Return the outputs return tuple(outputs) - def bind_primitive(*inputs): - result = prim.bind(*inputs) - return result[0] if len(result) == 1 else result - # cpu function prim.def_abstract_eval(abs_eval_rule) prim.def_impl(eval_rule) @@ -129,7 +124,4 @@ def bind_primitive(*inputs): if batch_fun is not None: batching.primitive_batchers[prim] = batch_fun - if return_primitive: - return bind_primitive, prim - else: - return bind_primitive + return prim From 65b48aa7869f30ec069ad0781dc7e53a15832070 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 29 Jul 2022 15:10:01 +0800 Subject: [PATCH 4/5] updates --- brainpy/math/operators/spikegrad.py | 42 ----------------------------- 1 file changed, 42 deletions(-) diff --git a/brainpy/math/operators/spikegrad.py b/brainpy/math/operators/spikegrad.py index bae6d58f0..a39b1f57d 100644 --- a/brainpy/math/operators/spikegrad.py +++ b/brainpy/math/operators/spikegrad.py @@ -206,48 +206,6 @@ def grad(dE_dz): return z, grad -@custom_gradient -def spike2_with_mg_grad(x_new, x_old, h=None, s=None, sigma=None, scale=None): - """Spike function with the multi-Gaussian surrogate gradient. - - Parameters - ---------- - x: ndarray - The variable to judge spike. - h: float - The hyper-parameters of approximate function - s: float - The hyper-parameters of approximate function - sigma: float - The gaussian sigma. - scale: float - The gradient scale. - """ - x_new_comp = x_new >= 0 - x_old_comp = x_old < 0 - z = bm.asarray(bm.logical_and(x_new_comp, x_old_comp), dtype=dftype()) - - def grad(dE_dz): - _sigma = 0.5 if sigma is None else sigma - _scale = 0.5 if scale is None else scale - _s = 6.0 if s is None else s - _h = 0.15 if h is None else h - dE_dx = dE_dz * (_gaussian(x, mu=0., sigma=_sigma) * (1. + _h) - - _gaussian(x, mu=_sigma, sigma=_s * _sigma) * _h - - _gaussian(x, mu=-_sigma, sigma=_s * _sigma) * _h) * _scale - returns = (_consistent_type(dE_dx, x),) - if h is not None: - returns += (_consistent_type(bm.zeros_like(_h), h),) - if s is not None: - returns += (_consistent_type(bm.zeros_like(_s), s),) - if sigma is not None: - returns += (_consistent_type(bm.zeros_like(_sigma), sigma),) - if scale is not None: - returns += (_consistent_type(bm.zeros_like(_scale), scale),) - return returns - - return z, grad - @custom_jvp def step_pwl(x, threshold, window=0.5, max_spikes_per_dt: int = bm.inf): From 0bdf00c3f3098700063f4f04166a51b9d694c180 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 4 Aug 2022 09:59:18 +0800 Subject: [PATCH 5/5] update apis and docs --- .github/PULL_REQUEST_TEMPLATE.md | 4 - .github/workflows/MacOS_CI.yml | 2 +- .github/workflows/Windows_CI.yml | 2 +- brainpy/algorithms/offline.py | 47 +++++--- brainpy/algorithms/online.py | 19 ++-- brainpy/base/collector.py | 13 ++- brainpy/dyn/base.py | 100 +++++++++++++----- brainpy/dyn/neurons/biological_models.py | 5 +- brainpy/dyn/synapses/abstract_models.py | 85 +++------------ brainpy/dyn/synapses/biological_models.py | 27 ++--- brainpy/dyn/synapses/gap_junction.py | 2 +- brainpy/dyn/synouts/conductances.py | 39 ++++--- brainpy/dyn/synouts/ions.py | 42 ++++---- brainpy/math/__init__.py | 3 +- brainpy/math/operators/op_register.py | 87 ++++++++------- brainpy/running/__init__.py | 1 + brainpy/{train => running}/constants.py | 10 ++ brainpy/running/multiprocess.py | 48 ++++++--- brainpy/running/runner.py | 5 +- brainpy/train/back_propagation.py | 4 +- brainpy/train/base.py | 2 +- docs/apis/datasets.rst | 3 +- docs/auto_generater.py | 11 +- docs/index.rst | 5 +- docs/tutorial_analysis/index.rst | 2 +- docs/tutorial_building/index.rst | 6 ++ docs/tutorial_simulation/index.rst | 4 +- docs/tutorial_training/index.rst | 2 +- ...2017_unified_thalamus_oscillation_model.py | 2 +- 29 files changed, 326 insertions(+), 256 deletions(-) rename brainpy/{train => running}/constants.py (57%) create mode 100644 docs/tutorial_building/index.rst diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index e5db149b4..ae3d2c86f 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -14,10 +14,6 @@ -## Screenshots(optional) - - - ## Types of changes diff --git a/.github/workflows/MacOS_CI.yml b/.github/workflows/MacOS_CI.yml index ff581ec3c..c1ff8f4dc 100644 --- a/.github/workflows/MacOS_CI.yml +++ b/.github/workflows/MacOS_CI.yml @@ -36,7 +36,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics +# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | pytest brainpy/ diff --git a/.github/workflows/Windows_CI.yml b/.github/workflows/Windows_CI.yml index 29e3f7ae5..76b24dcca 100644 --- a/.github/workflows/Windows_CI.yml +++ b/.github/workflows/Windows_CI.yml @@ -39,7 +39,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics +# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | pytest brainpy/ diff --git a/brainpy/algorithms/offline.py b/brainpy/algorithms/offline.py index d85c382b2..196daea75 100644 --- a/brainpy/algorithms/offline.py +++ b/brainpy/algorithms/offline.py @@ -39,11 +39,35 @@ class OfflineAlgorithm(Base): def __init__(self, name=None): super(OfflineAlgorithm, self).__init__(name=name) - def __call__(self, targets, inputs, outputs) -> Tensor: + def __call__(self, identifier, target, input, output): """The training procedure. Parameters ---------- + identifier: str + The variable name. + target: JaxArray, ndarray + The 2d target data with the shape of `(num_batch, num_output)`. + input: JaxArray, ndarray + The 2d input data with the shape of `(num_batch, num_input)`. + output: JaxArray, ndarray + The 2d output data with the shape of `(num_batch, num_output)`. + + Returns + ------- + weight: JaxArray + The weights after fit. + """ + return self.call(identifier, target, input, output) + + def call(self, identifier, targets, inputs, outputs) -> Tensor: + """The training procedure. + + Parameters + ---------- + identifier: str + The identifier. + inputs: JaxArray, jax.numpy.ndarray, numpy.ndarray The 3d input data with the shape of `(num_batch, num_time, num_input)`, or, the 2d input data with the shape of `(num_time, num_input)`. @@ -67,8 +91,7 @@ def __repr__(self): return self.__class__.__name__ def initialize(self, identifier, *args, **kwargs): - raise NotImplementedError('Must implement the initialize() ' - 'function by the subclass itself.') + pass def _check_data_2d_atls(x): @@ -166,7 +189,7 @@ def __init__( regularizer=Regularization(0.)) self.gradient_descent = gradient_descent - def __call__(self, targets, inputs, outputs=None): + def call(self, identifier, targets, inputs, outputs=None): # checking inputs = _check_data_2d_atls(bm.asarray(inputs)) targets = _check_data_2d_atls(bm.asarray(targets)) @@ -225,7 +248,7 @@ def __init__( regularizer=L2Regularization(alpha=alpha)) self.gradient_descent = gradient_descent - def __call__(self, targets, inputs, outputs=None): + def call(self, identifier, targets, inputs, outputs=None): # checking inputs = _check_data_2d_atls(bm.asarray(inputs)) targets = _check_data_2d_atls(bm.asarray(targets)) @@ -284,7 +307,7 @@ def __init__( assert gradient_descent self.degree = degree - def __call__(self, targets, inputs, outputs=None): + def call(self, identifier, targets, inputs, outputs=None): # checking inputs = _check_data_2d_atls(bm.asarray(inputs)) targets = _check_data_2d_atls(bm.asarray(targets)) @@ -332,7 +355,7 @@ def __init__( self.gradient_descent = gradient_descent self.sigmoid = Sigmoid() - def __call__(self, targets, inputs, outputs=None) -> Tensor: + def call(self, identifier, targets, inputs, outputs=None) -> Tensor: # prepare data inputs = _check_data_2d_atls(bm.asarray(inputs)) targets = _check_data_2d_atls(bm.asarray(targets)) @@ -395,11 +418,11 @@ def __init__( self.degree = degree self.add_bias = add_bias - def __call__(self, targets, inputs, outputs=None): + def call(self, identifier, targets, inputs, outputs=None): inputs = _check_data_2d_atls(bm.asarray(inputs)) targets = _check_data_2d_atls(bm.asarray(targets)) inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) - return super(PolynomialRegression, self).__call__(targets, inputs) + return super(PolynomialRegression, self).call(identifier, targets, inputs) def predict(self, W, X): X = _check_data_2d_atls(bm.asarray(X)) @@ -431,12 +454,12 @@ def __init__( self.degree = degree self.add_bias = add_bias - def __call__(self, targets, inputs, outputs=None): + def call(self, identifier, targets, inputs, outputs=None): # checking inputs = _check_data_2d_atls(bm.asarray(inputs)) targets = _check_data_2d_atls(bm.asarray(targets)) inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) - return super(PolynomialRidgeRegression, self).__call__(targets, inputs) + return super(PolynomialRidgeRegression, self).call(identifier, targets, inputs) def predict(self, W, X): X = _check_data_2d_atls(bm.asarray(X)) @@ -489,7 +512,7 @@ def __init__( self.gradient_descent = gradient_descent assert gradient_descent - def __call__(self, targets, inputs, outputs=None): + def call(self, identifier, targets, inputs, outputs=None): # checking inputs = _check_data_2d_atls(bm.asarray(inputs)) targets = _check_data_2d_atls(bm.asarray(targets)) diff --git a/brainpy/algorithms/online.py b/brainpy/algorithms/online.py index ade229823..9fd72768a 100644 --- a/brainpy/algorithms/online.py +++ b/brainpy/algorithms/online.py @@ -2,6 +2,8 @@ import brainpy.math as bm from brainpy.base import Base +from jax import vmap +import jax.numpy as jnp __all__ = [ # base class @@ -25,12 +27,12 @@ class OnlineAlgorithm(Base): def __init__(self, name=None): super(OnlineAlgorithm, self).__init__(name=name) - def __call__(self, name, target, input, output): + def __call__(self, identifier, target, input, output): """The training procedure. Parameters ---------- - name: str + identifier: str The variable name. target: JaxArray, ndarray The 2d target data with the shape of `(num_batch, num_output)`. @@ -44,11 +46,10 @@ def __call__(self, name, target, input, output): weight: JaxArray The weights after fit. """ - return self.call(name, target, input, output) + return self.call(identifier, target, input, output) def initialize(self, identifier, *args, **kwargs): - raise NotImplementedError('Must implement the initialize() ' - 'function by the subclass itself.') + pass def call(self, identifier, target, input, output): """The training procedure. @@ -146,11 +147,11 @@ def __init__(self, alpha=0.1, name=None): super(LMS, self).__init__(name=name) self.alpha = alpha - def initialize(self, identifier, *args, **kwargs): - pass - def call(self, identifier, target, input, output): - return -self.alpha * bm.dot(output - target, output) + assert target.shape[0] == input.shape[0] == output.shape[0], 'Batch size should be consistent.' + error = bm.as_jax(output - target) + input = bm.as_jax(input) + return -self.alpha * bm.sum(vmap(jnp.outer)(input, error), axis=0) name2func['lms'] = LMS diff --git a/brainpy/base/collector.py b/brainpy/base/collector.py index e0a6095ba..571d7f672 100644 --- a/brainpy/base/collector.py +++ b/brainpy/base/collector.py @@ -29,9 +29,16 @@ def replace(self, key, new_value): self[key] = new_value def update(self, other, **kwargs): - assert isinstance(other, dict) - for key, value in other.items(): - self[key] = value + assert isinstance(other, (dict, list, tuple)) + if isinstance(other, dict): + for key, value in other.items(): + self[key] = value + elif isinstance(other, (tuple, list)): + num = len(self) + for i, value in enumerate(other): + self[f'_var{i+num}'] = value + else: + raise ValueError(f'Only supports dict/list/tuple, but we got {type(other)}') for key, value in kwargs.items(): self[key] = value diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index c97c2bc09..0b88fbfa1 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -34,7 +34,11 @@ 'NeuGroup', 'CondNeuGroup', # synapse models - 'SynConn', 'SynOutput', 'SynSTP', 'SynLTP', 'TwoEndConn', + 'SynConn', + 'TwoEndConn', + 'SynOut', 'NullSynOut', + 'SynSTP', 'NullSynSTP', + 'SynLTP', 'NullSynLTP', ] @@ -752,15 +756,33 @@ def update(self, tdi, pre_spike=None): class SynComponent(DynamicalSystem): master: SynConn + def __init__(self, *args, **kwargs): + super(SynComponent, self).__init__(*args, **kwargs) + + self._registered = False + + @property + def isregistered(self) -> bool: + return self._registered + + @isregistered.setter + def isregistered(self, val: bool): + if not isinstance(val, bool): + raise ValueError('Must be an instance of bool.') + self._registered = val + def reset_state(self, batch_size=None): pass def register_master(self, master: SynConn): if not isinstance(master, SynConn): raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}') + if self.isregistered: + raise ValueError(f'master has been registered, but we got another master going to be registered.') if hasattr(self, 'master') and self.master != master: raise ValueError(f'master has been registered, but we got another master going to be registered.') self.master = master + self._registered = True def __repr__(self): return self.__class__.__name__ @@ -768,11 +790,14 @@ def __repr__(self): def __call__(self, *args, **kwargs): return self.filter(*args, **kwargs) + def clone(self) -> 'SynComponent': + raise NotImplementedError + def filter(self, g): raise NotImplementedError -class SynOutput(SynComponent): +class SynOut(SynComponent): """Base class for synaptic current output.""" def __init__( @@ -780,7 +805,7 @@ def __init__( name: str = None, target_var: Union[str, bm.Variable] = None, ): - super(SynOutput, self).__init__(name=name) + super(SynOut, self).__init__(name=name) # check target variable if target_var is not None: if not isinstance(target_var, (str, bm.Variable)): @@ -789,7 +814,8 @@ def __init__( self.target_var: Optional[bm.Variable] = target_var def register_master(self, master: SynConn): - super(SynOutput, self).register_master(master) + super(SynOut, self).register_master(master) + # initialize target variable to output if isinstance(self.target_var, str): if not hasattr(self.master.post, self.target_var): @@ -820,6 +846,27 @@ def update(self, tdi, pre_spike): pass +class NullSynOut(SynOut): + def clone(self): + return NullSynOut() + + +class NullSynSTP(SynSTP): + def clone(self): + return NullSynSTP() + + def filter(self, g): + return g + + +class NullSynLTP(SynLTP): + def clone(self): + return NullSynLTP() + + def filter(self, g): + return g + + class TwoEndConn(SynConn): """Base class to model synaptic connections. @@ -858,9 +905,9 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]] = None, - output: Optional[SynOutput] = None, - stp: Optional[SynSTP] = None, - ltp: Optional[SynLTP] = None, + output: SynOut = NullSynOut(), + stp: SynSTP = NullSynSTP(), + ltp: SynLTP = NullSynLTP(), name: str = None, mode: Mode = normal, ): @@ -871,28 +918,31 @@ def __init__( mode=mode) # synaptic output - if output is not None: - if not isinstance(output, SynOutput): - raise TypeError(f'output must be instance of {SynOutput.__name__}, ' - f'but we got {type(output)}') - output.register_master(master=self) - self.output: Optional[SynOutput] = output + output = NullSynOut() if output is None else output + if output.isregistered: output = output.clone() + if not isinstance(output, SynOut): + raise TypeError(f'output must be instance of {SynOut.__name__}, ' + f'but we got {type(output)}') + output.register_master(master=self) + self.output: SynOut = output # short-term synaptic plasticity - if stp is not None: - if not isinstance(stp, SynSTP): - raise TypeError(f'Short-term plasticity must be instance of {SynSTP.__name__}, ' - f'but we got {type(stp)}') - stp.register_master(master=self) - self.stp: Optional[SynSTP] = stp + stp = NullSynSTP() if stp is None else stp + if stp.isregistered: stp = stp.clone() + if not isinstance(stp, SynSTP): + raise TypeError(f'Short-term plasticity must be instance of {SynSTP.__name__}, ' + f'but we got {type(stp)}') + stp.register_master(master=self) + self.stp: SynSTP = stp # long-term synaptic plasticity - if ltp is not None: - if not isinstance(ltp, SynLTP): - raise TypeError(f'Long-term plasticity must be instance of {SynLTP.__name__}, ' - f'but we got {type(ltp)}') - ltp.register_master(master=self) - self.ltp: Optional[SynLTP] = ltp + ltp = NullSynLTP() if ltp is None else ltp + if ltp.isregistered: ltp = ltp.clone() + if not isinstance(ltp, SynLTP): + raise TypeError(f'Long-term plasticity must be instance of {SynLTP.__name__}, ' + f'but we got {type(ltp)}') + ltp.register_master(master=self) + self.ltp: SynLTP = ltp def init_weights( self, diff --git a/brainpy/dyn/neurons/biological_models.py b/brainpy/dyn/neurons/biological_models.py index 25b26bf9a..84b10d736 100644 --- a/brainpy/dyn/neurons/biological_models.py +++ b/brainpy/dyn/neurons/biological_models.py @@ -377,10 +377,7 @@ class MorrisLecar(NeuGroup): References ---------- - .. [4] Meier, Stephen R., Jarrett L. Lancaster, and Joseph M. Starobin. - "Bursting regimes in a reaction-diffusion system with action - potential-dependent equilibrium." PloS one 10.3 (2015): - e0122401. + .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model .. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model """ diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index b2fd11230..60ee6fc57 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -1,17 +1,17 @@ # -*- coding: utf-8 -*- -import warnings from typing import Union, Dict, Callable, Optional from jax import vmap from jax.lax import stop_gradient + import brainpy.math as bm from brainpy.connect import TwoEndConnector, All2All, One2One -from brainpy.dyn.base import NeuGroup, SynOutput, SynSTP, TwoEndConn +from brainpy.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn from brainpy.initialize import Initializer, variable from brainpy.integrators import odeint, JointEq +from brainpy.modes import Mode, BatchingMode, normal from brainpy.types import Tensor -from brainpy.modes import Mode, BatchingMode, TrainingMode, normal, batching, training from ..synouts import CUBA, MgBlock __all__ = [ @@ -88,7 +88,7 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynOutput = None, + output: SynOut = CUBA(target_var='V'), stp: Optional[SynSTP] = None, comp_method: str = 'sparse', g_max: Union[float, Tensor, Initializer, Callable] = 1., @@ -104,7 +104,7 @@ def __init__( pre=pre, post=post, conn=conn, - output=CUBA(target_var='V') if output is None else output, + output=output, stp=stp, mode=mode) @@ -265,7 +265,7 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynOutput = None, + output: SynOut = CUBA(), stp: Optional[SynSTP] = None, comp_method: str = 'sparse', g_max: Union[float, Tensor, Initializer, Callable] = 1., @@ -281,7 +281,7 @@ def __init__( super(Exponential, self).__init__(pre=pre, post=post, conn=conn, - output=CUBA() if output is None else output, + output=output, stp=stp, name=name, mode=mode) @@ -449,7 +449,7 @@ def __init__( post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], stp: Optional[SynSTP] = None, - output: SynOutput = None, + output: SynOut = CUBA(), comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 1., tau_decay: Union[float, Tensor] = 10.0, @@ -465,7 +465,7 @@ def __init__( super(DualExponential, self).__init__(pre=pre, post=post, conn=conn, - output=CUBA() if output is None else output, + output=output, stp=stp, name=name, mode=mode) @@ -628,7 +628,7 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynOutput = None, + output: SynOut = CUBA(), stp: Optional[SynSTP] = None, comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 1., @@ -650,7 +650,7 @@ def __init__( tau_decay=tau_decay, tau_rise=tau_decay, method=method, - output=CUBA() if output is None else output, + output=output, stp=stp, name=name, mode=mode, @@ -767,29 +767,6 @@ class NMDA(TwoEndConn): The name of this synaptic projection. method: str The numerical integration methods. - E: float, JaxArray, ndarray - The reversal potential for the synaptic current. [mV] - - .. deprecated:: 2.1.13 - Parameter `E` is no longer supported. Please use :py:class:`~.MgBlock` instead. - - alpha: float, JaxArray, ndarray - Binding constant. Default 0.062 - - .. deprecated:: 2.1.13 - Parameter `alpha` is no longer supported. Please use :py:class:`~.MgBlock` instead. - - beta: float, JaxArray, ndarray - Unbinding constant. Default 3.57 - - .. deprecated:: 2.1.13 - Parameter `beta` is no longer supported. Please use :py:class:`~.MgBlock` instead. - - cc_Mg: float, JaxArray, ndarray - Concentration of Magnesium ion. Default 1.2 [mM]. - - .. deprecated:: 2.1.13 - Parameter `cc_Mg` is no longer supported. Please use :py:class:`~.MgBlock` instead. References ---------- @@ -812,7 +789,7 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynOutput = None, + output: SynOut = MgBlock(E=0., alpha=0.062, beta=3.57, cc_Mg=1.2), stp: Optional[SynSTP] = None, comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.15, @@ -826,45 +803,7 @@ def __init__( name: str = None, mode: Mode = normal, stop_spike_gradient: bool = False, - - # deprecated - alpha=None, - beta=None, - cc_Mg=None, - E=None, ): - - if output is not None: - if alpha is not None: - raise ValueError(f'Please set "alpha" in "output" argument.') - if beta is not None: - raise ValueError(f'Please set "beta" in "output" argument.') - if cc_Mg is not None: - raise ValueError(f'Please set "cc_Mg" in "output" argument.') - if E is not None: - raise ValueError(f'Please set "E" in "output" argument.') - else: - if alpha is not None: - warnings.warn('Please set "alpha" by using "output=bp.dyn.synouts.MgBlock(alpha)" instead.', - DeprecationWarning) - else: - alpha = 0.062 - if beta is not None: - warnings.warn('Please set "beta" by using "output=bp.dyn.synouts.MgBlock(beta)" instead.', - DeprecationWarning) - else: - beta = 3.57 - if cc_Mg is not None: - warnings.warn('Please set "cc_Mg" by using "output=bp.dyn.synouts.MgBlock(cc_Mg)" instead.', - DeprecationWarning) - else: - cc_Mg = 1.2 - if E is not None: - warnings.warn('Please set "E" by using "output=bp.dyn.synouts.MgBlock(E)" instead.', - DeprecationWarning) - else: - E = 0. - output = MgBlock(E=E, alpha=alpha, beta=beta, cc_Mg=cc_Mg) super(NMDA, self).__init__(pre=pre, post=post, conn=conn, diff --git a/brainpy/dyn/synapses/biological_models.py b/brainpy/dyn/synapses/biological_models.py index e1e26f1ff..3ebfcb2f7 100644 --- a/brainpy/dyn/synapses/biological_models.py +++ b/brainpy/dyn/synapses/biological_models.py @@ -8,7 +8,7 @@ import brainpy.math as bm from brainpy.connect import TwoEndConnector, All2All, One2One -from brainpy.dyn.base import NeuGroup, TwoEndConn, SynSTP, SynOutput +from brainpy.dyn.base import NeuGroup, TwoEndConn, SynSTP, SynOut from brainpy.dyn.synouts import COBA, MgBlock from brainpy.initialize import Initializer, variable from brainpy.integrators import odeint, JointEq @@ -140,7 +140,7 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynOutput = None, + output: SynOut = COBA(E=0.), stp: Optional[SynSTP] = None, comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.42, @@ -155,19 +155,11 @@ def __init__( name: str = None, mode: Mode = normal, stop_spike_gradient: bool = False, - - # deprecated - E: float = None, ): - _E = 0. - if E is not None: - warnings.warn('"E" is deprecated in AMPA model. Please define "E" with ' - 'brainpy.dyn.synouts.COBA.', DeprecationWarning) - _E = E super(AMPA, self).__init__(pre=pre, post=post, conn=conn, - output=COBA(E=_E) if output is None else output, + output=output, stp=stp, name=name, mode=mode) @@ -322,7 +314,7 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynOutput = None, + output: SynOut = COBA(E=-80.), stp: Optional[SynSTP] = None, comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.04, @@ -341,15 +333,10 @@ def __init__( # deprecated E: Union[float, Tensor] = None, ): - _E = -80. - if E is not None: - warnings.warn('"E" is deprecated in AMPA model. Please define "E" with ' - 'brainpy.dyn.synouts.COBA.', DeprecationWarning) - _E = E super(GABAa, self).__init__(pre=pre, post=post, conn=conn, - output=COBA(E=_E) if output is None else output, + output=output, stp=stp, comp_method=comp_method, delay_step=delay_step, @@ -490,7 +477,7 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - output: SynOutput = None, + output: SynOut = MgBlock(E=0.), stp: Optional[SynSTP] = None, comp_method: str = 'dense', g_max: Union[float, Tensor, Initializer, Callable] = 0.15, @@ -511,7 +498,7 @@ def __init__( super(BioNMDA, self).__init__(pre=pre, post=post, conn=conn, - output=MgBlock(E=0.) if output is None else output, + output=output, stp=stp, name=name, mode=mode) diff --git a/brainpy/dyn/synapses/gap_junction.py b/brainpy/dyn/synapses/gap_junction.py index 42f263612..46e304078 100644 --- a/brainpy/dyn/synapses/gap_junction.py +++ b/brainpy/dyn/synapses/gap_junction.py @@ -4,7 +4,7 @@ import brainpy.math as bm from brainpy.connect import TwoEndConnector -from brainpy.dyn.base import NeuGroup, SynOutput, SynSTP, TwoEndConn +from brainpy.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn from brainpy.initialize import Initializer, parameter from brainpy.types import Tensor from ..synouts import CUBA diff --git a/brainpy/dyn/synouts/conductances.py b/brainpy/dyn/synouts/conductances.py index cd443c192..04644f451 100644 --- a/brainpy/dyn/synouts/conductances.py +++ b/brainpy/dyn/synouts/conductances.py @@ -3,7 +3,7 @@ from typing import Union, Callable, Optional from brainpy.math import Variable -from brainpy.dyn.base import SynOutput +from brainpy.dyn.base import SynOut from brainpy.initialize import parameter, Initializer from brainpy.types import Tensor @@ -13,7 +13,7 @@ ] -class CUBA(SynOutput): +class CUBA(SynOut): r"""Current-based synaptic output. Given the conductance, this model outputs the post-synaptic current with a identity function: @@ -38,10 +38,14 @@ def __init__( target_var: Optional[Union[str, Variable]] = 'input', name: str = None, ): + self._target_var = target_var super(CUBA, self).__init__(name=name, target_var=target_var) + def clone(self): + return CUBA(target_var=self._target_var) -class COBA(SynOutput): + +class COBA(SynOut): r"""Conductance-based synaptic output. Given the synaptic conductance, the model output the post-synaptic current with @@ -70,22 +74,29 @@ def __init__( name: str = None, ): super(COBA, self).__init__(name=name, target_var=target_var) - self.E = E - self.membrane_var = membrane_var + self._E = E + self._target_var = target_var + self._membrane_var = membrane_var + + def clone(self): + return COBA(E=self._E, target_var=self._target_var, membrane_var=self._membrane_var) def register_master(self, master): super(COBA, self).register_master(master) - self.E = parameter(self.E, self.master.post.num, allow_none=False) - - if isinstance(self.membrane_var, str): - if not hasattr(self.master.post, self.membrane_var): - raise KeyError(f'Post-synaptic group does not have membrane variable: {self.membrane_var}') - self.membrane_var = getattr(self.master.post, self.membrane_var) - elif isinstance(self.membrane_var, Variable): - self.membrane_var = self.membrane_var + + # reversal potential + self.E = parameter(self._E, self.master.post.num, allow_none=False) + + # membrane potential + if isinstance(self._membrane_var, str): + if not hasattr(self.master.post, self._membrane_var): + raise KeyError(f'Post-synaptic group does not have membrane variable: {self._membrane_var}') + self.membrane_var = getattr(self.master.post, self._membrane_var) + elif isinstance(self._membrane_var, Variable): + self.membrane_var = self._membrane_var else: raise TypeError('"membrane_var" must be instance of string or Variable. ' - f'But we got {type(self.membrane_var)}') + f'But we got {type(self._membrane_var)}') def filter(self, g): V = self.membrane_var.value diff --git a/brainpy/dyn/synouts/ions.py b/brainpy/dyn/synouts/ions.py index 3e82234ad..4c73c8efe 100644 --- a/brainpy/dyn/synouts/ions.py +++ b/brainpy/dyn/synouts/ions.py @@ -3,7 +3,7 @@ from typing import Union, Callable, Optional import brainpy.math as bm -from brainpy.dyn.base import SynOutput +from brainpy.dyn.base import SynOut from brainpy.initialize import parameter, Initializer from brainpy.types import Tensor @@ -12,7 +12,7 @@ ] -class MgBlock(SynOutput): +class MgBlock(SynOut): r"""Synaptic output based on Magnesium blocking. Given the synaptic conductance, the model output the post-synaptic current with @@ -54,30 +54,34 @@ def __init__( name: str = None, ): super(MgBlock, self).__init__(name=name, target_var=target_var) - self.E = E - self.cc_Mg = cc_Mg - self.alpha = alpha - self.beta = beta - self.membrane_var = membrane_var + self._E = E + self._cc_Mg = cc_Mg + self._alpha = alpha + self._beta = beta + self._membrane_var = membrane_var def register_master(self, master): super(MgBlock, self).register_master(master) - self.E = parameter(self.E, self.master.post.num, allow_none=False) - self.cc_Mg = parameter(self.cc_Mg, self.master.post.num, allow_none=False) - self.alpha = parameter(self.alpha, self.master.post.num, allow_none=False) - self.beta = parameter(self.beta, self.master.post.num, allow_none=False) - - if isinstance(self.membrane_var, str): - if not hasattr(self.master.post, self.membrane_var): - raise KeyError(f'Post-synaptic group does not have membrane variable: {self.membrane_var}') - self.membrane_var = getattr(self.master.post, self.membrane_var) - elif isinstance(self.membrane_var, bm.Variable): - self.membrane_var = self.membrane_var + + self.E = parameter(self._E, self.master.post.num, allow_none=False) + self.cc_Mg = parameter(self._cc_Mg, self.master.post.num, allow_none=False) + self.alpha = parameter(self._alpha, self.master.post.num, allow_none=False) + self.beta = parameter(self._beta, self.master.post.num, allow_none=False) + if isinstance(self._membrane_var, str): + if not hasattr(self.master.post, self._membrane_var): + raise KeyError(f'Post-synaptic group does not have membrane variable: {self._membrane_var}') + self.membrane_var = getattr(self.master.post, self._membrane_var) + elif isinstance(self._membrane_var, bm.Variable): + self.membrane_var = self._membrane_var else: raise TypeError('"membrane_var" must be instance of string or Variable. ' - f'But we got {type(self.membrane_var)}') + f'But we got {type(self._membrane_var)}') def filter(self, g): V = self.membrane_var.value I = g * (self.E - V) / (1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * V)) return super(MgBlock, self).filter(I) + + def clone(self): + return MgBlock(E=self._E, cc_Mg=self._cc_Mg, alpha=self._alpha, + beta=self._beta, membrane_var=self._membrane_var) diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 36f8c33ce..284279290 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -13,8 +13,7 @@ - automatic differentiation for class objects - dedicated operators for brain dynamics - activation functions -- device switching -- default type switching +- device/dtype switching - and others Details in the following. diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py index bccd0c692..fdf383f37 100644 --- a/brainpy/math/operators/op_register.py +++ b/brainpy/math/operators/op_register.py @@ -7,7 +7,6 @@ from brainpy.base import Base from brainpy.math.jaxarray import JaxArray -from brainpy import tools from .utils import _check_brainpylib try: @@ -22,54 +21,68 @@ class XLACustomOp(Base): - def __init__(self, name=None, apply_cpu_func_to_gpu: bool = False): + """Creating a XLA custom call operator. + + Parameters + ---------- + name: str + The name of operator. + eval_shape: callable + The function to evaluate the shape and dtype of the output according to the input. + This function should receive the abstract information of inputs, and return the + abstract information of the outputs. For example: + + >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...): + >>> return out1_info, out2_info + con_compute: callable + The function to make the concrete computation. This function receives inputs, + and returns outputs. For example: + + >>> def con_compute(inp1, inp2, inp3, ...): + >>> return out1, out2 + cpu_func: callable + The function defines the computation on CPU backend. Same as ``con_compute``. + gpu_func: callable + The function defines the computation on GPU backend. Currently, this function is not supportted. + apply_cpu_func_to_gpu: bool + Whether allows to apply CPU function on GPU backend. If True, the GPU data will move to CPU, + and after calculation, the returned outputs on CPU backend will move to GPU. + """ + + def __init__( + self, + eval_shape: Callable = None, + con_compute: Callable = None, + cpu_func: Callable = None, + gpu_func: Callable = None, + apply_cpu_func_to_gpu: bool = False, + name: str = None, + ): _check_brainpylib(register_op.__name__) super(XLACustomOp, self).__init__(name=name) # abstract evaluation function - if hasattr(self.eval_shape, 'not_customized') and self.eval_shape.not_customized: - raise ValueError('Must implement "eval_shape" for abstract evaluation.') + if eval_shape is None: + raise ValueError('Must provide "eval_shape" for abstract evaluation.') # cpu function - if hasattr(self.con_compute, 'not_customized') and self.con_compute.not_customized: - if hasattr(self.cpu_func, 'not_customized') and self.cpu_func.not_customized: - raise ValueError('Must implement one of "cpu_func" or "con_compute".') - else: - cpu_func = self.cpu_func + if con_compute is None: + if cpu_func is None: + raise ValueError('Must provide one of "cpu_func" or "con_compute".') else: - cpu_func = self.con_compute + cpu_func = con_compute # gpu function - if hasattr(self.gpu_func, 'not_customized') and self.gpu_func.not_customized: + if gpu_func is None: gpu_func = None - else: - gpu_func = self.gpu_func # register OP - self.op = brainpylib.register_op(self.name, - cpu_func=cpu_func, - gpu_func=gpu_func, - out_shapes=self.eval_shape, - apply_cpu_func_to_gpu=apply_cpu_func_to_gpu) - - @tools.not_customized - def eval_shape(self, *args, **kwargs): - raise NotImplementedError - - @staticmethod - @tools.not_customized - def con_compute(*args, **kwargs): - raise NotImplementedError - - @staticmethod - @tools.not_customized - def cpu_func(*args, **kwargs): - raise NotImplementedError - - @staticmethod - @tools.not_customized - def gpu_func(*args, **kwargs): - raise NotImplementedError + _, self.op = brainpylib.register_op(self.name, + cpu_func=cpu_func, + gpu_func=gpu_func, + out_shapes=eval_shape, + apply_cpu_func_to_gpu=apply_cpu_func_to_gpu, + return_primitive=True) def __call__(self, *args, **kwargs): args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a, diff --git a/brainpy/running/__init__.py b/brainpy/running/__init__.py index e6441aaea..6168a1591 100644 --- a/brainpy/running/__init__.py +++ b/brainpy/running/__init__.py @@ -7,3 +7,4 @@ from .multiprocess import * from .runner import * +from .constants import * diff --git a/brainpy/train/constants.py b/brainpy/running/constants.py similarity index 57% rename from brainpy/train/constants.py rename to brainpy/running/constants.py index 6c26c36ad..64c525b25 100644 --- a/brainpy/train/constants.py +++ b/brainpy/running/constants.py @@ -1,5 +1,15 @@ # -*- coding: utf-8 -*- + +__all__ = [ + 'TRAIN_PHASE', + 'FIT_PHASE', + 'PREDICT_PHASE', + 'RUN_PHASE', + 'LOSS_PHASE', +] + + TRAIN_PHASE = 'fit' FIT_PHASE = 'fit' PREDICT_PHASE = 'predict' diff --git a/brainpy/running/multiprocess.py b/brainpy/running/multiprocess.py index d0c99a7c7..89bf3dae2 100644 --- a/brainpy/running/multiprocess.py +++ b/brainpy/running/multiprocess.py @@ -2,23 +2,29 @@ import multiprocessing + __all__ = [ 'process_pool', 'process_pool_lock', + 'vectorize_map', + 'parallelize_map', ] -def process_pool(func, all_net_params, nb_process): +def process_pool(func, all_params, num_process): """Run multiple models in multi-processes. + .. Note:: + This multiprocessing function should be called within a `if __main__ == '__main__':` syntax. + Parameters ---------- func : callable The function to run model. - all_net_params : a_list, tuple + all_params : a_list, tuple The parameters of the function arguments. The parameters for each process can be a tuple, or a dictionary. - nb_process : int + num_process : int The number of the processes. Returns @@ -26,19 +32,19 @@ def process_pool(func, all_net_params, nb_process): results : list Process results. """ - print('{} jobs total.'.format(len(all_net_params))) - pool = multiprocessing.Pool(processes=nb_process) + print('{} jobs total.'.format(len(all_params))) + pool = multiprocessing.Pool(processes=num_process) results = [] - for net_params in all_net_params: - if isinstance(net_params, (list, tuple)): - results.append(pool.apply_async(func, args=tuple(net_params))) - elif isinstance(net_params, dict): - results.append(pool.apply_async(func, kwds=net_params)) + for params in all_params: + if isinstance(params, (list, tuple)): + results.append(pool.apply_async(func, args=tuple(params))) + elif isinstance(params, dict): + results.append(pool.apply_async(func, kwds=params)) else: - raise ValueError('Unknown parameter type: ', type(net_params)) + raise ValueError('Unknown parameter type: ', type(params)) pool.close() pool.join() - return results + return [r.get() for r in results] def process_pool_lock(func, all_net_params, nb_process): @@ -46,7 +52,7 @@ def process_pool_lock(func, all_net_params, nb_process): Sometimes, you want to synchronize the processes. For example, if you want to write something in a document, you cannot let - multi-process simultaneously open this same file. So, you need + multiprocess simultaneously open this same file. So, you need add a `lock` argument in your defined `func`: .. code-block:: python @@ -60,6 +66,9 @@ def some_func(..., lock, ...): In such case, you can use `process_pool_lock()` to run your model. + .. Note:: + This multiprocessing function should be called within a `if __main__ == '__main__':` syntax. + Parameters ---------- func : callable @@ -89,4 +98,15 @@ def some_func(..., lock, ...): raise ValueError('Unknown parameter type: ', type(net_params)) pool.close() pool.join() - return results + return [r.get() for r in results] + + +def vectorize_map(func, all_params, num_thread): + pass + + +def parallelize_map(func, all_params, num_process): + pass + + + diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py index 8f88b260f..d4e55a9ed 100644 --- a/brainpy/running/runner.py +++ b/brainpy/running/runner.py @@ -12,6 +12,7 @@ from brainpy.errors import MonitorError, RunningError from brainpy.tools.checking import check_dict_data from brainpy.tools.others import DotDict +from . import constants as C __all__ = [ 'Runner', @@ -72,11 +73,11 @@ def __init__( # jit instruction self.jit = dict() if isinstance(jit, bool): - self.jit = {'predict': jit} + self.jit = {C.PREDICT_PHASE: jit} elif isinstance(jit, dict): for k, v in jit.items(): self.jit[k] = v - self.jit = {'predict': jit.pop('predict', True)} + self.jit[C.PREDICT_PHASE] = jit.pop(C.PREDICT_PHASE, True) else: raise ValueError(f'Unknown "jit" setting: {jit}') diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py index 68a530560..c4e2704f8 100644 --- a/brainpy/train/back_propagation.py +++ b/brainpy/train/back_propagation.py @@ -15,7 +15,7 @@ from brainpy.tools.checking import serialize_kwargs from brainpy.tools.others import DotDict from brainpy.types import Tensor, Output -from . import constants as c +from ..running import constants as c from .base import DSTrainer __all__ = [ @@ -350,7 +350,7 @@ def _get_data_by_callable(self, dataset: Callable, num_batch=None): def _get_data_by_tensor(self, dataset, num_batch=None, shuffle=False): if num_batch is None: - raise ValueError('Must provide "num_batch" when dataset is not a callable function.') + raise ValueError('Must provide "batch_size" when dataset is not a callable function.') assert isinstance(dataset, (tuple, list)) and len(dataset) == 2 xs, ys = dataset num_sample = self._get_batch_size(xs) diff --git a/brainpy/train/base.py b/brainpy/train/base.py index 4b1669259..b1d50eb2c 100644 --- a/brainpy/train/base.py +++ b/brainpy/train/base.py @@ -9,7 +9,7 @@ from brainpy.dyn.runners import DSRunner from brainpy.tools.checking import check_dict_data from brainpy.types import Tensor, Output -from . import constants as c +from ..running import constants as c __all__ = [ 'DSTrainer', diff --git a/docs/apis/datasets.rst b/docs/apis/datasets.rst index d4bda0638..a2de549f2 100644 --- a/docs/apis/datasets.rst +++ b/docs/apis/datasets.rst @@ -8,4 +8,5 @@ .. toctree:: :maxdepth: 1 - auto/datasets/chaotic_systems + auto/datasets/chaos + auto/datasets/vision diff --git a/docs/auto_generater.py b/docs/auto_generater.py index 173f9e328..08af8b4f4 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -252,8 +252,11 @@ def generate_datasets_docs(path='apis/auto/datasets/'): if not os.path.exists(path): os.makedirs(path) - write_module(module_name='brainpy.datasets.chaotic_systems', - filename=os.path.join(path, 'chaotic_systems.rst'), + write_module(module_name='brainpy.datasets.chaos', + filename=os.path.join(path, 'chaos.rst'), + header='Chaotic Systems') + write_module(module_name='brainpy.datasets.vision', + filename=os.path.join(path, 'vision.rst'), header='Chaotic Systems') @@ -309,7 +312,7 @@ def generate_dyn_docs(path='apis/auto/dyn/'): module_and_name = [ ('conv', 'Convolutional Layers'), ('dropout', 'Dropout Layers'), - ('dense', 'Dense Connection Layers'), + ('linear', 'Dense Connection Layers'), ('nvar', 'NVAR Layers'), ('reservoir', 'Reservoir Layers'), ('rnncells', 'Artificial Recurrent Layers'), @@ -324,7 +327,7 @@ def generate_dyn_docs(path='apis/auto/dyn/'): module_and_name = [ ('abstract_models', 'Abstract Models'), ('biological_models', 'Biological Models'), - ('couplings', 'Coupling Models'), + ('delay_couplings', 'Coupling Models'), ('gap_junction', 'Gap Junction Models'), ('learning_rules', 'Learning Rule Models'), ] diff --git a/docs/index.rst b/docs/index.rst index 726c45847..5f08e858f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,8 +9,8 @@ high-performance Brain Dynamics Programming (BDP). Among its key ingredients, Br stochastic differential equations (SDEs), delay differential equations (DDEs), fractional differential equations (FDEs), etc. -- **Dynamics simulation** tools for various brain objects, like - neurons, synapses, networks, soma, dendrites, channels, and even more. +- **Dynamics building** with the modular and composable programming interface. +- **Dynamics simulation** tools for various brain objects. - **Dynamics training** tools with various machine learning algorithms, like FORCE learning, ridge regression, back-propagation, etc. - **Dynamics analysis** tools for differential equations, including @@ -48,6 +48,7 @@ The code of BrainPy is open-sourced at GitHub: :caption: BDP Tutorials tutorial_basics/index + tutorial_building/index tutorial_simulation/index tutorial_training/index tutorial_analysis/index diff --git a/docs/tutorial_analysis/index.rst b/docs/tutorial_analysis/index.rst index 7ad5154f9..878684684 100644 --- a/docs/tutorial_analysis/index.rst +++ b/docs/tutorial_analysis/index.rst @@ -1,4 +1,4 @@ -Dynamics Analysis +Model Analysis ================= .. toctree:: diff --git a/docs/tutorial_building/index.rst b/docs/tutorial_building/index.rst new file mode 100644 index 000000000..ce8fad09f --- /dev/null +++ b/docs/tutorial_building/index.rst @@ -0,0 +1,6 @@ +Model Building +============== + +.. toctree:: + :maxdepth: 1 + diff --git a/docs/tutorial_simulation/index.rst b/docs/tutorial_simulation/index.rst index 0d2a57006..5cfa6b661 100644 --- a/docs/tutorial_simulation/index.rst +++ b/docs/tutorial_simulation/index.rst @@ -1,5 +1,5 @@ -Dynamics Simulation -=================== +Model Simulation +================ .. toctree:: :maxdepth: 1 diff --git a/docs/tutorial_training/index.rst b/docs/tutorial_training/index.rst index 59c741ace..d3d039942 100644 --- a/docs/tutorial_training/index.rst +++ b/docs/tutorial_training/index.rst @@ -1,4 +1,4 @@ -Dynamics Training +Model Training ================= This tutorial shows how to train a dynamical system from data or task, diff --git a/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py b/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py index 01cf44193..127d0cde3 100644 --- a/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py +++ b/examples/simulation/Li_2017_unified_thalamus_oscillation_model.py @@ -99,7 +99,7 @@ def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ): IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ca=Ca) -class MgBlock(bp.dyn.SynOutput): +class MgBlock(bp.dyn.SynOut): def __init__(self, E=0.): super(MgBlock, self).__init__() self.E = E