From b59f0c3b100a8d0ff6b70c20f12d92549b89ffbc Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 27 Jun 2022 20:53:15 +0800 Subject: [PATCH 1/2] unify `trainable` setting between brainpy.layers and other modules --- brainpy/dyn/layers/conv.py | 3 ++ brainpy/dyn/layers/dropout.py | 8 +--- brainpy/dyn/layers/linear.py | 13 +++--- brainpy/dyn/layers/nvar.py | 71 +++++++++++++++++++++------------ brainpy/dyn/layers/reservoir.py | 14 +++---- brainpy/dyn/layers/rnncells.py | 15 ++++--- brainpy/dyn/training.py | 9 ++--- brainpy/train/base.py | 5 --- brainpy/train/offline.py | 11 ++++- brainpy/train/online.py | 15 +++++-- 10 files changed, 93 insertions(+), 71 deletions(-) diff --git a/brainpy/dyn/layers/conv.py b/brainpy/dyn/layers/conv.py index 97d05d1ea..6f43fd403 100644 --- a/brainpy/dyn/layers/conv.py +++ b/brainpy/dyn/layers/conv.py @@ -139,6 +139,9 @@ def update(self, sha, x): return y return y + self.b.value + def reset_state(self, batch_size=None): + pass + class Conv1D(GeneralConv): def __init__( diff --git a/brainpy/dyn/layers/dropout.py b/brainpy/dyn/layers/dropout.py index 07e3a46c1..b3dce1abf 100644 --- a/brainpy/dyn/layers/dropout.py +++ b/brainpy/dyn/layers/dropout.py @@ -36,7 +36,7 @@ class Dropout(TrainingSystem): neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ - def __init__(self, prob, seed=None, trainable=False, name=None): + def __init__(self, prob, seed=None, trainable=True, name=None): super(Dropout, self).__init__(trainable=trainable, name=name) self.prob = prob self.rng = bm.random.RandomState(seed=seed) @@ -47,9 +47,3 @@ def update(self, sha, x): return bm.where(keep_mask, x / self.prob, 0.) else: return x - - def reset(self, batch_size=1): - pass - - def reset_state(self, batch_size=1): - pass diff --git a/brainpy/dyn/layers/linear.py b/brainpy/dyn/layers/linear.py index 855288811..d890569ae 100644 --- a/brainpy/dyn/layers/linear.py +++ b/brainpy/dyn/layers/linear.py @@ -47,10 +47,15 @@ def __init__( W_initializer: Union[Initializer, Callable, Tensor] = XavierNormal(), b_initializer: Optional[Union[Initializer, Callable, Tensor]] = ZeroInit(), trainable: bool = True, - name: str = None + name: str = None, + fit_online: bool = False, + fit_offline: bool = False, ): super(Dense, self).__init__(trainable=trainable, name=name) + self.fit_online = fit_online + self.fit_offline = fit_offline + # shape self.num_in = num_in self.num_out = num_out @@ -90,12 +95,6 @@ def update(self, sha, x): self.fit_record['output'] = res return res - def reset(self, batch_size=1): - pass - - def reset_state(self, batch_size=1): - pass - def online_init(self): if self.b is None: num_input = self.num_in diff --git a/brainpy/dyn/layers/nvar.py b/brainpy/dyn/layers/nvar.py index 114c9516d..0235706a5 100644 --- a/brainpy/dyn/layers/nvar.py +++ b/brainpy/dyn/layers/nvar.py @@ -7,9 +7,8 @@ import numpy as np import brainpy.math as bm -from brainpy.tools.checking import (check_integer, check_sequence) from brainpy.dyn.training import TrainingSystem - +from brainpy.tools.checking import (check_integer, check_sequence) __all__ = [ 'NVAR' @@ -69,8 +68,8 @@ def __init__( order: Union[int, Sequence[int]] = None, stride: int = 1, constant: bool = False, - trainable: bool = False, - name: str = None + trainable: bool = True, + name: str = None, ): super(NVAR, self).__init__(trainable=trainable, name=name) @@ -93,8 +92,11 @@ def __init__( # delay variables self.idx = bm.Variable(jnp.asarray([0])) - batch_size = 1 # first initialize the state with batch size = 1 - self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in))) + if trainable: + batch_size = 1 # first initialize the state with batch size = 1 + self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)), batch_axis=1) + else: + self.store = bm.Variable(jnp.zeros((self.num_delay, self.num_in))) # linear dimension self.linear_dim = self.delay * num_in @@ -115,35 +117,52 @@ def __init__( if self.constant: self.num_out += 1 - def reset(self, batch_size=1): - self.idx[0] = 0 - self.reset_state(batch_size) - - def reset_state(self, batch_size=1): + def reset_state(self, batch_size=None): """Reset the node state which depends on batch size.""" + self.idx[0] = 0 # To store the last inputs. # Note, the batch axis is not in the first dimension, so we # manually handle the state of NVAR, rather return it. - self.store._value = jnp.zeros((self.num_delay, batch_size, self.num_in)) + if batch_size is None: + self.store.value = jnp.zeros((self.num_delay, self.num_in)) + else: + self.store.value = jnp.zeros((self.num_delay, batch_size, self.num_in)) def update(self, sha, x): all_parts = [] + select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay # 1. Store the current input self.store[self.idx[0]] = x - # 2. Linear part: - # select all previous inputs, including the current, with strides - select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay - linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature) - linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1)) - # 3. constant - if self.constant: - constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype) - all_parts.append(constant) - all_parts.append(linear_parts) - # 3. Nonlinear part: - # select monomial terms and compute them - for ids in self.comb_ids: - all_parts.append(jnp.prod(linear_parts[:, ids], axis=2)) + + if self.trainable: + # 2. Linear part: + # select all previous inputs, including the current, with strides + linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature) + linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1)) + # 3. constant + if self.constant: + constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype) + all_parts.append(constant) + all_parts.append(linear_parts) + # 3. Nonlinear part: + # select monomial terms and compute them + for ids in self.comb_ids: + all_parts.append(jnp.prod(linear_parts[:, ids], axis=2)) + + else: + # 2. Linear part: + # select all previous inputs, including the current, with strides + linear_parts = self.store[select_ids].flatten() # (num_time x num_feature,) + # 3. constant + if self.constant: + constant = jnp.ones((1,), dtype=x.dtype) + all_parts.append(constant) + all_parts.append(linear_parts) + # 3. Nonlinear part: + # select monomial terms and compute them + for ids in self.comb_ids: + all_parts.append(jnp.prod(linear_parts[ids], axis=1)) + # 4. Finally self.idx.value = (self.idx + 1) % self.num_delay return jnp.concatenate(all_parts, axis=-1) diff --git a/brainpy/dyn/layers/reservoir.py b/brainpy/dyn/layers/reservoir.py index 26232c31e..a7a069900 100644 --- a/brainpy/dyn/layers/reservoir.py +++ b/brainpy/dyn/layers/reservoir.py @@ -3,7 +3,7 @@ from typing import Optional, Union, Callable, Tuple import brainpy.math as bm -from brainpy.initialize import Normal, ZeroInit, Initializer, parameter +from brainpy.initialize import Normal, ZeroInit, Initializer, parameter, variable from brainpy.tools.checking import check_float, check_initializer, check_string from brainpy.tools.others import to_size from brainpy.dyn.training import TrainingSystem @@ -102,7 +102,7 @@ def __init__( noise_rec: float = 0., noise_type: str = 'normal', seed: Optional[int] = None, - trainable: bool = False, + trainable: bool = True, name: str = None ): super(Reservoir, self).__init__(trainable=trainable, name=name) @@ -179,14 +179,10 @@ def __init__( self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) # initialize state - batch_size = 1 - self.state = bm.Variable(bm.zeros((batch_size,) + self.output_shape)) + self.state = variable(bm.zeros, trainable, self.output_shape) - def reset(self, batch_size=1): - self.state._value = bm.zeros((batch_size,) + self.output_shape).value - - def reset_state(self, batch_size=1): - pass + def reset_state(self, batch_size=None): + self.state.value = variable(bm.zeros, batch_size, self.output_shape) def update(self, sha, x): """Feedforward output.""" diff --git a/brainpy/dyn/layers/rnncells.py b/brainpy/dyn/layers/rnncells.py index 5a44c853f..c9c6a3a71 100644 --- a/brainpy/dyn/layers/rnncells.py +++ b/brainpy/dyn/layers/rnncells.py @@ -4,13 +4,15 @@ from typing import Union, Callable import brainpy.math as bm +from brainpy.dyn.training import TrainingSystem from brainpy.initialize import (XavierNormal, ZeroInit, Orthogonal, parameter, + variable, Initializer) -from brainpy.tools.checking import (check_integer, check_initializer) -from brainpy.dyn.training import TrainingSystem +from brainpy.tools.checking import (check_integer, + check_initializer) from brainpy.types import Tensor __all__ = [ @@ -37,16 +39,13 @@ def __init__(self, self.train_state = train_state # state - self.state = bm.Variable(bm.zeros((1, self.num_out))) + self.state = variable(bm.zeros, trainable, self.num_out) if train_state and self.trainable: self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False)) self.state[:] = self.state2train - def reset(self, batch_size=1): - self.reset_state(batch_size) - - def reset_state(self, batch_size=1): - self.state._value = parameter(self._state_initializer, (batch_size, self.num_out), allow_none=False) + def reset_state(self, batch_size=None): + self.state.value = parameter(self._state_initializer, (batch_size, self.num_out), allow_none=False) if self.train_state: self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) self.state[:] = self.state2train diff --git a/brainpy/dyn/training.py b/brainpy/dyn/training.py index f1f1f4043..ff51b6468 100644 --- a/brainpy/dyn/training.py +++ b/brainpy/dyn/training.py @@ -49,11 +49,7 @@ def __init__(self, name: str = None, trainable: bool = False): self.offline_fit_by = None self.fit_record = dict() - def reset(self, batch_size=1): - for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values(): - node.reset(batch_size=batch_size) - - def reset_state(self, batch_size=1): + def reset_state(self, batch_size=None): for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values(): node.reset_state(batch_size=batch_size) @@ -214,3 +210,6 @@ def update(self, sha: dict, x: Any) -> Tensor: x = node(sha, x) return x + def reset(self, batch_size=1): + for node in self.nodes(level=1, include_self=False).unique().subset(TrainingSystem).values(): + node.reset(batch_size=batch_size) diff --git a/brainpy/train/base.py b/brainpy/train/base.py index 6a264448c..14a278c17 100644 --- a/brainpy/train/base.py +++ b/brainpy/train/base.py @@ -82,11 +82,6 @@ def fit( ) -> Output: # need to be implemented by subclass raise NotImplementedError('Must implement the fit function. ') - def _get_trainable_nodes(self) -> Tuple[TrainingSystem, ...]: - # check trainable nodes - nodes = self.target.nodes(level=-1, include_self=True).subset(TrainingSystem).unique() - return tuple([node for node in nodes.values() if node.trainable]) - def _check_ys(self, ys, num_batch, num_step, move_axis=False): if isinstance(ys, (bm.ndarray, jnp.ndarray)): if len(self.train_nodes) == 1: diff --git a/brainpy/train/offline.py b/brainpy/train/offline.py index 97ea8e1ad..e015cf62c 100644 --- a/brainpy/train/offline.py +++ b/brainpy/train/offline.py @@ -56,7 +56,16 @@ def __init__( super(OfflineTrainer, self).__init__(target=target, **kwargs) # get all trainable nodes - self.train_nodes = self._get_trainable_nodes() + nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique() + self.train_nodes = tuple([node for node in nodes.values() + if (hasattr(node, 'fit_offline') and node.fit_offline)]) + if len(self.train_nodes) == 0: + self.train_nodes = tuple([node for node in nodes.values() + if (hasattr(node, 'offline_fit') and + callable(node.offline_fit) and + (not hasattr(node.offline_fit, 'not_implemented')))]) + if len(self.train_nodes) == 0: + raise ValueError('Found no trainable nodes.') # training method if fit_method is None: diff --git a/brainpy/train/online.py b/brainpy/train/online.py index 730fed1be..f0fae9cdb 100644 --- a/brainpy/train/online.py +++ b/brainpy/train/online.py @@ -56,7 +56,16 @@ def __init__( super(OnlineTrainer, self).__init__(target=target, **kwargs) # get all trainable nodes - self.train_nodes = self._get_trainable_nodes() + nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique() + self.train_nodes = tuple([node for node in nodes.values() + if (hasattr(node, 'fit_online') and node.fit_online)]) + if len(self.train_nodes) == 0: + self.train_nodes = tuple([node for node in nodes.values() + if (hasattr(node, 'online_fit') and + callable(node.online_fit) and + (not hasattr(node.online_fit, 'not_implemented')))]) + if len(self.train_nodes) == 0: + raise ValueError('Found no trainable nodes.') # training method if fit_method is None: @@ -214,7 +223,7 @@ def _fit( A tuple of pair of (outputs, hists). """ _fit_func = self._get_fit_func(shared_args) - hists = _fit_func(xs + (ys, )) + hists = _fit_func(xs + (ys,)) hists = tree_map(lambda x: bm.moveaxis(x, 0, 1), hists, is_leaf=lambda x: isinstance(x, bm.JaxArray)) return hists @@ -241,7 +250,7 @@ def _step_func(all_inputs): # update step shared.update(shared_args) - args = (shared, ) if x is None else (shared, x) + args = (shared,) if x is None else (shared, x) out = self.target(*args) # monitor step From 525c4a6decc6a8c0739e945fb319e8d7116af7d6 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 1 Jul 2022 20:02:14 +0800 Subject: [PATCH 2/2] update and fix bugs --- brainpy/dyn/base.py | 26 ++++--- brainpy/dyn/channels/K.py | 22 +++--- brainpy/dyn/channels/Na.py | 22 +++--- brainpy/dyn/runners.py | 2 +- brainpy/math/setting.py | 14 +++- brainpy/tools/others/dicts.py | 5 ++ examples/simulation/hh_model.py | 4 +- examples/simulation/multi_scale_COBAHH.py | 77 +++++++++++-------- examples/training/echo_state_network.py | 26 +++---- extensions/brainpylib/custom_op/regis_op.py | 21 +++-- extensions/brainpylib/event_sum.py | 24 +++++- extensions/brainpylib/tests/test_event_sum.py | 38 +++++++++ 12 files changed, 192 insertions(+), 89 deletions(-) diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index ba59af43d..c62487ff6 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -148,7 +148,12 @@ def register_delay( self.global_delay_data[identifier] = (delay, delay_target) self.local_delay_vars[identifier] = delay else: - if self.global_delay_data[identifier][0].num_delay_step - 1 < max_delay_step: + delay = self.global_delay_data[identifier][0] + if delay is None: + delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data) + self.global_delay_data[identifier] = (delay, delay_target) + self.local_delay_vars[identifier] = delay + elif delay.num_delay_step - 1 < max_delay_step: self.global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data) else: self.global_delay_data[identifier] = (None, delay_target) @@ -181,7 +186,8 @@ def get_delay_data( return self.global_delay_data[identifier][1].value if identifier in self.global_delay_data: - if isinstance(delay_step, (int, np.integer)): + # if isinstance(delay_step, (int, np.integer)): + if bm.ndim(delay_step) == 0: return self.global_delay_data[identifier][0](delay_step, *indices) else: if len(indices) == 0: @@ -189,7 +195,7 @@ def get_delay_data( return self.global_delay_data[identifier][0](delay_step, *indices) elif identifier in self.local_delay_vars: - if isinstance(delay_step, (int, np.integer)): + if bm.ndim(delay_step) == 0: return self.local_delay_vars[identifier](delay_step) else: if len(indices) == 0: @@ -685,11 +691,11 @@ def __init__( self.output.register_master(master=self) # synaptic plasticity - if stp is None: stp = SynSTP() - if not isinstance(stp, SynSTP): - raise TypeError(f'plasticity must be instance of {SynSTP.__name__}, but we got {type(stp)}') - self.stp: SynSTP = stp - self.stp.register_master(master=self) + if stp is not None: + if not isinstance(stp, SynSTP): + raise TypeError(f'plasticity must be instance of {SynSTP.__name__}, but we got {type(stp)}') + stp.register_master(master=self) + self.stp: Optional[SynSTP] = stp def init_weights( self, @@ -734,7 +740,7 @@ def init_weights( return weight, conn_mask def syn2post_with_all2all(self, syn_value, syn_weight): - if bm.size(syn_weight) == 1: + if bm.ndim(syn_weight) == 0: if self.trainable: post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) else: @@ -750,7 +756,7 @@ def syn2post_with_one2one(self, syn_value, syn_weight): return syn_value * syn_weight def syn2post_with_dense(self, syn_value, syn_weight, conn_mat): - if bm.size(syn_weight) == 1: + if bm.ndim(syn_weight) == 0: post_vs = (syn_weight * syn_value) @ conn_mat else: post_vs = syn_value @ (syn_weight * conn_mat) diff --git a/brainpy/dyn/channels/K.py b/brainpy/dyn/channels/K.py index f42e21f4d..4d8aeba4e 100644 --- a/brainpy/dyn/channels/K.py +++ b/brainpy/dyn/channels/K.py @@ -17,7 +17,7 @@ 'IK_p4_markov', 'IKDR_Ba2002', 'IK_TM1991', - 'IK_HH', + 'IK_HH1952', 'IKA_p4q_ss', 'IKA1_HM1992', @@ -269,7 +269,7 @@ def f_p_beta(self, V): return 0.5 * bm.exp((10 - V + self.V_sh) / 40) -class IK_HH(IK_p4_markov): +class IK_HH1952(IK_p4_markov): r"""The potassium channel described by Hodgkin–Huxley model [1]_. The dynamics of this channel is given by: @@ -307,7 +307,7 @@ class IK_HH(IK_p4_markov): See Also -------- - INa_HH + INa_HH1952 """ def __init__( @@ -322,14 +322,14 @@ def __init__( name: str = None, trainable: bool = False, ): - super(IK_HH, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=phi, - E=E, - g_max=g_max, - trainable=trainable) + super(IK_HH1952, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=phi, + E=E, + g_max=g_max, + trainable=trainable) self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_alpha(self, V): diff --git a/brainpy/dyn/channels/Na.py b/brainpy/dyn/channels/Na.py index 20e8fa877..062ff71eb 100644 --- a/brainpy/dyn/channels/Na.py +++ b/brainpy/dyn/channels/Na.py @@ -17,7 +17,7 @@ 'INa_p3q_markov', 'INa_Ba2002', 'INa_TM1991', - 'INa_HH', + 'INa_HH1952', ] @@ -284,7 +284,7 @@ def f_q_beta(self, V): return 4. / (1 + bm.exp(-(V - self.V_sh - 40) / 5)) -class INa_HH(INa_p3q_markov): +class INa_HH1952(INa_p3q_markov): r"""The sodium current model described by Hodgkin–Huxley model [1]_. The dynamics of this sodium current model is given by: @@ -331,7 +331,7 @@ class INa_HH(INa_p3q_markov): See Also -------- - IK_HH + IK_HH1952 """ def __init__( @@ -346,14 +346,14 @@ def __init__( name: str = None, trainable: bool = False, ): - super(INa_HH, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - E=E, - phi=phi, - g_max=g_max, - trainable=trainable) + super(INa_HH1952, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + E=E, + phi=phi, + g_max=g_max, + trainable=trainable) self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_alpha(self, V): diff --git a/brainpy/dyn/runners.py b/brainpy/dyn/runners.py index 9617e1f57..a0505ea2f 100644 --- a/brainpy/dyn/runners.py +++ b/brainpy/dyn/runners.py @@ -583,7 +583,7 @@ def f_predict(self, shared_args: Dict = None): def _step_func(inputs): t, i, x = inputs # input step - shared = DotDict(t=t, i=t, dt=self.dt) + shared = DotDict(t=t, i=i, dt=self.dt) self._input_step(shared) # dynamics update step shared.update(shared_args) diff --git a/brainpy/math/setting.py b/brainpy/math/setting.py index a28c0f21a..66d334c0e 100644 --- a/brainpy/math/setting.py +++ b/brainpy/math/setting.py @@ -13,7 +13,7 @@ 'set_host_device_count', # device memory - 'clear_live_buffers', + 'clear_device_memory', 'disable_gpu_memory_preallocation', 'enable_gpu_memory_preallocation', @@ -125,15 +125,23 @@ def set_host_device_count(n): os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags) -def clear_live_buffers(): +def clear_device_memory(platform=None): """Clear all on-device buffers. + This function will be very useful when you call models in a Python loop, + because it can clear all cached arrays, and clear device memory. + .. warning:: This operation may cause errors when you use a deleted buffer. Therefore, regenerate data always. + + Parameters + ---------- + platform: str + The device to clear its memory. """ - for buf in xla_bridge.get_backend().live_buffers(): + for buf in xla_bridge.get_backend(platform=platform).live_buffers(): buf.delete() diff --git a/brainpy/tools/others/dicts.py b/brainpy/tools/others/dicts.py index 7cacf87da..caa3fde98 100644 --- a/brainpy/tools/others/dicts.py +++ b/brainpy/tools/others/dicts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +import numpy as np from jax.tree_util import register_pytree_node from jax.util import safe_zip @@ -29,6 +30,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__dict__ = self + def to_numpy(self): + for key in tuple(self.keys()): + self[key] = np.asarray(self[key]) + register_pytree_node( DotDict, diff --git a/examples/simulation/hh_model.py b/examples/simulation/hh_model.py index 1469df818..5040e2370 100644 --- a/examples/simulation/hh_model.py +++ b/examples/simulation/hh_model.py @@ -9,8 +9,8 @@ class HH(dyn.CondNeuGroup): def __init__(self, size): super(HH, self).__init__(size) - self.INa = channels.INa_HH(size, ) - self.IK = channels.IK_HH(size, ) + self.INa = channels.INa_HH1952(size, ) + self.IK = channels.IK_HH1952(size, ) self.IL = channels.IL(size, E=-54.387, g_max=0.03) diff --git a/examples/simulation/multi_scale_COBAHH.py b/examples/simulation/multi_scale_COBAHH.py index 5168319dd..b80d4f668 100644 --- a/examples/simulation/multi_scale_COBAHH.py +++ b/examples/simulation/multi_scale_COBAHH.py @@ -7,6 +7,7 @@ from brainpy.dyn.channels import INa_TM1991, IL from brainpy.dyn.synapses import Exponential from brainpy.dyn.synouts import COBA +from brainpy.connect import FixedProb class IK2(bp.dyn.channels.IK_p4_markov): @@ -42,27 +43,44 @@ def current(self, V): return self.g_max * self.p ** 4 * (self.E - V) +# class IK(bp.dyn.Channel): +# def __init__(self, size, E=-90., g_max=10.): +# super(IK, self).__init__(size) +# def dp(p, t, V): +# alpha = 0.032*(V+48)/(1-bm.exp(-(V + 48)/5.)) +# beta = 0.5 * bm.exp(-(V + 53) / 40.) +# return alpha * (1. - p) - beta * p +# self.integral = bp.odeint(dp, method='exp_euler') +# self.g_max, self.E = g_max, E +# self.p = bm.Variable(bm.zeros(size)) +# +# def update(self, tdi, V): +# self.p.value = self.integral(self.p, tdi.t, V, tdi.dt) +# +# def current(self, V): +# return self.g_max * self.p ** 4 * (self.E - V) + + class HH(bp.dyn.CondNeuGroup): def __init__(self, size): super(HH, self).__init__(size, ) - self.INa = INa_TM1991(size, g_max=100., V_sh=-63.) self.IK = IK(size, g_max=30., V_sh=-63.) + self.INa = INa_TM1991(size, g_max=100., V_sh=-63.) self.IL = IL(size, E=-60., g_max=0.05) class Network(bp.dyn.Network): - def __init__(self, num_E, num_I, ): + def __init__(self, num_E, num_I, g_e2e=0.03, g_e2i=0.03, e_i2e=0.335, g_i2i=0.335): super(Network, self).__init__() - self.E = HH(num_E) - self.I = HH(num_I) - self.E2E = Exponential(self.E, self.E, bp.conn.FixedProb(0.02), - g_max=0.03, tau=5, output=COBA(E=0.)) - self.E2I = Exponential(self.E, self.I, bp.conn.FixedProb(0.02), - g_max=0.03, tau=5., output=COBA(E=0.)) - self.I2E = Exponential(self.I, self.E, bp.conn.FixedProb(0.02), - g_max=0.335, tau=10., output=COBA(E=-80)) - self.I2I = Exponential(self.I, self.I, bp.conn.FixedProb(0.02), - g_max=0.335, tau=10., output=COBA(E=-80.)) + self.E, self.I = HH(num_E), HH(num_I) + self.E2E = Exponential(self.E, self.E, FixedProb(0.02), + g_max=g_e2e, tau=5, output=COBA(E=0.)) + self.E2I = Exponential(self.E, self.I, FixedProb(0.02), + g_max=g_e2i, tau=5., output=COBA(E=0.)) + self.I2E = Exponential(self.I, self.E, FixedProb(0.02), + g_max=e_i2e, tau=10., output=COBA(E=-80)) + self.I2I = Exponential(self.I, self.I, FixedProb(0.02), + g_max=g_i2i, tau=10., output=COBA(E=-80.)) class Projection(bp.dyn.DynamicalSystem): @@ -85,10 +103,8 @@ def update(self, tdi): class Circuit(bp.dyn.Network): - def __init__(self, conn, delay): + def __init__(self, conn, delay, num_area): super(Circuit, self).__init__() - - num_area = conn.shape[0] self.areas = [Network(3200, 800) for _ in range(num_area)] self.projections = [] for i in range(num_area): @@ -106,7 +122,7 @@ def __init__(self, conn, delay): conn_data = data['conn'] delay_data = (data['delay'] / bm.get_dt()).astype(int) -circuit = Circuit(conn_data, delay_data) +circuit = Circuit(conn_data, delay_data, conn_data.shape[0]) f1 = lambda tdi: bm.concatenate([area.E.spike for area in circuit.areas]) f2 = lambda tdi: bm.concatenate([area.I.spike for area in circuit.areas]) I, duration = bp.inputs.section_input([0, 0.8, 0.], [50., 50., 100.], return_length=True) @@ -126,18 +142,17 @@ def __init__(self, conn, delay): fig.add_subplot(gs[1, 0]) bp.visualize.raster_plot(runner.mon['ts'], runner.mon.get('inh.spike'), show=True) -import seaborn as sns - -sns.set_theme(font_scale=1.5) - -fig, gs = bp.visualize.get_figure(1, 1, 4.5, 6) -fig.add_subplot(gs[0, 0]) -bp.visualize.line_plot(runner.mon['ts'], runner.mon['K.p'], show=True, plot_ids=(4, 5, 1)) - -fig, gs = bp.visualize.get_figure(1, 1, 4.5, 6) -fig.add_subplot(gs[0, 0]) -bp.visualize.line_plot(runner.mon['ts'], runner.mon['A0.V'], show=True, plot_ids=(4, 5, 1)) - -fig, gs = bp.visualize.get_figure(1, 1, 4.5, 6) -fig.add_subplot(gs[0, 0]) -bp.visualize.raster_plot(runner.mon['ts'], runner.mon['A0.spike'], show=True) +# import seaborn as sns +# sns.set_theme(font_scale=1.5) +# +# fig, gs = bp.visualize.get_figure(1, 1, 4.5, 6) +# fig.add_subplot(gs[0, 0]) +# bp.visualize.line_plot(runner.mon['ts'], runner.mon['K.p'], show=True, plot_ids=(4, 5, 1)) +# +# fig, gs = bp.visualize.get_figure(1, 1, 4.5, 6) +# fig.add_subplot(gs[0, 0]) +# bp.visualize.line_plot(runner.mon['ts'], runner.mon['A0.V'], show=True, plot_ids=(4, 5, 1)) +# +# fig, gs = bp.visualize.get_figure(1, 1, 4.5, 6) +# fig.add_subplot(gs[0, 0]) +# bp.visualize.raster_plot(runner.mon['ts'], runner.mon['A0.spike'], show=True) diff --git a/examples/training/echo_state_network.py b/examples/training/echo_state_network.py index e05a2a172..76631e34a 100644 --- a/examples/training/echo_state_network.py +++ b/examples/training/echo_state_network.py @@ -15,8 +15,8 @@ def __init__(self, num_in, num_hidden, num_out): conn_type='dense') self.o = bp.layers.Dense(num_hidden, num_out, W_initializer=bp.init.Normal()) - def forward(self, x, shared_args=None): - return self.o(self.r(x, shared_args), shared_args) + def update(self, shared_args, x): + return self.o(shared_args, self.r(shared_args, x)) class NGRC(bp.dyn.TrainingSystem): @@ -28,15 +28,15 @@ def __init__(self, num_in, num_out): W_initializer=bp.init.Normal(0.1), trainable=True) - def forward(self, x, shared_args=None): - return self.o(self.r(x, shared_args), shared_args) + def update(self, shared_args, x): + return self.o(shared_args, self.r(shared_args, x)) def train_esn_with_ridge(num_in=100, num_out=30): model = ESN(num_in, 2000, num_out) # input-output - print(model(bm.ones((1, num_in)))) + print(model(dict(), bm.ones((1, num_in)))) X = bm.random.random((1, 200, num_in)) Y = bm.random.random((1, 200, num_out)) @@ -67,7 +67,7 @@ def train_esn_with_force(num_in=100, num_out=30): model = ESN(num_in, 2000, num_out) # input-output - print(model(bm.ones((1, num_in)))) + print(model(dict(), bm.ones((1, num_in)))) X = bm.random.random((1, 200, num_in)) Y = bm.random.random((1, 200, num_out)) @@ -77,8 +77,8 @@ def train_esn_with_force(num_in=100, num_out=30): trainer.fit([X, Y]) # prediction - runner = bp.train.DSRunner(model, monitors=['r.state'], jit=True) - outputs = runner.predict(X) + runner = bp.dyn.DSRunner(model, monitors=['r.state'], jit=True) + outputs = runner.predict(inputs=X, inputs_are_batching=True) print(runner.mon['r.state'].shape) print(bp.losses.mean_absolute_error(outputs, Y)) print() @@ -93,11 +93,11 @@ def ngrc(num_in=10, num_out=30): X = bm.random.random((1, 200, num_in)) # (num_batch, num_time, num_feature) Y = bm.random.random((1, 200, num_out)) trainer = bp.train.RidgeTrainer(model, alpha=1e-6) - outputs = trainer.predict(X) + outputs = trainer.predict(inputs=X) print(outputs.shape) print(bp.losses.mean_absolute_error(outputs, Y)) trainer.fit([X, Y]) - outputs = trainer.predict(X) + outputs = trainer.predict(inputs=X) print(bp.losses.mean_absolute_error(outputs, Y)) @@ -107,7 +107,7 @@ def ngrc_bacth(num_in=10, num_out=30): model.reset_state(batch_size) X = bm.random.random((batch_size, 200, num_in)) Y = bm.random.random((batch_size, 200, num_out)) - trainer = bp.train.RidgeTrainer(model, beta=1e-6) + trainer = bp.train.RidgeTrainer(model, alpha=1e-6) outputs = trainer.predict(X) print(bp.losses.mean_absolute_error(outputs, Y)) trainer.fit([X, Y]) @@ -116,7 +116,7 @@ def ngrc_bacth(num_in=10, num_out=30): if __name__ == '__main__': - train_esn_with_ridge(10, 30) - train_esn_with_force(10, 30) + # train_esn_with_ridge(10, 30) + # train_esn_with_force(10, 30) ngrc(10, 30) ngrc_bacth() diff --git a/extensions/brainpylib/custom_op/regis_op.py b/extensions/brainpylib/custom_op/regis_op.py index 9e65fb556..8ededa937 100644 --- a/extensions/brainpylib/custom_op/regis_op.py +++ b/extensions/brainpylib/custom_op/regis_op.py @@ -8,7 +8,7 @@ import numpy as np from jax import core from jax.abstract_arrays import ShapedArray -from jax.interpreters import xla +from jax.interpreters import xla, batching from numba import cuda from numba.core.dispatcher import Dispatcher @@ -21,9 +21,11 @@ def register_op( op_name: str, cpu_func: Callable, + out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]], gpu_func: Callable = None, - out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]] = None, - apply_cpu_func_to_gpu: bool = False + batch_fun: Callable = None, + apply_cpu_func_to_gpu: bool = False, + return_primitive: bool = False, ): """ Converting the numba-jitted function in a Jax/XLA compatible primitive. @@ -110,17 +112,24 @@ def bind_primitive(*inputs): result = prim.bind(*inputs) return result[0] if len(result) == 1 else result - # binding + # cpu function prim.def_abstract_eval(abs_eval_rule) prim.def_impl(eval_rule) - # registering xla.backend_specific_translations['cpu'][prim] = partial(func_cpu_translation, cpu_func, abs_eval_rule) if apply_cpu_func_to_gpu: xla.backend_specific_translations['gpu'][prim] = partial(func_gpu_translation, cpu_func, abs_eval_rule) + # gpu function if gpu_func is not None: if not isinstance(gpu_func, Dispatcher): gpu_func = cuda.jit(gpu_func) xla.backend_specific_translations['gpu'][prim] = partial(func_gpu_translation, gpu_func, abs_eval_rule) - return bind_primitive + # batching + if batch_fun is not None: + batching.primitive_batchers[prim] = batch_fun + + if return_primitive: + return bind_primitive, prim + else: + return bind_primitive diff --git a/extensions/brainpylib/event_sum.py b/extensions/brainpylib/event_sum.py index fbbd985f9..b9b6691a6 100644 --- a/extensions/brainpylib/event_sum.py +++ b/extensions/brainpylib/event_sum.py @@ -10,8 +10,9 @@ import jax.numpy as jnp import numpy as np from jax import core -from jax.interpreters import xla +from jax.interpreters import xla, batching from jax.lib import xla_client +from jax.lax import scan try: from . import gpu_ops @@ -130,6 +131,27 @@ def _event_sum_translation(c, events, indices, indptr, values, out, *, platform= xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu") xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu") + +def _event_sum_batch(args, axes): + batch_axes, batch_args, non_batch_args = [], {}, {} + for ax_i, ax in enumerate(axes): + if ax is None: + non_batch_args[f'ax{ax_i}'] = args[ax_i] + else: + batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0) + batch_axes.append(ax_i) + + def f(_, x): + pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}']) + for i in range(len(axes))]) + return 0, _event_sum_prim.bind(*pars) + _, outs = scan(f, 0, batch_args) + return outs, 0 + + +batching.primitive_batchers[_event_sum_prim] = _event_sum_batch + + # --------------------------- # event sum kernel 2 # --------------------------- diff --git a/extensions/brainpylib/tests/test_event_sum.py b/extensions/brainpylib/tests/test_event_sum.py index 58a7a211e..af6aabfdb 100644 --- a/extensions/brainpylib/tests/test_event_sum.py +++ b/extensions/brainpylib/tests/test_event_sum.py @@ -7,6 +7,7 @@ import numpy as np import pytest import unittest +from jax import vmap from brainpylib import event_sum import brainpy as bp import brainpy.math as bm @@ -29,6 +30,26 @@ def test_homo_values(self): a = event_sum(sps, (post_ids.value, indptr.value), size, value) print(a) + def test_homo_values_batching(self): + bp.math.random.seed(1345) + size = 200 + conn = bp.conn.FixedProb(prob=0.5, seed=123) + + conn(pre_size=size, post_size=size) + post_ids, indptr = conn.require('pre2post') + sps = bm.random.random((10, size)).value < 0.5 + value = 3.0233 + f = vmap(bm.pre2post_event_sum, in_axes=(0, None, None, None)) + a1 = f(sps, (post_ids.value, indptr.value), size, value) + + print(a1) + + f = vmap(lambda events: bm.pre2post_event_sum(events, (post_ids.value, indptr.value), size, value)) + a2 = f(sps) + + print(a2) + self.assertTrue(jnp.array_equal(a1, a2)) + def test_heter_value(self): bp.math.random.seed(3) size = 200 @@ -43,6 +64,23 @@ def test_heter_value(self): a = event_sum(sps, (post_ids.value, indptr.value), size, values.value) print(a) + def test_heter_values_batching(self): + bp.math.random.seed(1345) + size = 200 + conn = bp.conn.FixedProb(prob=0.5, seed=123) + + conn(pre_size=size, post_size=size) + post_ids, indptr = conn.require('pre2post') + sps = bm.random.random((10, size)).value < 0.5 + values = bm.random.rand(post_ids.size) + f = vmap(bm.pre2post_event_sum, in_axes=(0, None, None, None)) + a1 = f(sps, (post_ids.value, indptr.value), size, values) + + f = vmap(lambda events: bm.pre2post_event_sum(events, (post_ids.value, indptr.value), size, values)) + a2 = f(sps) + + self.assertTrue(jnp.array_equal(a1, a2)) + # def test1(): # bm.random.seed(123)