From bba53fbd878182ae20990cfccc1708c705483e84 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 27 Jan 2023 14:49:59 +0800 Subject: [PATCH 1/4] operation results of `Array` and `Variable` are all brainpy.math.Array --- brainpy/_src/analysis/highdim/slow_points.py | 2 +- brainpy/_src/dyn/base.py | 41 +- brainpy/_src/dyn/channels/IH.py | 21 +- brainpy/_src/dyn/channels/K.py | 84 ++-- brainpy/_src/dyn/channels/KCa.py | 12 +- brainpy/_src/dyn/channels/Na.py | 30 +- brainpy/_src/dyn/neurons/biological_models.py | 104 ++--- brainpy/_src/dyn/neurons/reduced_models.py | 241 +++++----- brainpy/_src/dyn/rates/populations.py | 92 ++-- brainpy/_src/dyn/synapses/abstract_models.py | 45 +- .../_src/dyn/synapses/biological_models.py | 46 +- brainpy/_src/initialize/random_inits.py | 10 +- brainpy/_src/inputs/currents.py | 18 +- brainpy/_src/integrators/ode/exponential.py | 8 +- .../ode/tests/test_ode_method_exp_euler.py | 39 +- brainpy/_src/math/delayvars.py | 6 +- brainpy/_src/math/ndarray.py | 434 +----------------- brainpy/_src/math/object_transform/base.py | 7 +- ...Bellec_2020_eprop_evidence_accumulation.py | 8 +- .../spikebased_bp_for_cifar10.py | 2 +- 20 files changed, 421 insertions(+), 829 deletions(-) diff --git a/brainpy/_src/analysis/highdim/slow_points.py b/brainpy/_src/analysis/highdim/slow_points.py index a86fecc0e..4515b9d64 100644 --- a/brainpy/_src/analysis/highdim/slow_points.py +++ b/brainpy/_src/analysis/highdim/slow_points.py @@ -462,7 +462,7 @@ def filter_loss(self, tolerance: float = 1e-5): else: num_fps = self._fixed_points.shape[0] ids = self._losses < tolerance - keep_ids = bm.as_jax(jnp.where(ids)[0]) + keep_ids = bm.as_jax(bm.where(ids)[0]) self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points) self._losses = self._losses[keep_ids] self._selected_ids = self._selected_ids[keep_ids] diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index 0f38324d5..29eea1588 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -4,6 +4,7 @@ import gc from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any +import jax import jax.numpy as jnp import numpy as np @@ -18,8 +19,6 @@ from brainpy.errors import NoImplementationError, UnsupportedError from brainpy.types import ArrayType, Shape - - __all__ = [ # general class 'DynamicalSystem', @@ -170,14 +169,14 @@ def register_delay( raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' f'integer, array of integers, callable function, brainpy.init.Initializer.') if delay_type == 'heter': - if delay_step.dtype not in [jnp.int32, jnp.int64]: + if delay_step.dtype not in [bm.int32, bm.int64]: raise ValueError('Only support delay steps of int32, int64. If your ' 'provide delay time length, please divide the "dt" ' 'then provide us the number of delay steps.') if delay_target.shape[0] != delay_step.shape[0]: raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}') if delay_type != 'none': - max_delay_step = int(jnp.max(delay_step)) + max_delay_step = int(bm.max(delay_step)) # delay target if delay_type != 'none': @@ -207,8 +206,8 @@ def register_delay( def get_delay_data( self, identifier: str, - delay_step: Optional[Union[int, bm.Array, jnp.DeviceArray]], - *indices: Union[int, slice, bm.Array, jnp.DeviceArray], + delay_step: Optional[Union[int, bm.Array, jax.Array]], + *indices: Union[int, slice, bm.Array, jax.Array], ): """Get delay data according to the provided delay steps. @@ -230,19 +229,19 @@ def get_delay_data( return self.global_delay_data[identifier][1].value if identifier in self.global_delay_data: - if jnp.ndim(delay_step) == 0: + if bm.ndim(delay_step) == 0: return self.global_delay_data[identifier][0](delay_step, *indices) else: if len(indices) == 0: - indices = (jnp.arange(delay_step.size),) + indices = (bm.arange(delay_step.size),) return self.global_delay_data[identifier][0](delay_step, *indices) elif identifier in self.local_delay_vars: - if jnp.ndim(delay_step) == 0: + if bm.ndim(delay_step) == 0: return self.local_delay_vars[identifier](delay_step) else: if len(indices) == 0: - indices = (jnp.arange(delay_step.size),) + indices = (bm.arange(delay_step.size),) return self.local_delay_vars[identifier](delay_step, *indices) else: @@ -878,7 +877,7 @@ def __init__( # ------------ if isinstance(conn, TwoEndConnector): self.conn = conn(pre.size, post.size) - elif isinstance(conn, (bm.ndarray, np.ndarray, jnp.ndarray)): + elif isinstance(conn, (bm.ndarray, np.ndarray, jax.Array)): if (pre.num, post.num) != conn.shape: raise ValueError(f'"conn" is provided as a matrix, and it is expected ' f'to be an array with shape of (pre.num, post.num) = ' @@ -1157,11 +1156,11 @@ def _init_weights( return weight, conn_mask def _syn2post_with_all2all(self, syn_value, syn_weight): - if jnp.ndim(syn_weight) == 0: + if bm.ndim(syn_weight) == 0: if isinstance(self.mode, bm.BatchingMode): - post_vs = jnp.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) + post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) else: - post_vs = jnp.sum(syn_value) + post_vs = bm.sum(syn_value) if not self.conn.include_self: post_vs = post_vs - syn_value post_vs = syn_weight * post_vs @@ -1173,7 +1172,7 @@ def _syn2post_with_one2one(self, syn_value, syn_weight): return syn_value * syn_weight def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): - if jnp.ndim(syn_weight) == 0: + if bm.ndim(syn_weight) == 0: post_vs = (syn_weight * syn_value) @ conn_mat else: post_vs = syn_value @ (syn_weight * conn_mat) @@ -1253,8 +1252,8 @@ def __init__( # variables self.V = variable(V_initializer, self.mode, self.varshape) - self.input = variable(jnp.zeros, self.mode, self.varshape) - self.spike = variable(lambda s: jnp.zeros(s, dtype=bool), self.mode, self.varshape) + self.input = variable(bm.zeros, self.mode, self.varshape) + self.spike = variable(lambda s: bm.zeros(s, dtype=bool), self.mode, self.varshape) # function if self.noise is None: @@ -1271,8 +1270,8 @@ def derivative(self, V, t): def reset_state(self, batch_size=None): self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.spike.value = variable(lambda s: jnp.zeros(s, dtype=bool), batch_size, self.varshape) - self.input.value = variable(jnp.zeros, batch_size, self.varshape) + self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) for channel in self.nodes(level=1, include_self=False).subset(Channel).unique().values(): channel.reset_state(self.V.value, batch_size=batch_size) @@ -1286,7 +1285,7 @@ def update(self, tdi, *args, **kwargs): # update variables for node in channels.values(): node.update(tdi, self.V.value) - self.spike.value = jnp.logical_and(V >= self.V_th, self.V < self.V_th) + self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) self.V.value = V def register_implicit_nodes(self, *channels, **named_channels): @@ -1295,7 +1294,7 @@ def register_implicit_nodes(self, *channels, **named_channels): def clear_input(self): """Useful for monitoring inputs. """ - self.input.value = jnp.zeros_like(self.input.value) + self.input.value = bm.zeros_like(self.input.value) class Channel(DynamicalSystem): diff --git a/brainpy/_src/dyn/channels/IH.py b/brainpy/_src/dyn/channels/IH.py index 3454db160..47804b0f6 100644 --- a/brainpy/_src/dyn/channels/IH.py +++ b/brainpy/_src/dyn/channels/IH.py @@ -7,7 +7,6 @@ from typing import Union, Callable -import jax.numpy as jnp import brainpy.math as bm from brainpy._src.initialize import Initializer, parameter, variable from brainpy._src.integrators import odeint, JointEq @@ -76,7 +75,7 @@ def __init__( self.E = parameter(E, self.varshape, allow_none=False) # variable - self.p = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(self.derivative, method=method) @@ -96,10 +95,10 @@ def current(self, V): return self.g_max * self.p * (self.E - V) def f_p_inf(self, V): - return 1. / (1. + jnp.exp((V + 75.) / 5.5)) + return 1. / (1. + bm.exp((V + 75.) / 5.5)) def f_p_tau(self, V): - return 1. / (jnp.exp(-0.086 * V - 14.59) + jnp.exp(0.0701 * V - 1.87)) + return 1. / (bm.exp(-0.086 * V - 14.59) + bm.exp(0.0701 * V - 1.87)) class Ih_De1996(IhChannel, CalciumChannel): @@ -200,9 +199,9 @@ def __init__( self.g_inc = parameter(g_inc, self.varshape, allow_none=False) # variable - self.O = variable(jnp.zeros, self.mode, self.varshape) - self.OL = variable(jnp.zeros, self.mode, self.varshape) - self.P1 = variable(jnp.zeros, self.mode, self.varshape) + self.O = variable(bm.zeros, self.mode, self.varshape) + self.OL = variable(bm.zeros, self.mode, self.varshape) + self.P1 = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(JointEq(self.dO, self.dOL, self.dP1), method=method) @@ -229,7 +228,7 @@ def current(self, V, C_Ca, E_Ca): def reset_state(self, V, C_Ca, E_Ca, batch_size=None): varshape = self.varshape if (batch_size is None) else ((batch_size,) + self.varshape) - self.P1.value = jnp.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape) + self.P1.value = bm.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape) inf = self.f_inf(V) tau = self.f_tau(V) alpha = inf / tau @@ -242,8 +241,8 @@ def reset_state(self, V, C_Ca, E_Ca, batch_size=None): assert self.OL.shape[0] == batch_size def f_inf(self, V): - return 1 / (1 + jnp.exp((V + 75 - self.V_sh) / 5.5)) + return 1 / (1 + bm.exp((V + 75 - self.V_sh) / 5.5)) def f_tau(self, V): - return (20. + 1000 / (jnp.exp((V + 71.5 - self.V_sh) / 14.2) + - jnp.exp(-(V + 89 - self.V_sh) / 11.6))) / self.phi + return (20. + 1000 / (bm.exp((V + 71.5 - self.V_sh) / 14.2) + + bm.exp(-(V + 89 - self.V_sh) / 11.6))) / self.phi diff --git a/brainpy/_src/dyn/channels/K.py b/brainpy/_src/dyn/channels/K.py index 49e43695f..4bfa25507 100644 --- a/brainpy/_src/dyn/channels/K.py +++ b/brainpy/_src/dyn/channels/K.py @@ -7,8 +7,6 @@ from typing import Union, Callable, Optional -import jax.numpy as jnp - import brainpy.math as bm from brainpy._src.initialize import Initializer, parameter, variable from brainpy._src.integrators import odeint, JointEq @@ -86,7 +84,7 @@ def __init__( self.phi = parameter(phi, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(self.derivative, method=method) @@ -192,10 +190,10 @@ def __init__( def f_p_alpha(self, V): tmp = V - self.V_sh - 15. - return 0.032 * tmp / (1. - jnp.exp(-tmp / 5.)) + return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) def f_p_beta(self, V): - return 0.5 * jnp.exp(-(V - self.V_sh - 10.) / 40.) + return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) class IK_TM1991(_IK_p4_markov): @@ -262,10 +260,10 @@ def __init__( def f_p_alpha(self, V): c = 15 - V + self.V_sh - return 0.032 * c / (jnp.exp(c / 5) - 1.) + return 0.032 * c / (bm.exp(c / 5) - 1.) def f_p_beta(self, V): - return 0.5 * jnp.exp((10 - V + self.V_sh) / 40) + return 0.5 * bm.exp((10 - V + self.V_sh) / 40) class IK_HH1952(_IK_p4_markov): @@ -333,10 +331,10 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh + 10 - return 0.01 * temp / (1 - jnp.exp(-temp / 10)) + return 0.01 * temp / (1 - bm.exp(-temp / 10)) def f_p_beta(self, V): - return 0.125 * jnp.exp(-(V - self.V_sh + 20) / 80) + return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) class _IKA_p4q_ss(PotassiumChannel): @@ -405,8 +403,8 @@ def __init__( self.phi_q = parameter(phi_q, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) - self.q = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(JointEq(self.dp, self.dq), method=method) @@ -523,19 +521,19 @@ def __init__( self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): - return 1. / (1. + jnp.exp(-(V - self.V_sh + 60.) / 8.5)) + return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5)) def f_p_tau(self, V): - return 1. / (jnp.exp((V - self.V_sh + 35.8) / 19.7) + - jnp.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 + return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + + bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 def f_q_inf(self, V): - return 1. / (1. + jnp.exp((V - self.V_sh + 78.) / 6.)) + return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) def f_q_tau(self, V): - return jnp.where(V < -63 + self.V_sh, - 1. / (jnp.exp((V - self.V_sh + 46.) / 5.) + - jnp.exp(-(V - self.V_sh + 238.) / 37.5)), + return bm.where(V < -63 + self.V_sh, + 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + + bm.exp(-(V - self.V_sh + 238.) / 37.5)), 19.) @@ -618,19 +616,19 @@ def __init__( self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): - return 1. / (1. + jnp.exp(-(V - self.V_sh + 36.) / 20.)) + return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.)) def f_p_tau(self, V): - return 1. / (jnp.exp((V - self.V_sh + 35.8) / 19.7) + - jnp.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 + return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + + bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 def f_q_inf(self, V): - return 1. / (1. + jnp.exp((V - self.V_sh + 78.) / 6.)) + return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) def f_q_tau(self, V): - return jnp.where(V < -63 + self.V_sh, - 1. / (jnp.exp((V - self.V_sh + 46.) / 5.) + - jnp.exp(-(V - self.V_sh + 238.) / 37.5)), + return bm.where(V < -63 + self.V_sh, + 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + + bm.exp(-(V - self.V_sh + 238.) / 37.5)), 19.) @@ -700,8 +698,8 @@ def __init__( self.phi_q = parameter(phi_q, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) - self.q = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(JointEq(self.dp, self.dq), method=method) @@ -814,18 +812,18 @@ def __init__( self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): - raise 1. / (1. + jnp.exp(-(V - self.V_sh + 43.) / 17.)) + raise 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) def f_p_tau(self, V): - return 1. / (jnp.exp((V - self.V_sh - 81.) / 25.6) + - jnp.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 + return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + + bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 def f_q_inf(self, V): - raise 1. / (1. + jnp.exp((V - self.V_sh + 58.) / 10.6)) + raise 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) def f_q_tau(self, V): - raise 1. / (jnp.exp((V - self.V_sh - 1329.) / 200.) + - jnp.exp(-(V - self.V_sh + 130.) / 7.1)) + raise 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + + bm.exp(-(V - self.V_sh + 130.) / 7.1)) class IKK2B_HM1992(_IKK2_pq_ss): @@ -905,19 +903,19 @@ def __init__( self.V_sh = parameter(V_sh, self.varshape, allow_none=False) def f_p_inf(self, V): - raise 1. / (1. + jnp.exp(-(V - self.V_sh + 43.) / 17.)) + raise 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) def f_p_tau(self, V): - return 1. / (jnp.exp((V - self.V_sh - 81.) / 25.6) + - jnp.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 + return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + + bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 def f_q_inf(self, V): - raise 1. / (1. + jnp.exp((V - self.V_sh + 58.) / 10.6)) + raise 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) def f_q_tau(self, V): - raise jnp.where(V < -70 + self.V_sh, - 1. / (jnp.exp((V - self.V_sh - 1329.) / 200.) + - jnp.exp(-(V - self.V_sh + 130.) / 7.1)), + raise bm.where(V < -70 + self.V_sh, + 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + + bm.exp(-(V - self.V_sh + 130.) / 7.1)), 8.9) @@ -991,7 +989,7 @@ def __init__( self.phi_q = parameter(phi_q, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(self.dp, method=method) @@ -1012,8 +1010,8 @@ def reset_state(self, V, batch_size=None): assert self.p.shape[0] == batch_size def f_p_inf(self, V): - raise 1. / (1. + jnp.exp(-(V - self.V_sh + 35.) / 10.)) + raise 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.)) def f_p_tau(self, V): temp = V - self.V_sh + 35. - raise self.tau_max / (3.3 * jnp.exp(temp / 20.) + jnp.exp(-temp / 20.)) + raise self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.)) diff --git a/brainpy/_src/dyn/channels/KCa.py b/brainpy/_src/dyn/channels/KCa.py index 103561d97..9165d647e 100644 --- a/brainpy/_src/dyn/channels/KCa.py +++ b/brainpy/_src/dyn/channels/KCa.py @@ -8,8 +8,6 @@ from typing import Union, Callable -import jax.numpy as jnp - import brainpy.math as bm from brainpy._src.initialize import Initializer, parameter, variable from brainpy._src.integrators.ode.generic import odeint @@ -101,13 +99,13 @@ def __init__( self.phi = parameter(phi, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(self.dp, method=method) def dp(self, p, t, C_Ca): - C2 = self.alpha * jnp.power(bm.as_jax(C_Ca), self.n) + C2 = self.alpha * bm.power(bm.as_jax(C_Ca), self.n) C3 = C2 + self.beta return self.phi * (C2 / C3 - p) * C3 @@ -119,10 +117,10 @@ def current(self, V, C_Ca, E_Ca): return self.g_max * self.p * self.p * (self.E - V) def reset_state(self, V, C_Ca, E_Ca, batch_size=None): - C2 = self.alpha * jnp.power(C_Ca, self.n) + C2 = self.alpha * bm.power(C_Ca, self.n) C3 = C2 + self.beta if batch_size is None: - self.p.value = jnp.broadcast_to(C2 / C3, self.varshape) + self.p.value = bm.broadcast_to(C2 / C3, self.varshape) else: - self.p.value = jnp.broadcast_to(C2 / C3, (batch_size,) + self.varshape) + self.p.value = bm.broadcast_to(C2 / C3, (batch_size,) + self.varshape) assert self.p.shape[0] == batch_size diff --git a/brainpy/_src/dyn/channels/Na.py b/brainpy/_src/dyn/channels/Na.py index d867d5334..533af4057 100644 --- a/brainpy/_src/dyn/channels/Na.py +++ b/brainpy/_src/dyn/channels/Na.py @@ -7,8 +7,6 @@ from typing import Union, Callable -import jax.numpy as jnp - import brainpy.math as bm from brainpy._src.initialize import Initializer, parameter, variable from brainpy._src.integrators import odeint, JointEq @@ -74,8 +72,8 @@ def __init__( self.g_max = parameter(g_max, self.varshape, allow_none=False) # variables - self.p = variable(jnp.zeros, self.mode, self.varshape) - self.q = variable(jnp.zeros, self.mode, self.varshape) + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) # function self.integral = odeint(JointEq([self.dp, self.dq]), method=method) @@ -186,17 +184,17 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh - 13. - return 0.32 * temp / (1. - jnp.exp(-temp / 4.)) + return 0.32 * temp / (1. - bm.exp(-temp / 4.)) def f_p_beta(self, V): temp = V - self.V_sh - 40. - return -0.28 * temp / (1. - jnp.exp(temp / 5.)) + return -0.28 * temp / (1. - bm.exp(temp / 5.)) def f_q_alpha(self, V): - return 0.128 * jnp.exp(-(V - self.V_sh - 17.) / 18.) + return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.) def f_q_beta(self, V): - return 4. / (1. + jnp.exp(-(V - self.V_sh - 40.) / 5.)) + return 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.)) class INa_TM1991(_INa_p3q_markov): @@ -272,17 +270,17 @@ def __init__( def f_p_alpha(self, V): temp = 13 - V + self.V_sh - return 0.32 * temp / (jnp.exp(temp / 4) - 1.) + return 0.32 * temp / (bm.exp(temp / 4) - 1.) def f_p_beta(self, V): temp = V - self.V_sh - 40 - return 0.28 * temp / (jnp.exp(temp / 5) - 1) + return 0.28 * temp / (bm.exp(temp / 5) - 1) def f_q_alpha(self, V): - return 0.128 * jnp.exp((17 - V + self.V_sh) / 18) + return 0.128 * bm.exp((17 - V + self.V_sh) / 18) def f_q_beta(self, V): - return 4. / (1 + jnp.exp(-(V - self.V_sh - 40) / 5)) + return 4. / (1 + bm.exp(-(V - self.V_sh - 40) / 5)) class INa_HH1952(_INa_p3q_markov): @@ -359,13 +357,13 @@ def __init__( def f_p_alpha(self, V): temp = V - self.V_sh - 5 - return 0.1 * temp / (1 - jnp.exp(-temp / 10)) + return 0.1 * temp / (1 - bm.exp(-temp / 10)) def f_p_beta(self, V): - return 4.0 * jnp.exp(-(V - self.V_sh + 20) / 18) + return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18) def f_q_alpha(self, V): - return 0.07 * jnp.exp(-(V - self.V_sh + 20) / 20.) + return 0.07 * bm.exp(-(V - self.V_sh + 20) / 20.) def f_q_beta(self, V): - return 1 / (1 + jnp.exp(-(V - self.V_sh - 10) / 10)) + return 1 / (1 + bm.exp(-(V - self.V_sh - 10) / 10)) diff --git a/brainpy/_src/dyn/neurons/biological_models.py b/brainpy/_src/dyn/neurons/biological_models.py index aabddf6fa..f016420f4 100644 --- a/brainpy/_src/dyn/neurons/biological_models.py +++ b/brainpy/_src/dyn/neurons/biological_models.py @@ -2,8 +2,6 @@ from typing import Union, Callable, Optional -import jax.numpy as jnp - import brainpy.math as bm from brainpy import check from brainpy._src.dyn.base import NeuGroup @@ -124,8 +122,8 @@ class HH(NeuGroup): .. plot:: :include-source: True - >>> import jax.numpy as jnp >>> import brainpy as bp + >>> import brainpy.math as bm >>> import matplotlib.pyplot as plt >>> >>> group = bp.neurons.HH(2) @@ -134,7 +132,7 @@ class HH(NeuGroup): >>> I2 = bp.inputs.spike_input(sp_times=[600., 900, 950, 1500], sp_lens=5, sp_sizes=5., duration=2000, ) >>> I1 += bp.math.random.normal(0, 3, size=I1.shape) >>> I2 += bp.math.random.normal(0, 3, size=I2.shape) - >>> I = jnp.stack((I1, I2), axis=-1) + >>> I = bm.stack((I1, I2), axis=-1) >>> >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', I, 'iter')) >>> runner.run(2000.) @@ -257,8 +255,8 @@ def __init__( self.n = (bm.Variable(self.n_inf(self.V.value)) if n_initializer is None else variable_(self._n_initializer, self.varshape, self.mode)) - self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) # integral if self.noise is None: @@ -267,20 +265,20 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) # m channel - m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - jnp.exp(-(V + 40) / 10)) - m_beta = lambda self, V: 4.0 * jnp.exp(-(V + 65) / 18) + m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18) m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m # h channel - h_alpha = lambda self, V: 0.07 * jnp.exp(-(V + 65) / 20.) - h_beta = lambda self, V: 1 / (1 + jnp.exp(-(V + 35) / 10)) + h_alpha = lambda self, V: 0.07 * bm.exp(-(V + 65) / 20.) + h_beta = lambda self, V: 1 / (1 + bm.exp(-(V + 35) / 10)) h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h # n channel - n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - jnp.exp(-(V + 55) / 10)) - n_beta = lambda self, V: 0.125 * jnp.exp(-(V + 65) / 80) + n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80) n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n @@ -298,8 +296,8 @@ def reset_state(self, batch_size=None): self.n.value = self.n_inf(self.V.value) else: self.n.value = variable_(self._n_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, m, h, n, I_ext): I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) @@ -316,7 +314,7 @@ def update(self, tdi, x=None): t, dt = tdi['t'], tdi['dt'] if x is not None: self.input += x V, m, h, n = self.integral(self.V, self.m, self.h, self.n, t, self.input, dt) - self.spike.value = jnp.logical_and(self.V < self.V_th, V >= self.V_th) + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.m.value = m self.h.value = h @@ -458,8 +456,8 @@ def __init__( # variables self.W = variable_(self._W_initializer, self.varshape, self.mode) self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) - self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) # integral if self.noise is None: @@ -470,11 +468,11 @@ def __init__( def reset_state(self, batch_size=None): self.W.value = variable_(self._W_initializer, self.varshape, batch_size) self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, W, I_ext): - M_inf = (1 / 2) * (1 + jnp.tanh((V - self.V1) / self.V2)) + M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) I_Ca = self.g_Ca * M_inf * (V - self.V_Ca) I_K = self.g_K * W * (V - self.V_K) I_Leak = self.g_leak * (V - self.V_leak) @@ -482,8 +480,8 @@ def dV(self, V, t, W, I_ext): return dVdt def dW(self, W, t, V): - tau_W = 1 / (self.phi * jnp.cosh((V - self.V3) / (2 * self.V4))) - W_inf = (1 / 2) * (1 + jnp.tanh((V - self.V3) / self.V4)) + tau_W = 1 / (self.phi * bm.cosh((V - self.V3) / (2 * self.V4))) + W_inf = (1 / 2) * (1 + bm.tanh((V - self.V3) / self.V4)) dWdt = (W_inf - W) / tau_W return dWdt @@ -495,7 +493,7 @@ def update(self, tdi, x=None): t, dt = tdi['t'], tdi['dt'] if x is not None: self.input += x V, self.W.value = self.integral(self.V, self.W, t, self.input, dt) - spike = jnp.logical_and(self.V < self.V_th, V >= self.V_th) + spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.spike.value = spike @@ -730,8 +728,8 @@ def __init__( self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) - self.Id = variable_(jnp.zeros, self.varshape, self.mode) # input to soma - self.Is = variable_(jnp.zeros, self.varshape, self.mode) # input to dendrite + self.Id = variable_(bm.zeros, self.varshape, self.mode) # input to soma + self.Is = variable_(bm.zeros, self.varshape, self.mode) # input to dendrite # self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool)) # integral @@ -750,8 +748,8 @@ def reset_state(self, batch_size=None): self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis) self.c.value = bm.Variable(self.inf_c(self.Vd), batch_axis=batch_axis) self.q.value = bm.Variable(self.inf_q(self.Ca), batch_axis=batch_axis) - self.Id.value = variable_(jnp.zeros, self.varshape, batch_size) - self.Is.value = variable_(jnp.zeros, self.varshape, batch_size) + self.Id.value = variable_(bm.zeros, self.varshape, batch_size) + self.Is.value = variable_(bm.zeros, self.varshape, batch_size) # self.spike[:] = False def dCa(self, Ca, t, s, Vd): @@ -785,7 +783,7 @@ def dVd(self, Vd, t, s, q, c, Ca, Vs): I_leak = self.gL * (Vd - self.EL) I_Ca = self.gCa * s * s * (Vd - self.ECa) I_AHP = self.gAHP * q * (Vd - self.EK) - I_C = self.gC * jnp.minimum(Ca / 250., 1.) * (Vd - self.EK) + I_C = self.gC * bm.minimum(Ca / 250., 1.) * (Vd - self.EK) p = 1 - self.p I_gj = self.gc / p * (Vs - Vd) dVdt = (- I_leak - I_Ca - I_AHP - I_C + I_gj + self.Id / p) / self.Cm @@ -821,10 +819,10 @@ def clear_input(self): self.Is[:] = 0. def alpha_m(self, Vs): - return 0.32 * (13.1 - (Vs + 60.)) / (jnp.exp((13.1 - (Vs + 60.)) / 4.) - 1.) + return 0.32 * (13.1 - (Vs + 60.)) / (bm.exp((13.1 - (Vs + 60.)) / 4.) - 1.) def beta_m(self, Vs): - return 0.28 * ((Vs + 60.) - 40.1) / (jnp.exp(((Vs + 60.) - 40.1) / 5.) - 1.) + return 0.28 * ((Vs + 60.) - 40.1) / (bm.exp(((Vs + 60.) - 40.1) / 5.) - 1.) def inf_m(self, Vs): alpha = self.alpha_m(Vs) @@ -832,10 +830,10 @@ def inf_m(self, Vs): return alpha / (alpha + beta) def alpha_n(self, Vs): - return 0.016 * (35.1 - (Vs + 60.)) / (jnp.exp((35.1 - (Vs + 60.)) / 5) - 1) + return 0.016 * (35.1 - (Vs + 60.)) / (bm.exp((35.1 - (Vs + 60.)) / 5) - 1) def beta_n(self, Vs): - return 0.25 * jnp.exp(0.5 - 0.025 * (Vs + 60.)) + return 0.25 * bm.exp(0.5 - 0.025 * (Vs + 60.)) def inf_n(self, Vs): alpha = self.alpha_n(Vs) @@ -843,10 +841,10 @@ def inf_n(self, Vs): return alpha / (alpha + beta) def alpha_h(self, Vs): - return 0.128 * jnp.exp((17. - (Vs + 60.)) / 18.) + return 0.128 * bm.exp((17. - (Vs + 60.)) / 18.) def beta_h(self, Vs): - return 4. / (1 + jnp.exp((40. - (Vs + 60.)) / 5)) + return 4. / (1 + bm.exp((40. - (Vs + 60.)) / 5)) def inf_h(self, Vs): alpha = self.alpha_h(Vs) @@ -854,10 +852,10 @@ def inf_h(self, Vs): return alpha / (alpha + beta) def alpha_s(self, Vd): - return 1.6 / (1 + jnp.exp(-0.072 * ((Vd + 60.) - 65.))) + return 1.6 / (1 + bm.exp(-0.072 * ((Vd + 60.) - 65.))) def beta_s(self, Vd): - return 0.02 * ((Vd + 60.) - 51.1) / (jnp.exp(((Vd + 60.) - 51.1) / 5.) - 1.) + return 0.02 * ((Vd + 60.) - 51.1) / (bm.exp(((Vd + 60.) - 51.1) / 5.) - 1.) def inf_s(self, Vd): alpha = self.alpha_s(Vd) @@ -865,13 +863,13 @@ def inf_s(self, Vd): return alpha / (alpha + beta) def alpha_c(self, Vd): - return jnp.where((Vd + 60.) <= 50., - (jnp.exp(((Vd + 60.) - 10.) / 11.) - jnp.exp(((Vd + 60.) - 6.5) / 27.)) / 18.975, - 2. * jnp.exp((6.5 - (Vd + 60.)) / 27.)) + return bm.where((Vd + 60.) <= 50., + (bm.exp(((Vd + 60.) - 10.) / 11.) - bm.exp(((Vd + 60.) - 6.5) / 27.)) / 18.975, + 2. * bm.exp((6.5 - (Vd + 60.)) / 27.)) def beta_c(self, Vd): - alpha_c = (jnp.exp(((Vd + 60.) - 10.) / 11.) - jnp.exp(((Vd + 60.) - 6.5) / 27.)) / 18.975 - return jnp.where((Vd + 60.) <= 50., 2. * jnp.exp((6.5 - (Vd + 60.)) / 27.) - alpha_c, 0.) + alpha_c = (bm.exp(((Vd + 60.) - 10.) / 11.) - bm.exp(((Vd + 60.) - 6.5) / 27.)) / 18.975 + return bm.where((Vd + 60.) <= 50., 2. * bm.exp((6.5 - (Vd + 60.)) / 27.) - alpha_c, 0.) def inf_c(self, Vd): alpha_c = self.alpha_c(Vd) @@ -879,7 +877,7 @@ def inf_c(self, Vd): return alpha_c / (alpha_c + beta_c) def alpha_q(self, Ca): - return jnp.minimum(2e-5 * Ca, 1e-2) + return bm.minimum(2e-5 * Ca, 1e-2) def beta_q(self, Ca): return 1e-3 @@ -1024,8 +1022,8 @@ def __init__( self.h = variable_(self._h_initializer, self.varshape, self.mode) self.n = variable_(self._n_initializer, self.varshape, self.mode) self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) - self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) # integral if self.noise is None: @@ -1037,23 +1035,23 @@ def reset_state(self, batch_size=None): self.h.value = variable_(self._h_initializer, self.varshape, batch_size) self.n.value = variable_(self._n_initializer, self.varshape, batch_size) self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def m_inf(self, V): - alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) - beta = 4. * jnp.exp(-(V + 60.) / 18.) + alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + beta = 4. * bm.exp(-(V + 60.) / 18.) return alpha / (alpha + beta) def dh(self, h, t, V): - alpha = 0.07 * jnp.exp(-(V + 58) / 20) - beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) + alpha = 0.07 * bm.exp(-(V + 58) / 20) + beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) dhdt = alpha * (1 - h) - beta * h return self.phi * dhdt def dn(self, n, t, V): - alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) - beta = 0.125 * jnp.exp(-(V + 44) / 80) + alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * bm.exp(-(V + 44) / 80) dndt = alpha * (1 - n) - beta * n return self.phi * dndt @@ -1072,7 +1070,7 @@ def update(self, tdi, x=None): t, dt = tdi['t'], tdi['dt'] if x is not None: self.input += x V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt) - self.spike.value = jnp.logical_and(self.V < self.V_th, V >= self.V_th) + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.h.value = h self.n.value = n diff --git a/brainpy/_src/dyn/neurons/reduced_models.py b/brainpy/_src/dyn/neurons/reduced_models.py index 35cec630c..099b63c8f 100644 --- a/brainpy/_src/dyn/neurons/reduced_models.py +++ b/brainpy/_src/dyn/neurons/reduced_models.py @@ -3,7 +3,6 @@ from functools import partial from typing import Union, Callable, Optional -import jax.numpy as jnp from jax.lax import stop_gradient import brainpy.math as bm @@ -102,7 +101,7 @@ def __init__( # variables self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) # integral if self.noise is None: @@ -115,7 +114,7 @@ def derivative(self, V, t, I_ext): def reset_state(self, batch_size=None): self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) def update(self, tdi, x=None): if x is not None: self.input += x @@ -227,12 +226,12 @@ def __init__( # variables self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool # the gradient of spike is a float - self.spike = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) if self.tau_ref is not None: - self.t_last_spike = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, self.mode) - self.refractory = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) # integral if self.noise is None: @@ -245,12 +244,12 @@ def derivative(self, V, t, I_ext): def reset_state(self, batch_size=None): self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, batch_size) - self.refractory.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt @@ -264,7 +263,7 @@ def update(self, tdi, x=None): refractory = (t - self.t_last_spike) <= self.tau_ref if isinstance(self.mode, bm.TrainingMode): refractory = stop_gradient(refractory) - V = jnp.where(refractory, self.V.value, V) + V = bm.where(refractory, self.V.value, V) # spike, refractory, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -273,13 +272,13 @@ def update(self, tdi, x=None): V += (self.V_reset - V) * spike_no_grad spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient - refractory = stop_gradient(jnp.logical_or(refractory, spike_).value) - t_last_spike = stop_gradient(jnp.where(spike_, t, self.t_last_spike.value)) + refractory = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) else: spike = V >= self.V_th - V = jnp.where(spike, self.V_reset, V) - refractory = jnp.logical_or(refractory, spike) - t_last_spike = jnp.where(spike, t, self.t_last_spike.value) + V = bm.where(spike, self.V_reset, V) + refractory = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) self.V.value = V self.spike.value = spike self.refractory.value = refractory @@ -293,7 +292,7 @@ def update(self, tdi, x=None): V += (self.V_reset - V) * spike_no_grad else: spike = V >= self.V_th - V = jnp.where(spike, self.V_reset, V) + V = bm.where(spike, self.V_reset, V) self.V.value = V self.spike.value = spike @@ -441,12 +440,12 @@ def __init__( # variables self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, self.mode) - self.t_last_spike = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) if self.tau_ref is not None: - self.refractory = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) # integral if self.noise is None: @@ -456,15 +455,15 @@ def __init__( def reset_state(self, batch_size=None): self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, batch_size) - self.t_last_spike.value = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def derivative(self, V, t, I_ext): - exp_v = self.delta_T * jnp.exp((V - self.V_T) / self.delta_T) + exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau return dvdt @@ -475,15 +474,15 @@ def update(self, tdi, x=None): if self.tau_ref is not None: refractory = (t - self.t_last_spike) <= self.tau_ref - V = jnp.where(refractory, self.V.value, V) + V = bm.where(refractory, self.V.value, V) spike = self.V_th <= V - t_last_spike = jnp.where(spike, t, self.t_last_spike.value) - V = jnp.where(spike, self.V_reset, V) - self.refractory.value = jnp.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + V = bm.where(spike, self.V_reset, V) + self.refractory.value = bm.logical_or(refractory, spike) else: spike = self.V_th <= V - t_last_spike = jnp.where(spike, t, self.t_last_spike.value) - V = jnp.where(spike, self.V_reset, V) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + V = bm.where(spike, self.V_reset, V) self.V.value = V self.spike.value = spike @@ -619,12 +618,12 @@ def __init__( # variables self.V = variable_(V_initializer, self.varshape, self.mode) self.w = variable_(w_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) sp_type = bm.float_ if isinstance(self.mode, bm.BatchingMode) else bool - self.spike = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) if self.tau_ref is not None: - self.refractory = variable_(partial(jnp.zeros, dtype=bool), self.varshape, self.mode) - self.t_last_spike = variable_(lambda s: jnp.ones(s) * -1e8, self.varshape, self.mode) + self.refractory = variable_(partial(bm.zeros, dtype=bool), self.varshape, self.mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, self.mode) # functions if self.noise is None: @@ -635,19 +634,19 @@ def __init__( def reset_state(self, batch_size=None): self.V.value = variable_(self._V_initializer, self.varshape, batch_size) self.w.value = variable_(self._w_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) self.spike.value = variable_( - lambda s: jnp.zeros(s, dtype=(bm.float_ + lambda s: bm.zeros(s, dtype=(bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool)), self.varshape, batch_size ) if self.tau_ref is not None: - self.refractory.value = variable_(partial(jnp.zeros, dtype=bool), self.varshape, batch_size) - self.t_last_spike.value = variable_(lambda s: jnp.ones(s) * -1e8, self.varshape, batch_size) + self.refractory.value = variable_(partial(bm.zeros, dtype=bool), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, batch_size) def dV(self, V, t, w, I_ext): - exp = self.delta_T * jnp.exp((V - self.V_T) / self.delta_T) + exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I_ext) / self.tau return dVdt @@ -665,14 +664,14 @@ def update(self, tdi, x=None): V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt) if self.tau_ref is not None: refractory = (t - self.t_last_spike) <= self.tau_ref - V = jnp.where(refractory, self.V.value, V) + V = bm.where(refractory, self.V.value, V) spike = V >= self.V_th - self.V.value = jnp.where(spike, self.V_reset, V) - self.w.value = jnp.where(spike, w + self.b, w) + self.V.value = bm.where(spike, self.V_reset, V) + self.w.value = bm.where(spike, w + self.b, w) self.spike.value = spike if self.tau_ref is not None: - self.refractory.value = jnp.logical_or(refractory, spike) - self.t_last_spike.value = jnp.where(spike, t, self.t_last_spike.value) + self.refractory.value = bm.logical_or(refractory, spike) + self.t_last_spike.value = bm.where(spike, t, self.t_last_spike.value) def clear_input(self): self.input[:] = 0. @@ -787,12 +786,12 @@ def __init__( # variables self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, self.mode) - self.t_last_spike = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) if self.tau_ref is not None: - self.refractory = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) # integral if self.noise is None: @@ -802,12 +801,12 @@ def __init__( def reset_state(self, batch_size=None): self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, batch_size) - self.t_last_spike.value = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def derivative(self, V, t, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau @@ -819,15 +818,15 @@ def update(self, tdi, x=None): V = self.integral(self.V.value, t, self.input.value, dt) if self.tau_ref is not None: refractory = (t - self.t_last_spike) <= self.tau_ref - V = jnp.where(refractory, self.V.value, V) + V = bm.where(refractory, self.V.value, V) spike = self.V_th <= V - t_last_spike = jnp.where(spike, t, self.t_last_spike.value) - V = jnp.where(spike, self.V_reset, V) - self.refractory.value = jnp.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + V = bm.where(spike, self.V_reset, V) + self.refractory.value = bm.logical_or(refractory, spike) else: spike = self.V_th <= V - t_last_spike = jnp.where(spike, t, self.t_last_spike.value) - V = jnp.where(spike, self.V_reset, V) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + V = bm.where(spike, self.V_reset, V) self.V.value = V self.spike.value = spike self.t_last_spike.value = t_last_spike @@ -960,10 +959,10 @@ def __init__( # variables self.V = variable_(V_initializer, self.varshape, self.mode) self.w = variable_(w_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, self.mode) - self.refractory = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) # integral if self.noise is None: @@ -974,10 +973,10 @@ def __init__( def reset_state(self, batch_size=None): self.V.value = variable_(self._V_initializer, self.varshape, batch_size) self.w.value = variable_(self._w_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, batch_size) - self.refractory.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, w, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau @@ -996,8 +995,8 @@ def update(self, tdi, x=None): if x is not None: self.input += x V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt) spike = self.V_th <= V - self.V.value = jnp.where(spike, self.V_reset, V) - self.w.value = jnp.where(spike, w + self.b, w) + self.V.value = bm.where(spike, self.V_reset, V) + self.w.value = bm.where(spike, w + self.b, w) self.spike.value = spike def clear_input(self): @@ -1156,9 +1155,9 @@ def __init__( self.I2 = variable_(I2_initializer, self.varshape, self.mode) self.V_th = variable_(Vth_initializer, self.varshape, self.mode) self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) # integral if self.noise is None: @@ -1171,9 +1170,9 @@ def reset_state(self, batch_size=None): self.I2.value = variable_(self._I2_initializer, self.varshape, batch_size) self.V_th.value = variable_(self._Vth_initializer, self.varshape, batch_size) self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dI1(self, I1, t): return - self.k1 * I1 @@ -1208,11 +1207,11 @@ def update(self, tdi, x=None): V_th += reset_th * (self.V_th_reset - V_th) else: spike = self.V_th <= V - V = jnp.where(spike, self.V_reset, V) - I1 = jnp.where(spike, self.R1 * I1 + self.A1, I1) - I2 = jnp.where(spike, self.R2 * I2 + self.A2, I2) - reset_th = jnp.logical_and(V_th < self.V_th_reset, spike) - V_th = jnp.where(reset_th, self.V_th_reset, V_th) + V = bm.where(spike, self.V_reset, V) + I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) + I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) + reset_th = bm.logical_and(V_th < self.V_th_reset, spike) + V_th = bm.where(reset_th, self.V_th_reset, V_th) self.spike.value = spike self.I1.value = I1 self.I2.value = I2 @@ -1309,12 +1308,12 @@ def __init__( # variables self.a = variable_(a_initializer, self.varshape, self.mode) self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) if self.tau_ref is not None: - self.t_last_spike = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, self.mode) - self.refractory = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) # integral if self.noise is None: @@ -1335,12 +1334,12 @@ def derivative(self): def reset_state(self, batch_size=None): self.a.value = variable_(self._a_initializer, self.varshape, batch_size) self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, batch_size) - self.refractory.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt @@ -1354,19 +1353,19 @@ def update(self, tdi, x=None): refractory = (t - self.t_last_spike) <= self.tau_ref if isinstance(self.mode, bm.TrainingMode): refractory = stop_gradient(refractory) - V = jnp.where(refractory, self.V.value, V) + V = bm.where(refractory, self.V.value, V) # spike and reset if isinstance(self.mode, bm.TrainingMode): spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) V -= self.V_th * (stop_gradient(spike) if self.eprop else spike) # will be used in other place, like Delta Synapse, so stop its gradient spike_ = spike > 0. - refractory = stop_gradient(jnp.logical_or(refractory, spike_)) - t_last_spike = stop_gradient(jnp.where(spike_, t, self.t_last_spike.value)) + refractory = stop_gradient(bm.logical_or(refractory, spike_)) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) else: spike = V >= (self.V_th + self.beta * self.a) - refractory = jnp.logical_or(refractory, spike) - t_last_spike = jnp.where(spike, t, self.t_last_spike.value) + refractory = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) V -= self.V_th * spike self.refractory.value = refractory self.t_last_spike.value = t_last_spike @@ -1499,12 +1498,12 @@ def __init__( # variables self.u = variable_(u_initializer, self.varshape, self.mode) self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) if self.tau_ref is not None: - self.t_last_spike = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, self.mode) - self.refractory = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) # functions if self.noise is None: @@ -1515,12 +1514,12 @@ def __init__( def reset_state(self, batch_size=None): self.V.value = variable_(self._V_initializer, self.varshape, batch_size) self.u.value = variable_(self._u_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable_(lambda s: jnp.ones(s) * -1e7, self.varshape, batch_size) - self.refractory.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, u, I_ext): dVdt = 0.04 * V * V + 5 * V + 140 - u + I_ext @@ -1541,7 +1540,7 @@ def update(self, tdi, x=None): refractory = (t - self.t_last_spike) <= self.tau_ref if isinstance(self.mode, bm.TrainingMode): refractory = stop_gradient(refractory) - V = jnp.where(refractory, self.V.value, V) + V = bm.where(refractory, self.V.value, V) # spike, refractory, and reset membrane potential if isinstance(self.mode, bm.TrainingMode): @@ -1550,14 +1549,14 @@ def update(self, tdi, x=None): V += spike_no_grad * (self.c - self.V_th) u += spike_no_grad * self.d spike_ = spike_no_grad > 0. - refractory = stop_gradient(jnp.logical_or(refractory, spike_)) - t_last_spike = stop_gradient(jnp.where(spike_, t, self.t_last_spike.value)) + refractory = stop_gradient(bm.logical_or(refractory, spike_)) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) else: spike = self.V_th <= V - V = jnp.where(spike, self.c, V) - u = jnp.where(spike, u + self.d, u) - refractory = jnp.logical_or(refractory, spike) - t_last_spike = jnp.where(spike, t, self.t_last_spike.value) + V = bm.where(spike, self.c, V) + u = bm.where(spike, u + self.d, u) + refractory = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) self.refractory.value = refractory self.t_last_spike.value = t_last_spike @@ -1570,8 +1569,8 @@ def update(self, tdi, x=None): u += spike_no_grad * self.d else: spike = self.V_th <= V - V = jnp.where(spike, self.c, V) - u = jnp.where(spike, u + self.d, u) + V = bm.where(spike, self.c, V) + u = bm.where(spike, u + self.d, u) # finally self.V.value = V @@ -1611,7 +1610,7 @@ class HindmarshRose(NeuGroup): .. plot:: :include-source: True - >>> import jax.numpy as jnp + >>> import brainpy.math as bm >>> import brainpy as bp >>> import matplotlib.pyplot as plt >>> @@ -1619,8 +1618,8 @@ class HindmarshRose(NeuGroup): >>> bp.ode.set_default_odeint('rk4') >>> >>> types = ['quiescence', 'spiking', 'bursting', 'irregular_spiking', 'irregular_bursting'] - >>> bs = jnp.array([1.0, 3.5, 2.5, 2.95, 2.8]) - >>> Is = jnp.array([2.0, 5.0, 3.0, 3.3, 3.7]) + >>> bs = bm.array([1.0, 3.5, 2.5, 2.95, 2.8]) + >>> Is = bm.array([2.0, 5.0, 3.0, 3.3, 3.7]) >>> >>> # define neuron type >>> group = bp.neurons.HindmarshRose(len(types), b=bs) @@ -1735,9 +1734,9 @@ def __init__( self.V = variable_(self._V_initializer, self.varshape, self.mode) self.y = variable_(self._y_initializer, self.varshape, self.mode) self.z = variable_(self._z_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) # integral if self.noise is None: @@ -1749,9 +1748,9 @@ def reset_state(self, batch_size=None): self.V.value = variable_(self._V_initializer, self.varshape, batch_size) self.y.value = variable_(self._y_initializer, self.varshape, batch_size) self.z.value = variable_(self._z_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dV(self, V, t, y, z, I_ext): return y - self.a * V * V * V + self.b * V * V - z + I_ext @@ -1773,7 +1772,7 @@ def update(self, tdi, x=None): if isinstance(self.mode, bm.TrainingMode): self.spike.value = self.spike_fun(V - self.V_th, self.V - self.V_th) else: - self.spike.value = jnp.logical_and(V >= self.V_th, self.V < self.V_th) + self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) self.V.value = V self.y.value = y self.z.value = z @@ -1905,8 +1904,8 @@ def __init__( # variables self.V = variable_(self._V_initializer, self.varshape, self.mode) self.w = variable_(self._w_initializer, self.varshape, self.mode) - self.input = variable_(jnp.zeros, self.varshape, self.mode) - self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) + self.input = variable_(bm.zeros, self.varshape, self.mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) # integral if self.noise is None: @@ -1917,8 +1916,8 @@ def __init__( def reset_state(self, batch_size=None): self.V.value = variable_(self._V_initializer, self.varshape, batch_size) self.w.value = variable_(self._w_initializer, self.varshape, batch_size) - self.input.value = variable_(jnp.zeros, self.varshape, batch_size) - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, w, I_ext): return V - V * V * V / 3 - w + I_ext @@ -1934,7 +1933,7 @@ def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt if x is not None: self.input += x V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt=dt) - self.spike.value = jnp.logical_and(V >= self.Vth, self.V < self.Vth) + self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) self.V.value = V self.w.value = w diff --git a/brainpy/_src/dyn/rates/populations.py b/brainpy/_src/dyn/rates/populations.py index a9e9f3316..170f8a385 100644 --- a/brainpy/_src/dyn/rates/populations.py +++ b/brainpy/_src/dyn/rates/populations.py @@ -2,8 +2,6 @@ from typing import Union, Callable -import jax.numpy as jnp - from brainpy import check, math as bm from brainpy._src.dyn.base import NeuGroup from brainpy._src.dyn.neurons.noise_groups import OUProcess @@ -125,18 +123,18 @@ def __init__( # variables self.x = variable(x_initializer, self.mode, self.varshape) self.y = variable(y_initializer, self.mode, self.varshape) - self.input = variable(jnp.zeros, self.mode, self.varshape) - self.input_y = variable(jnp.zeros, self.mode, self.varshape) + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) # noise variables self.x_ou = self.y_ou = None - if jnp.any(self.x_ou_mean > 0.) or jnp.any(self.x_ou_sigma > 0.): + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): self.x_ou = OUProcess(self.varshape, self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, method=method) - if jnp.any(self.y_ou_mean > 0.) or jnp.any(self.y_ou_sigma > 0.): + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): self.y_ou = OUProcess(self.varshape, self.y_ou_mean, self.y_ou_sigma, @@ -149,8 +147,8 @@ def __init__( def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) - self.input.value = variable(jnp.zeros, batch_size, self.varshape) - self.input_y.value = variable(jnp.zeros, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: @@ -311,18 +309,18 @@ def __init__( self.x = variable(x_initializer, self.mode, self.varshape) self.y = variable(y_initializer, self.mode, self.varshape) self.x_delay = bm.TimeDelay(self.x, self.delay, dt=self.dt, interp_method='round') - self.input = variable(jnp.zeros, self.mode, self.varshape) - self.input_y = variable(jnp.zeros, self.mode, self.varshape) + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) # noise variables self.x_ou = self.y_ou = None - if jnp.any(self.x_ou_mean > 0.) or jnp.any(self.x_ou_sigma > 0.): + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): self.x_ou = OUProcess(self.varshape, self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, method=method) - if jnp.any(self.y_ou_mean > 0.) or jnp.any(self.y_ou_sigma > 0.): + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): self.y_ou = OUProcess(self.varshape, self.y_ou_mean, self.y_ou_sigma, @@ -338,8 +336,8 @@ def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) self.x_delay.reset(self.x, self.delay) - self.input = variable(jnp.zeros, batch_size, self.varshape) - self.input_y = variable(jnp.zeros, batch_size, self.varshape) + self.input = variable(bm.zeros, batch_size, self.varshape) + self.input_y = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: @@ -360,7 +358,7 @@ def update(self, tdi, x=None): t = tdi['t'] dt = tdi['dt'] if check.is_checking(): - jit_error_checking(not jnp.isclose(dt, self.dt), self._check_dt, dt) + jit_error_checking(not bm.isclose(dt, self.dt), self._check_dt, dt) if x is not None: self.input += x if self.x_ou is not None: @@ -504,18 +502,18 @@ def __init__( # variables self.x = variable(x_initializer, self.mode, self.varshape) self.y = variable(y_initializer, self.mode, self.varshape) - self.input = variable(jnp.zeros, self.mode, self.varshape) - self.input_y = variable(jnp.zeros, self.mode, self.varshape) + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) # noise variables self.x_ou = self.y_ou = None - if jnp.any(self.x_ou_mean > 0.) or jnp.any(self.x_ou_sigma > 0.): + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): self.x_ou = OUProcess(self.varshape, self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, method=method) - if jnp.any(self.y_ou_mean > 0.) or jnp.any(self.y_ou_sigma > 0.): + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): self.y_ou = OUProcess(self.varshape, self.y_ou_mean, self.y_ou_sigma, @@ -528,19 +526,19 @@ def __init__( def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) - self.input.value = variable(jnp.zeros, batch_size, self.varshape) - self.input_y.value = variable(jnp.zeros, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: self.y_ou.reset_state(batch_size) def dy(self, y, t, x, y_ext): - return (self.delta / (jnp.pi * self.tau) + 2. * x * y + y_ext) / self.tau + return (self.delta / (bm.pi * self.tau) + 2. * x * y + y_ext) / self.tau def dx(self, x, t, y, x_ext): return (x ** 2 + self.eta + x_ext + self.J * y * self.tau - - (jnp.pi * y * self.tau) ** 2) / self.tau + (bm.pi * y * self.tau) ** 2) / self.tau def update(self, tdi, x=None): t, dt = tdi['t'], tdi['dt'] @@ -640,18 +638,18 @@ def __init__( # variables self.x = variable(x_initializer, self.mode, self.varshape) self.y = variable(y_initializer, self.mode, self.varshape) - self.input = variable(jnp.zeros, self.mode, self.varshape) - self.input_y = variable(jnp.zeros, self.mode, self.varshape) + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) # noise variables self.x_ou = self.y_ou = None - if jnp.any(self.x_ou_mean > 0.) or jnp.any(self.x_ou_sigma > 0.): + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): self.x_ou = OUProcess(self.varshape, self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, method=method) - if jnp.any(self.y_ou_mean > 0.) or jnp.any(self.y_ou_sigma > 0.): + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): self.y_ou = OUProcess(self.varshape, self.y_ou_mean, self.y_ou_sigma, @@ -664,8 +662,8 @@ def __init__( def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) - self.input.value = variable(jnp.zeros, batch_size, self.varshape) - self.input_y.value = variable(jnp.zeros, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: @@ -801,18 +799,18 @@ def __init__( # variables self.x = variable(x_initializer, self.mode, self.varshape) self.y = variable(y_initializer, self.mode, self.varshape) - self.input = variable(jnp.zeros, self.mode, self.varshape) - self.input_y = variable(jnp.zeros, self.mode, self.varshape) + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) # noise variables self.x_ou = self.y_ou = None - if jnp.any(self.x_ou_mean > 0.) or jnp.any(self.x_ou_sigma > 0.): + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): self.x_ou = OUProcess(self.varshape, self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, method=method) - if jnp.any(self.y_ou_mean > 0.) or jnp.any(self.y_ou_sigma > 0.): + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): self.y_ou = OUProcess(self.varshape, self.y_ou_mean, self.y_ou_sigma, @@ -825,15 +823,15 @@ def __init__( def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) - self.input.value = variable(jnp.zeros, batch_size, self.varshape) - self.input_y.value = variable(jnp.zeros, batch_size, self.varshape) + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: self.y_ou.reset_state(batch_size) def F(self, x, a, theta): - return 1 / (1 + jnp.exp(-a * (x - theta))) - 1 / (1 + jnp.exp(a * theta)) + return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta)) def dx(self, x, t, y, x_ext): x = self.wEE * x - self.wIE * y + x_ext @@ -944,9 +942,9 @@ def __init__( # variables self.e = variable(e_initializer, self.mode, self.varshape) # Firing rate of excitatory population self.i = variable(i_initializer, self.mode, self.varshape) # Firing rate of inhibitory population - self.Ie = variable(jnp.zeros, self.mode, self.varshape) # Input of excitaory population - self.Ii = variable(jnp.zeros, self.mode, self.varshape) # Input of inhibitory population - if jnp.any(self.noise_e != 0) or jnp.any(self.noise_i != 0): + self.Ie = variable(bm.zeros, self.mode, self.varshape) # Input of excitaory population + self.Ii = variable(bm.zeros, self.mode, self.varshape) # Input of inhibitory population + if bm.any(self.noise_e != 0) or bm.any(self.noise_i != 0): self.rng = bm.random.default_rng(seed) def reset(self, batch_size=None): @@ -956,24 +954,24 @@ def reset(self, batch_size=None): def reset_state(self, batch_size=None): self.e.value = variable(self._e_initializer, batch_size, self.varshape) self.i.value = variable(self._i_initializer, batch_size, self.varshape) - self.Ie.value = variable(jnp.zeros, batch_size, self.varshape) - self.Ii.value = variable(jnp.zeros, batch_size, self.varshape) + self.Ie.value = variable(bm.zeros, batch_size, self.varshape) + self.Ii.value = variable(bm.zeros, batch_size, self.varshape) def update(self, tdi, x=None): t, dt = tdi['t'], tdi['dt'] if x is not None: self.Ie += x - de = -self.e + self.beta_e * jnp.maximum(self.Ie, 0.) - if jnp.any(self.noise_e != 0.): + de = -self.e + self.beta_e * bm.maximum(self.Ie, 0.) + if bm.any(self.noise_e != 0.): de += self.rng.randn(self.varshape) * self.noise_e de = de / self.tau_e - self.e.value = jnp.maximum(self.e + de * dt, 0.) + self.e.value = bm.maximum(self.e + de * dt, 0.) - di = -self.i + self.beta_i * jnp.maximum(self.Ii, 0.) - if jnp.any(self.noise_i != 0.): + di = -self.i + self.beta_i * bm.maximum(self.Ii, 0.) + if bm.any(self.noise_i != 0.): di += self.rng.randn(self.varshape) * self.noise_i di = di / self.tau_i - self.i.value = jnp.maximum(self.i + di * dt, 0.) + self.i.value = bm.maximum(self.i + di * dt, 0.) def clear_input(self): self.Ie.value = bm.zeros_like(self.Ie) diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py index 3e05602ff..f4aeee3cc 100644 --- a/brainpy/_src/dyn/synapses/abstract_models.py +++ b/brainpy/_src/dyn/synapses/abstract_models.py @@ -5,7 +5,6 @@ import brainpylib as bl from jax import vmap from jax.lax import stop_gradient -import jax.numpy as jnp import brainpy.math as bm from brainpy._src.connect import TwoEndConnector, All2All, One2One @@ -142,12 +141,12 @@ def update(self, tdi, pre_spike=None): # synaptic values onto the post if isinstance(self.conn, All2All): - syn_value = jnp.asarray(pre_spike, dtype=bm.float_) + syn_value = bm.asarray(pre_spike, dtype=bm.float_) if self.stp is not None: syn_value = self.stp(syn_value) post_vs = self._syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - syn_value = jnp.asarray(pre_spike, dtype=bm.float_) + syn_value = bm.asarray(pre_spike, dtype=bm.float_) if self.stp is not None: syn_value = self.stp(syn_value) post_vs = self._syn2post_with_one2one(syn_value, self.g_max) @@ -166,7 +165,7 @@ def update(self, tdi, pre_spike=None): # if self.trainable: f2 = vmap(f2) # post_vs *= f2(stp_value) else: - syn_value = jnp.asarray(pre_spike, dtype=bm.float_) + syn_value = bm.asarray(pre_spike, dtype=bm.float_) if self.stp is not None: syn_value = self.stp(syn_value) post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) @@ -301,21 +300,21 @@ def __init__( self.stop_spike_gradient = stop_spike_gradient self.comp_method = comp_method self.tau = tau - if jnp.size(self.tau) != 1: + if bm.size(self.tau) != 1: raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}') # connections and weights self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr') # variables - self.g = variable_(jnp.zeros, self.post.num, self.mode) + self.g = variable_(bm.zeros, self.post.num, self.mode) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # function self.integral = odeint(lambda g, t: -g / self.tau, method=method) def reset_state(self, batch_size=None): - self.g.value = variable_(jnp.zeros, self.post.num, batch_size) + self.g.value = variable_(bm.zeros, self.post.num, batch_size) self.output.reset_state(batch_size) if self.stp is not None: self.stp.reset_state(batch_size) @@ -335,11 +334,11 @@ def update(self, tdi, pre_spike=None): # post values if isinstance(self.conn, All2All): - syn_value = jnp.asarray(pre_spike, dtype=bm.float_) + syn_value = bm.asarray(pre_spike, dtype=bm.float_) if self.stp is not None: syn_value = self.stp(syn_value) post_vs = self._syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): - syn_value = jnp.asarray(pre_spike, dtype=bm.float_) + syn_value = bm.asarray(pre_spike, dtype=bm.float_) if self.stp is not None: syn_value = self.stp(syn_value) post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: @@ -354,7 +353,7 @@ def update(self, tdi, pre_spike=None): # if not isinstance(self.stp, _NullSynSTP): # raise NotImplementedError() else: - syn_value = jnp.asarray(pre_spike, dtype=bm.float_) + syn_value = bm.asarray(pre_spike, dtype=bm.float_) if self.stp is not None: syn_value = self.stp(syn_value) post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # updates @@ -492,10 +491,10 @@ def __init__( self.comp_method = comp_method self.tau_rise = tau_rise self.tau_decay = tau_decay - if jnp.size(self.tau_rise) != 1: + if bm.size(self.tau_rise) != 1: raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. ' f'But we got {self.tau_rise}') - if jnp.size(self.tau_decay) != 1: + if bm.size(self.tau_decay) != 1: raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. ' f'But we got {self.tau_decay}') @@ -503,16 +502,16 @@ def __init__( self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr') # variables - self.h = variable_(jnp.zeros, self.pre.num, self.mode) - self.g = variable_(jnp.zeros, self.pre.num, self.mode) + self.h = variable_(bm.zeros, self.pre.num, self.mode) + self.g = variable_(bm.zeros, self.pre.num, self.mode) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # integral self.integral = odeint(method=method, f=JointEq([self.dg, self.dh])) def reset_state(self, batch_size=None): - self.h.value = variable_(jnp.zeros, self.pre.num, batch_size) - self.g.value = variable_(jnp.zeros, self.pre.num, batch_size) + self.h.value = variable_(bm.zeros, self.pre.num, batch_size) + self.g.value = variable_(bm.zeros, self.pre.num, batch_size) self.output.reset_state(batch_size) if self.stp is not None: self.stp.reset_state(batch_size) @@ -836,11 +835,11 @@ def __init__( self.tau_decay = tau_decay self.tau_rise = tau_rise self.a = a - if jnp.size(a) != 1: + if bm.size(a) != 1: raise ValueError(f'"a" must be a scalar or a tensor with size of 1. But we got {a}') - if jnp.size(tau_decay) != 1: + if bm.size(tau_decay) != 1: raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. But we got {tau_decay}') - if jnp.size(tau_rise) != 1: + if bm.size(tau_rise) != 1: raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. But we got {tau_rise}') self.comp_method = comp_method self.stop_spike_gradient = stop_spike_gradient @@ -849,8 +848,8 @@ def __init__( self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr') # variables - self.g = variable_(jnp.zeros, self.pre.num, self.mode) - self.x = variable_(jnp.zeros, self.pre.num, self.mode) + self.g = variable_(bm.zeros, self.pre.num, self.mode) + self.x = variable_(bm.zeros, self.pre.num, self.mode) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # integral @@ -863,8 +862,8 @@ def dx(self, x, t): return -x / self.tau_rise def reset_state(self, batch_size=None): - self.g.value = variable_(jnp.zeros, self.pre.num, batch_size) - self.x.value = variable_(jnp.zeros, self.pre.num, batch_size) + self.g.value = variable_(bm.zeros, self.pre.num, batch_size) + self.x.value = variable_(bm.zeros, self.pre.num, batch_size) self.output.reset_state(batch_size) if self.stp is not None: self.stp.reset_state(batch_size) diff --git a/brainpy/_src/dyn/synapses/biological_models.py b/brainpy/_src/dyn/synapses/biological_models.py index e532697ef..a21f0dc51 100644 --- a/brainpy/_src/dyn/synapses/biological_models.py +++ b/brainpy/_src/dyn/synapses/biological_models.py @@ -3,7 +3,7 @@ from typing import Union, Dict, Callable, Optional import brainpylib as bl -from jax import vmap, numpy as jnp +from jax import vmap from jax.lax import stop_gradient import brainpy.math as bm @@ -170,29 +170,29 @@ def __init__( self.beta = beta self.T = T self.T_duration = T_duration - if jnp.size(alpha) != 1: + if bm.size(alpha) != 1: raise ValueError(f'"alpha" must be a scalar or a tensor with size of 1. But we got {alpha}') - if jnp.size(beta) != 1: + if bm.size(beta) != 1: raise ValueError(f'"beta" must be a scalar or a tensor with size of 1. But we got {beta}') - if jnp.size(T) != 1: + if bm.size(T) != 1: raise ValueError(f'"T" must be a scalar or a tensor with size of 1. But we got {T}') - if jnp.size(T_duration) != 1: + if bm.size(T_duration) != 1: raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}') # connection self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij') # variables - self.g = variable(jnp.zeros, self.mode, self.pre.num) - self.spike_arrival_time = variable(lambda s: jnp.ones(s) * -1e7, self.mode, self.pre.num) + self.g = variable(bm.zeros, self.mode, self.pre.num) + self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, self.mode, self.pre.num) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # functions self.integral = odeint(method=method, f=self.dg) def reset_state(self, batch_size=None): - self.g = variable(jnp.zeros, batch_size, self.pre.num) - self.spike_arrival_time = variable(lambda s: jnp.ones(s) * -1e7, batch_size, self.pre.num) + self.g = variable(bm.zeros, batch_size, self.pre.num) + self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num) self.output.reset_state(batch_size) if self.stp is not None: self.stp.reset_state(batch_size) @@ -215,7 +215,7 @@ def update(self, tdi, pre_spike=None): if self.stp is not None: self.stp.update(tdi, pre_spike) # update synaptic variables - self.spike_arrival_time.value = jnp.where(pre_spike, t, self.spike_arrival_time.value) + self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time.value) if isinstance(self.mode, bm.TrainingMode): self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value) TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T @@ -504,17 +504,17 @@ def __init__( self.alpha2 = alpha2 self.T_0 = T_0 self.T_dur = T_dur - if jnp.size(alpha1) != 1: + if bm.size(alpha1) != 1: raise ValueError(f'"alpha1" must be a scalar or a tensor with size of 1. But we got {alpha1}') - if jnp.size(beta1) != 1: + if bm.size(beta1) != 1: raise ValueError(f'"beta1" must be a scalar or a tensor with size of 1. But we got {beta1}') - if jnp.size(alpha2) != 1: + if bm.size(alpha2) != 1: raise ValueError(f'"alpha2" must be a scalar or a tensor with size of 1. But we got {alpha2}') - if jnp.size(beta2) != 1: + if bm.size(beta2) != 1: raise ValueError(f'"beta2" must be a scalar or a tensor with size of 1. But we got {beta2}') - if jnp.size(T_0) != 1: + if bm.size(T_0) != 1: raise ValueError(f'"T_0" must be a scalar or a tensor with size of 1. But we got {T_0}') - if jnp.size(T_dur) != 1: + if bm.size(T_dur) != 1: raise ValueError(f'"T_dur" must be a scalar or a tensor with size of 1. But we got {T_dur}') self.comp_method = comp_method self.stop_spike_gradient = stop_spike_gradient @@ -523,18 +523,18 @@ def __init__( self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij') # variables - self.g = variable(jnp.zeros, self.mode, self.pre.num) - self.x = variable(jnp.zeros, self.mode, self.pre.num) - self.spike_arrival_time = variable(lambda s: jnp.ones(s) * -1e7, self.mode, self.pre.num) + self.g = variable(bm.zeros, self.mode, self.pre.num) + self.x = variable(bm.zeros, self.mode, self.pre.num) + self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, self.mode, self.pre.num) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # integral self.integral = odeint(method=method, f=JointEq([self.dg, self.dx])) def reset_state(self, batch_size=None): - self.g = variable(jnp.zeros, batch_size, self.pre.num) - self.x = variable(jnp.zeros, batch_size, self.pre.num) - self.spike_arrival_time = variable(lambda s: jnp.ones(s) * -1e7, batch_size, self.pre.num) + self.g = variable(bm.zeros, batch_size, self.pre.num) + self.x = variable(bm.zeros, batch_size, self.pre.num) + self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num) self.output.reset_state(batch_size) if self.stp is not None: self.stp.reset_state(batch_size) @@ -559,7 +559,7 @@ def update(self, tdi, pre_spike=None): if self.stp is not None: self.stp.update(tdi, pre_spike) # update synapse variables - self.spike_arrival_time.value = jnp.where(pre_spike, t, self.spike_arrival_time.value) + self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time.value) if isinstance(self.mode, bm.TrainingMode): self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value) T = ((t - self.spike_arrival_time) < self.T_dur) * self.T_0 diff --git a/brainpy/_src/initialize/random_inits.py b/brainpy/_src/initialize/random_inits.py index 99419eaa6..9c4488101 100644 --- a/brainpy/_src/initialize/random_inits.py +++ b/brainpy/_src/initialize/random_inits.py @@ -114,7 +114,7 @@ def __init__(self, mean=0., scale=1., seed=None): def __call__(self, *shape, dtype=None): shape = _format_shape(shape) weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale) - return bm.as_jax(weights, dtype=dtype) + return bm.asarray(weights, dtype=dtype) def __repr__(self): return f'{self.__class__.__name__}(scale={self.scale}, rng={self.rng})' @@ -140,7 +140,7 @@ def __init__(self, min_val: float = 0., max_val: float = 1., seed=None): def __call__(self, shape, dtype=None): shape = _format_shape(shape) r = self.rng.uniform(low=self.min_val, high=self.max_val, size=shape) - return bm.as_jax(r, dtype=dtype) + return bm.asarray(r, dtype=dtype) def __repr__(self): return (f'{self.__class__.__name__}(min_val={self.min_val}, ' @@ -187,7 +187,7 @@ def __call__(self, shape, dtype=None): res = self.rng.uniform(low=-1, high=1, size=shape) * jnp.sqrt(3 * variance).astype(dtype) else: raise ValueError("invalid distribution for variance scaling initializer") - return bm.as_jax(res, dtype=dtype) + return bm.asarray(res, dtype=dtype) def __repr__(self): name = self.__class__.__name__ @@ -336,7 +336,7 @@ def __call__(self, shape, dtype=None): q_mat = q_mat.T q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis))) q_mat = jnp.moveaxis(q_mat, 0, self.axis) - return self.scale * bm.as_jax(q_mat, dtype=dtype) + return self.scale * bm.asarray(q_mat, dtype=dtype) def __repr__(self): return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, rng={self.rng})' @@ -372,7 +372,7 @@ def __call__(self, shape, dtype=None): else: k1, k2, k3 = shape[:3] W[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2, ...] = ortho_matrix - return W.value + return W def __repr__(self): return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis})' diff --git a/brainpy/_src/inputs/currents.py b/brainpy/_src/inputs/currents.py index 555f989b8..dbaf57956 100644 --- a/brainpy/_src/inputs/currents.py +++ b/brainpy/_src/inputs/currents.py @@ -70,9 +70,9 @@ def section_input(values, durations, dt=None, return_length=False): start += length if return_length: - return I_current.value, I_duration + return I_current, I_duration else: - return I_current.value + return I_current def constant_input(I_and_duration, dt=None): @@ -118,7 +118,7 @@ def constant_input(I_and_duration, dt=None): length = int(duration / dt) I_current[start: start + length] = c_size start += length - return I_current.value, I_duration + return I_current, I_duration def constant_current(*args, **kwargs): @@ -177,7 +177,7 @@ def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None): pp = int(time / dt) p_len = int(dur / dt) current[pp: pp + p_len] = size - return current.value + return current def spike_current(*args, **kwargs): @@ -223,7 +223,7 @@ def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None): p2 = int(np.ceil(t_end / dt)) cc = jnp.array(jnp.linspace(c_start, c_end, p2 - p1)) current[p1: p2] = cc - return current.value + return current def ramp_current(*args, **kwargs): @@ -267,7 +267,7 @@ def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None): noises = rng.standard_normal((i_end - i_start, n)) * jnp.sqrt(dt) currents = bm.zeros((int(duration / dt), n)) currents[i_start: i_end] = noises - return currents.value + return currents def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, seed=None): @@ -316,7 +316,7 @@ def _f(t): i_end = int(t_end / dt) currents = bm.zeros((int(duration / dt), n)) currents[i_start: i_end] = noises - return currents.value + return currents def sinusoidal_input(amplitude, frequency, duration, dt=None, t_start=0., t_end=None, bias=False): @@ -351,7 +351,7 @@ def sinusoidal_input(amplitude, frequency, duration, dt=None, t_start=0., t_end= if bias: sin_inputs += amplitude currents = bm.zeros(int(duration / dt)) currents[start_i:end_i] = sin_inputs - return currents.value + return currents def _square(t, duty=0.5): @@ -412,4 +412,4 @@ def square_input(amplitude, frequency, duration, dt=None, bias=False, t_start=0. start_i = int(t_start / dt) end_i = int(t_end / dt) currents[start_i:end_i] = bm.asarray(sin_inputs) - return currents.value + return currents diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py index 5c4e90ad7..85fad0fb0 100644 --- a/brainpy/_src/integrators/ode/exponential.py +++ b/brainpy/_src/integrators/ode/exponential.py @@ -107,8 +107,6 @@ import logging -import jax.numpy as jnp - from functools import wraps from brainpy import errors from brainpy._src import math as bm @@ -366,9 +364,9 @@ def integral(*args, **kwargs): assert len(args) > 0 dt = kwargs.pop(C.DT, self.dt) linear, derivative = value_and_grad(*args, **kwargs) - phi = jnp.where(linear == 0., - jnp.ones_like(linear), - (jnp.exp(dt * linear) - 1) / (dt * linear)) + phi = bm.where(linear == 0., + bm.ones_like(linear), + (bm.exp(dt * linear) - 1) / (dt * linear)) return args[0] + dt * phi * derivative return [(integral, vars, pars), ] diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py index 9e4577441..d950c509c 100644 --- a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py +++ b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py @@ -4,7 +4,6 @@ import matplotlib.pyplot as plt -import jax.numpy as jnp import brainpy as bp import brainpy.math as bm from brainpy._src.integrators.ode.exponential import ExponentialEuler @@ -15,16 +14,16 @@ class TestExpnentialEuler(unittest.TestCase): def test_hh_model(self): def drivative(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C): - alpha = 0.1 * (V + 40) / (1 - jnp.exp(-(V + 40) / 10)) - beta = 4.0 * jnp.exp(-(V + 65) / 18) + alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + beta = 4.0 * bm.exp(-(V + 65) / 18) dmdt = alpha * (1 - m) - beta * m - alpha = 0.07 * jnp.exp(-(V + 65) / 20.) - beta = 1 / (1 + jnp.exp(-(V + 35) / 10)) + alpha = 0.07 * bm.exp(-(V + 65) / 20.) + beta = 1 / (1 + bm.exp(-(V + 35) / 10)) dhdt = alpha * (1 - h) - beta * h - alpha = 0.01 * (V + 55) / (1 - jnp.exp(-(V + 55) / 10)) - beta = 0.125 * jnp.exp(-(V + 65) / 80) + alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + beta = 0.125 * bm.exp(-(V + 65) / 80) dndt = alpha * (1 - n) - beta * n I_Na = (gNa * m ** 3.0 * h) * (V - ENa) @@ -39,7 +38,7 @@ def drivative(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C): def test1(self): def dev(x, t): - dx = jnp.power(x, 3) + dx = bm.power(x, 3) return dx ExponentialEuler(f=dev, show_code=True, dt=0.01) @@ -64,29 +63,29 @@ def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., self.phi = phi # variables - self.V = bm.Variable(jnp.ones(size) * -65.) - self.h = bm.Variable(jnp.ones(size) * 0.6) - self.n = bm.Variable(jnp.ones(size) * 0.32) - self.spike = bm.Variable(jnp.zeros(size, dtype=bool)) - self.input = bm.Variable(jnp.zeros(size)) + self.V = bm.Variable(bm.ones(size) * -65.) + self.h = bm.Variable(bm.ones(size) * 0.6) + self.n = bm.Variable(bm.ones(size) * 0.32) + self.spike = bm.Variable(bm.zeros(size, dtype=bool)) + self.input = bm.Variable(bm.zeros(size)) self.integral = bp.odeint(bp.JointEq(self.dV, self.dh, self.dn), method=method, show_code=True) def dh(self, h, t, V): - alpha = 0.07 * jnp.exp(-(V + 58) / 20) - beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) + alpha = 0.07 * bm.exp(-(V + 58) / 20) + beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) dhdt = self.phi * (alpha * (1 - h) - beta * h) return dhdt def dn(self, n, t, V): - alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) - beta = 0.125 * jnp.exp(-(V + 44) / 80) + alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * bm.exp(-(V + 44) / 80) dndt = self.phi * (alpha * (1 - n) - beta * n) return dndt def dV(self, V, t, h, n, Iext): - m_alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) - m_beta = 4 * jnp.exp(-(V + 60) / 18) + m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + m_beta = 4 * bm.exp(-(V + 60) / 18) m = m_alpha / (m_alpha + m_beta) INa = self.gNa * m ** 3 * h * (V - self.ENa) IK = self.gK * n ** 4 * (V - self.EK) @@ -98,7 +97,7 @@ def dV(self, V, t, h, n, Iext): def update(self, tdi): t, dt = tdi.t, tdi.dt V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt=dt) - self.spike.value = jnp.logical_and(self.V < self.V_th, V >= self.V_th) + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.h.value = h self.n.value = n diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 8d57646a2..bc028396b 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -12,6 +12,8 @@ from .object_transform.base import BrainPyObject from .environment import get_dt, get_float from .ndarray import ndarray, Variable, Array +from .arrayinterporate import as_jax +from . import arraycompatible as bm __all__ = [ 'AbstractDelay', @@ -448,12 +450,12 @@ def update(self, value: Union[float, int, bool, Array, jnp.DeviceArray]): The value of the latest data, used to update this delay variable. """ if self.update_method == ROTATION_UPDATING: - self.idx.value = stop_gradient((self.idx - 1) % self.num_delay_step) + self.idx.value = stop_gradient(as_jax((self.idx - 1) % self.num_delay_step)) self.data[self.idx[0]] = value elif self.update_method == CONCAT_UPDATING: if self.num_delay_step >= 2: - self.data.value = jnp.vstack([jnp.broadcast_to(value, self.data.shape[1:]), self.data[1:]]) + self.data.value = bm.vstack([bm.broadcast_to(value, self.data.shape[1:]), self.data[1:]]) else: self.data[:] = value diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 451abc9c6..3d33200cc 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -104,6 +104,7 @@ def is_return_bparray(): _return_bp_array = True + def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj @@ -827,13 +828,13 @@ def take(self, indices, axis=None, mode=None): """Return an array formed from the elements of a at the given indices.""" return _return(self.value.take(indices=_as_jax_array_(indices), axis=axis, mode=mode)) - def tobytes(self, order='C'): + def tobytes(self): """Construct Python bytes containing the raw data bytes in the array. Constructs Python bytes showing a copy of the raw contents of data memory. The bytes object is produced in C-order by default. This behavior is controlled by the ``order`` parameter.""" - return _return(self.value.tobytes(order=order)) + return _return(self.value.tobytes()) def tolist(self): """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. @@ -1009,13 +1010,13 @@ class Variable(Array): batch_axis: optional, int The batch axis. """ - __slots__ = ('_value', '_batch_axis') + __slots__ = ('_value', '_batch_axis', 'requires_grad') def __init__( self, value_or_size, - dtype=None, - batch_axis: int = None + dtype: type =None, + batch_axis: int = None, ): if isinstance(value_or_size, int): value = jnp.zeros(value_or_size, dtype=dtype) @@ -1181,419 +1182,19 @@ def sort(self, axis=-1, kind=None, order=None): """Sort an array in-place.""" self._value = self.value.sort(axis=axis, kind=kind, order=order) - # ---------- # - # operations # - # ---------- # - - def __bool__(self) -> bool: - return self.value.__bool__() - - def __len__(self) -> int: - return len(self.value) - - def __neg__(self): - return self.value.__neg__() - - def __pos__(self): - return self.value.__pos__() - - def __abs__(self): - return self.value.__abs__() - - def __invert__(self): - return self.value.__invert__() - - def __eq__(self, oc): - return self.value == _check_input_array(oc) - - def __ne__(self, oc): - return self.value != _check_input_array(oc) - - def __lt__(self, oc): - return self.value < _check_input_array(oc) - - def __le__(self, oc): - return self.value <= _check_input_array(oc) - - def __gt__(self, oc): - return self.value > _check_input_array(oc) - - def __ge__(self, oc): - return self.value >= _check_input_array(oc) - - def __add__(self, oc): - return self.value + _check_input_array(oc) - - def __radd__(self, oc): - return self.value + _check_input_array(oc) - - def __sub__(self, oc): - return self.value - _check_input_array(oc) - - def __rsub__(self, oc): - return _check_input_array(oc) - self.value - - def __mul__(self, oc): - return self.value * _check_input_array(oc) - - def __rmul__(self, oc): - return _check_input_array(oc) * self.value - - def __rdiv__(self, oc): - return _check_input_array(oc) / self.value - - def __truediv__(self, oc): - return self.value / _check_input_array(oc) - - def __rtruediv__(self, oc): - return _check_input_array(oc) / self.value - - def __floordiv__(self, oc): - return self.value // _check_input_array(oc) - - def __rfloordiv__(self, oc): - return _check_input_array(oc) // self.value - - def __divmod__(self, oc): - return self.value.__divmod__(_check_input_array(oc)) - - def __rdivmod__(self, oc): - return self.value.__rdivmod__(_check_input_array(oc)) - - def __mod__(self, oc): - return self.value % _check_input_array(oc) - - def __rmod__(self, oc): - return _check_input_array(oc) % self.value - - def __pow__(self, oc): - return self.value ** _check_input_array(oc) - - def __rpow__(self, oc): - return _check_input_array(oc) ** self.value - - def __matmul__(self, oc): - return self.value @ _check_input_array(oc) - - def __rmatmul__(self, oc): - return _check_input_array(oc) @ self.value - - def __and__(self, oc): - return self.value & _check_input_array(oc) - - def __rand__(self, oc): - return _check_input_array(oc) & self.value - - def __or__(self, oc): - return self.value | _check_input_array(oc) - - def __ror__(self, oc): - return _check_input_array(oc) | self.value - - def __xor__(self, oc): - return self.value ^ _check_input_array(oc) - - def __rxor__(self, oc): - return _check_input_array(oc) ^ self.value - - def __lshift__(self, oc): - return self.value << _check_input_array(oc) - - def __rlshift__(self, oc): - return _check_input_array(oc) << self.value - - def __rshift__(self, oc): - return self.value >> _check_input_array(oc) - - def __rrshift__(self, oc): - return _check_input_array(oc) >> self.value - - def __round__(self, ndigits=None): - return self.value.__round__(ndigits) - - # ----------------------- # - # NumPy methods # - # ----------------------- # - - def all(self, axis=None, keepdims=False): - """Returns True if all elements evaluate to True.""" - return self.value.all(axis=axis, keepdims=keepdims) - - def any(self, axis=None, keepdims=False): - """Returns True if any of the elements of a evaluate to True.""" - return self.value.any(axis=axis, keepdims=keepdims) - - def argmax(self, axis=None): - """Return indices of the maximum values along the given axis.""" - return self.value.argmax(axis=axis) - - def argmin(self, axis=None): - """Return indices of the minimum values along the given axis.""" - return self.value.argmin(axis=axis) - - def argpartition(self, kth, axis=-1, kind='introselect', order=None): - """Returns the indices that would partition this array.""" - return self.value.argpartition(kth=kth, axis=axis, kind=kind, order=order) - - def argsort(self, axis=-1, kind=None, order=None): - """Returns the indices that would sort this array.""" - return self.value.argsort(axis=axis, kind=kind, order=order) - - def astype(self, dtype): - """Copy of the array, cast to a specified type. - - Parameters - ---------- - dtype: str, dtype - Typecode or data-type to which the array is cast. - """ - return self.value.astype(dtype=dtype) - - def byteswap(self, inplace=False): - """Swap the bytes of the array elements - - Toggle between low-endian and big-endian data representation by - returning a byteswapped array, optionally swapped in-place. - Arrays of byte-strings are not swapped. The real and imaginary - parts of a complex number are swapped individually.""" - return self.value.byteswap(inplace=inplace) - - def choose(self, choices, mode='raise'): - """Use an index array to construct a new array from a set of choices.""" - choices = choices.value if isinstance(choices, Array) else choices - return self.value.choose(choices=choices, mode=mode) - - def clip(self, min=None, max=None): - """Return an array whose values are limited to [min, max]. One of max or min must be given.""" - return self.value.clip(min=min, max=max) - - def compress(self, condition, axis=None): - """Return selected slices of this array along given axis.""" - condition = condition.value if isinstance(condition, Array) else condition - return self.value.compress(condition=condition, axis=axis) - - def conj(self): - """Complex-conjugate all elements.""" - return self.value.conj() - - def conjugate(self): - """Return the complex conjugate, element-wise.""" - return self.value.conjugate() - - def copy(self): - """Return a copy of the array.""" - return self.value.copy() - - def cumprod(self, axis=None, dtype=None): - """Return the cumulative product of the elements along the given axis.""" - return self.value.cumprod(axis=axis, dtype=dtype) - - def cumsum(self, axis=None, dtype=None): - """Return the cumulative sum of the elements along the given axis.""" - return self.value.cumsum(axis=axis, dtype=dtype) - - def diagonal(self, offset=0, axis1=0, axis2=1): - """Return specified diagonals.""" - return self.value.diagonal(offset=offset, axis1=axis1, axis2=axis2) - - def dot(self, b): - """Dot product of two arrays.""" - return self.value.dot(b.value if isinstance(b, Array) else b) - - def flatten(self): - return self.value.flatten() - - def item(self, *args): - """Copy an element of an array to a standard Python scalar and return it.""" - return self.value.item(*args) - - def max(self, axis=None, keepdims=False, *args, **kwargs): - """Return the maximum along a given axis.""" - return self.value.max(axis=axis, keepdims=keepdims, *args, **kwargs) - - def mean(self, axis=None, dtype=None, keepdims=False, *args, **kwargs): - """Returns the average of the array elements along given axis.""" - return self.value.mean(axis=axis, dtype=dtype, keepdims=keepdims, *args, **kwargs) - - def min(self, axis=None, keepdims=False, *args, **kwargs): - """Return the minimum along a given axis.""" - return self.value.min(axis=axis, keepdims=keepdims, *args, **kwargs) - - def nonzero(self): - """Return the indices of the elements that are non-zero.""" - return self.value.nonzero() - - def prod(self, axis=None, dtype=None, keepdims=False, initial=1, where=True): - """Return the product of the array elements over the given axis.""" - return self.value.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - - def ptp(self, axis=None, keepdims=False): - """Peak to peak (maximum - minimum) value along a given axis.""" - return self.value.ptp(axis=axis, keepdims=keepdims) - - def ravel(self, order=None): - """Return a flattened array.""" - return self.value.ravel(order=order) - - def repeat(self, repeats, axis=None): - """Repeat elements of an array.""" - return self.value.repeat(repeats=repeats, axis=axis) - - def reshape(self, *shape, order='C'): - """Returns an array containing the same data with a new shape.""" - return self.value.reshape(*shape, order=order) - - def round(self, decimals=0): - """Return ``a`` with each element rounded to the given number of decimals.""" - return self.value.round(decimals=decimals) - - def searchsorted(self, v, side='left', sorter=None): - """Find indices where elements should be inserted to maintain order. - - Find the indices into a sorted array `a` such that, if the - corresponding elements in `v` were inserted before the indices, the - order of `a` would be preserved. - - Assuming that `a` is sorted: - - ====== ============================ - `side` returned index `i` satisfies - ====== ============================ - left ``a[i-1] < v <= a[i]`` - right ``a[i-1] <= v < a[i]`` - ====== ============================ - - Parameters - ---------- - v : array_like - Values to insert into `a`. - side : {'left', 'right'}, optional - If 'left', the index of the first suitable location found is given. - If 'right', return the last such index. If there is no suitable - index, return either 0 or N (where N is the length of `a`). - sorter : 1-D array_like, optional - Optional array of integer indices that sort array a into ascending - order. They are typically the result of argsort. - - Returns - ------- - indices : array of ints - Array of insertion points with the same shape as `v`. - """ - v = v.value if isinstance(v, Array) else v - return self.value.searchsorted(v=v, side=side, sorter=sorter) - - def squeeze(self, axis=None): - """Remove axes of length one from ``a``.""" - return self.value.squeeze(axis=axis) - - def std(self, axis=None, dtype=None, ddof=0, keepdims=False): - """Compute the standard deviation along the specified axis. - - Returns the standard deviation, a measure of the spread of a distribution, - of the array elements. The standard deviation is computed for the - flattened array by default, otherwise over the specified axis. - - Parameters - ---------- - axis : None or int or tuple of ints, optional - Axis or axes along which the standard deviation is computed. The - default is to compute the standard deviation of the flattened array. - If this is a tuple of ints, a standard deviation is performed over - multiple axes, instead of a single axis or all the axes as before. - dtype : dtype, optional - Type to use in computing the standard deviation. For arrays of - integer type the default is float64, for arrays of float types it is - the same as the array type. - ddof : int, optional - Means Delta Degrees of Freedom. The divisor used in calculations - is ``N - ddof``, where ``N`` represents the number of elements. - By default `ddof` is zero. - keepdims : bool, optional - If this is set to True, the axes which are reduced are left - in the result as dimensions with size one. With this option, - the result will broadcast correctly against the input array. - - If the default value is passed, then `keepdims` will not be - passed through to the `std` method of sub-classes of - `ndarray`, however any non-default value will be. If the - sub-class' method does not implement `keepdims` any - exceptions will be raised. - - Returns - ------- - standard_deviation : ndarray, see dtype parameter above. - If `out` is None, return a new array containing the standard deviation, - otherwise return a reference to the output array. - """ - return self.value.std(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) - - def sum(self, axis=None, dtype=None, keepdims=False, initial=0, where=True): - """Return the sum of the array elements over the given axis.""" - return self.value.sum(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - - def swapaxes(self, axis1, axis2): - """Return a view of the array with `axis1` and `axis2` interchanged.""" - return self.value.swapaxes(axis1, axis2) - - def split(self, indices_or_sections, axis=0): - """Split an array into multiple sub-arrays as views into ``ary``. - """ - return self.value.split(indices_or_sections, axis=axis) - - def take(self, indices, axis=None, mode=None): - """Return an array formed from the elements of a at the given indices.""" - indices = indices.value if isinstance(indices, Array) else indices - return self.value.take(indices=indices, axis=axis, mode=mode) - - def tobytes(self, order='C'): - """Construct Python bytes containing the raw data bytes in the array. - - Constructs Python bytes showing a copy of the raw contents of data memory. - The bytes object is produced in C-order by default. This behavior is - controlled by the ``order`` parameter.""" - return self.value.tobytes(order=order) - - def tolist(self): - """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. - - Return a copy of the array data as a (nested) Python list. - Data items are converted to the nearest compatible builtin Python type, via - the `~numpy.ndarray.item` function. - - If ``a.ndim`` is 0, then since the depth of the nested list is 0, it will - not be a list at all, but a simple Python scalar. - """ - return self.value.tolist() - - def trace(self, offset=0, axis1=0, axis2=1, dtype=None): - """Return the sum along diagonals of the array.""" - return self.value.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) - - def transpose(self, *axes): - """Returns a view of the array with axes transposed. - """ - return self.value.transpose(*axes) - - def tile(self, reps): - return self.value.tile(reps.value if isinstance(reps, Array) else reps) - - def var(self, axis=None, dtype=None, ddof=0, keepdims=False): - """Returns the variance of the array elements, along given axis.""" - return self.value.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) - - def view(self, dtype=None, *args, **kwargs): - """New view of array with the same data.""" - return self.value.view(dtype=dtype, *args, **kwargs) - class TrainVar(Variable): """The pointer to specify the trainable variable. """ __slots__ = ('_value', '_batch_axis') - def __init__(self, value_or_size, dtype=None, batch_axis: int = None): - super(TrainVar, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis) + def __init__(self, + value_or_size, + dtype: type = None, + batch_axis: int = None): + super(TrainVar, self).__init__(value_or_size, + dtype=dtype, + batch_axis=batch_axis) class Parameter(Variable): @@ -1601,8 +1202,13 @@ class Parameter(Variable): """ __slots__ = ('_value', '_batch_axis') - def __init__(self, value_or_size, dtype=None, batch_axis: int = None): - super(Parameter, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis) + def __init__(self, + value_or_size, + dtype: type = None, + batch_axis: int = None): + super(Parameter, self).__init__(value_or_size, + dtype=dtype, + batch_axis=batch_axis) class ParallelVariable(Variable): diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 220bdb6b1..3e237ca4d 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -458,8 +458,11 @@ def to(self, device: Optional[Any]): Args: device: The device. """ - for var in self.vars().unique().values(): - var.value = jax.device_put(var.value, device=device) + for key, var in self.state_dict().items(): + if isinstance(var, Array): + var.value = jax.device_put(var.value, device=device) + else: + setattr(self, key, jax.device_put(var, device=device)) def cpu(self): """Move all variable into the CPU device.""" diff --git a/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py b/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py index 4f1e38227..9671d1abb 100644 --- a/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py +++ b/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py @@ -16,10 +16,7 @@ from jax.lax import stop_gradient from matplotlib import patches -bm.set_environment(mode=bm.training_mode) - - -bm.set_dt(1.) # Simulation time step [ms] +bm.set_environment(mode=bm.training_mode, dt=1.) # training parameters n_batch = 128 # batch size @@ -171,7 +168,8 @@ def loss_fun(predicts, targets): # Training trainer = bp.BPTT( - net, loss_fun, + net, + loss_fun, loss_has_aux=True, optimizer=bp.optimizers.Adam(lr=0.01), monitors={'r.spike': net.r.spike}, diff --git a/examples/training_snn_models/spikebased_bp_for_cifar10.py b/examples/training_snn_models/spikebased_bp_for_cifar10.py index 0c279641f..c80927d0c 100644 --- a/examples/training_snn_models/spikebased_bp_for_cifar10.py +++ b/examples/training_snn_models/spikebased_bp_for_cifar10.py @@ -247,7 +247,7 @@ def main(): def loss_fun(x, y, fit=True): yy = bm.one_hot(y, 10, dtype=bm.float_) # poisson encoding - x = (bm.random.rand(num_time, *x.shape) < jnp.abs(x)).astype(bm.float_) * jnp.sign(x) + x = (bm.random.rand(num_time, *x.shape) < bm.abs(x)).astype(bm.float_) * bm.sign(x) # loop over time s = {'fit': fit} for i in range(num_time): From df2ee19981bc334d745e240ee45b56008e97adb7 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 27 Jan 2023 15:07:24 +0800 Subject: [PATCH 2/4] fix bug --- brainpy/_src/integrators/fde/Caputo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/brainpy/_src/integrators/fde/Caputo.py b/brainpy/_src/integrators/fde/Caputo.py index e463c3341..29db2b29b 100644 --- a/brainpy/_src/integrators/fde/Caputo.py +++ b/brainpy/_src/integrators/fde/Caputo.py @@ -326,7 +326,7 @@ def __init__( raise UnsupportedError(f'Only support the fractional order in (0, 1), ' f'but we got {self.alpha}.') from scipy.special import gamma - self.gamma_alpha = jnp.asarray(gamma(bm.as_numpy(2 - self.alpha))) + self.gamma_alpha = bm.asarray(gamma(bm.as_numpy(2 - self.alpha))) # initial values inits = check_inits(inits, self.variables) @@ -336,11 +336,11 @@ def __init__( # coefficients ranges = jnp.asarray([jnp.arange(1, num_memory + 2) for _ in self.variables]).T coef = jnp.diff(jnp.power(ranges, 1 - self.alpha), axis=0) - self.coef = jnp.flip(coef, axis=0) + self.coef = bm.flip(coef, axis=0) # variable states self.diff_states = {v + "_diff": bm.Variable(jnp.zeros((num_memory,) + self.inits[v].shape, - dtype=self.inits[v].dtype)) + dtype=self.inits[v].dtype)) for v in self.variables} self.register_implicit_vars(self.diff_states) self.idx = bm.Variable(jnp.asarray([self.num_memory - 1])) From df4ec695ac9563e679193c58b3d0b3528eec7fea Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 27 Jan 2023 16:41:29 +0800 Subject: [PATCH 3/4] fix bug and update examples --- brainpy/_src/dyn/synapses/gap_junction.py | 4 +- examples/dynamics_analysis/1d_system.py | 96 --- .../2d_decision_making_model.py | 93 --- .../2d_wilson_cowan_model.py | 85 --- .../3d_hindmarsh_rose_model.py | 67 -- examples/dynamics_analysis/highdim_CANN.py | 153 ---- .../highdim_gj_coupled_fhn.py | 115 ---- .../Bazhenov_1998_thalamus_aug_response.py | 38 - .../dynamics_simulation/Brette_2007_COBA.py | 101 --- .../dynamics_simulation/Brette_2007_COBAHH.py | 73 -- .../dynamics_simulation/COBA_for_benchmark.py | 68 -- ...022_gj_coupled_bursting_pituitary_cells.py | 101 --- .../JR_1995_jansen_rit_model.py | 139 ---- ...2017_unified_thalamus_oscillation_model.py | 353 ---------- .../Susin_2021_gamma_oscillation_nets.py | 651 ------------------ .../Vreeswijk_1996_EI_net.py | 57 -- .../Wang_2002_decision_making_spiking.py | 289 -------- .../dynamics_simulation/Wu_2008_CANN_1D.py | 127 ---- .../Wu_2008_CANN_1D_oscillatory_tracking.py | 94 --- .../dynamics_simulation/Wu_2008_CANN_2D.py | 114 --- .../dynamics_simulation/multi_scale_COBAHH.py | 3 - ...Bellec_2020_eprop_evidence_accumulation.py | 226 ------ .../Gauthier_2021_ngrc_double_scroll.py | 142 ---- .../Gauthier_2021_ngrc_lorenz.py | 148 ---- .../Gauthier_2021_ngrc_lorenz_inference.py | 179 ----- 25 files changed, 2 insertions(+), 3514 deletions(-) delete mode 100644 examples/dynamics_analysis/1d_system.py delete mode 100644 examples/dynamics_analysis/2d_decision_making_model.py delete mode 100644 examples/dynamics_analysis/2d_wilson_cowan_model.py delete mode 100644 examples/dynamics_analysis/3d_hindmarsh_rose_model.py delete mode 100644 examples/dynamics_analysis/highdim_CANN.py delete mode 100644 examples/dynamics_analysis/highdim_gj_coupled_fhn.py delete mode 100644 examples/dynamics_simulation/Bazhenov_1998_thalamus_aug_response.py delete mode 100644 examples/dynamics_simulation/Brette_2007_COBA.py delete mode 100644 examples/dynamics_simulation/Brette_2007_COBAHH.py delete mode 100644 examples/dynamics_simulation/COBA_for_benchmark.py delete mode 100644 examples/dynamics_simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py delete mode 100644 examples/dynamics_simulation/JR_1995_jansen_rit_model.py delete mode 100644 examples/dynamics_simulation/Li_2017_unified_thalamus_oscillation_model.py delete mode 100644 examples/dynamics_simulation/Susin_2021_gamma_oscillation_nets.py delete mode 100644 examples/dynamics_simulation/Vreeswijk_1996_EI_net.py delete mode 100644 examples/dynamics_simulation/Wang_2002_decision_making_spiking.py delete mode 100644 examples/dynamics_simulation/Wu_2008_CANN_1D.py delete mode 100644 examples/dynamics_simulation/Wu_2008_CANN_1D_oscillatory_tracking.py delete mode 100644 examples/dynamics_simulation/Wu_2008_CANN_2D.py delete mode 100644 examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py delete mode 100644 examples/dynamics_training/Gauthier_2021_ngrc_double_scroll.py delete mode 100644 examples/dynamics_training/Gauthier_2021_ngrc_lorenz.py delete mode 100644 examples/dynamics_training/Gauthier_2021_ngrc_lorenz_inference.py diff --git a/brainpy/_src/dyn/synapses/gap_junction.py b/brainpy/_src/dyn/synapses/gap_junction.py index 8334ba09c..7c6aa428d 100644 --- a/brainpy/_src/dyn/synapses/gap_junction.py +++ b/brainpy/_src/dyn/synapses/gap_junction.py @@ -51,9 +51,9 @@ def update(self, tdi): if self.comp_method == 'dense': # pre -> post diff = (self.pre.V.reshape((-1, 1)) - self.post.V) * self.conn_mat * self.weights - self.post.input += jnp.einsum('ij->j', diff) + self.post.input += bm.einsum('ij->j', diff) # post -> pre - self.pre.input += jnp.einsum('ij->i', -diff) + self.pre.input += bm.einsum('ij->i', -diff) else: diff = (self.pre.V[self.pre_ids] - self.post.V[self.post_ids]) * self.weights self.post.input += bm.syn2post_sum(diff, self.post_ids, self.post.num) diff --git a/examples/dynamics_analysis/1d_system.py b/examples/dynamics_analysis/1d_system.py deleted file mode 100644 index 270181cf7..000000000 --- a/examples/dynamics_analysis/1d_system.py +++ /dev/null @@ -1,96 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy as bp - -bp.math.enable_x64() -bp.math.set_platform('cpu') - - -def quadratic_system1(): - int_x = bp.odeint(lambda x, t: -x ** 2) - analyzer = bp.analysis.PhasePlane1D(model=int_x, - target_vars={'x': [-2, 2]}, - resolutions=0.001) - analyzer.plot_vector_field() - analyzer.plot_fixed_point(show=True) - - int_x = bp.odeint(lambda x, t: x ** 2) - analyzer = bp.analysis.PhasePlane1D(model=int_x, - target_vars={'x': [-2, 2]}, - resolutions=0.001) - analyzer.plot_vector_field() - analyzer.plot_fixed_point(show=True) - - -def cubic_system1(): - int_x = bp.odeint(lambda x, t: -x ** 3) - analyzer = bp.analysis.PhasePlane1D(model=int_x, - target_vars={'x': [-2, 2]}, - resolutions=0.001) - analyzer.plot_vector_field() - analyzer.plot_fixed_point(show=True) - - int_x = bp.odeint(lambda x, t: x ** 3) - analyzer = bp.analysis.PhasePlane1D(model=int_x, - target_vars={'x': [-2, 2]}, - resolutions=0.001) - analyzer.plot_vector_field() - analyzer.plot_fixed_point(show=True) - - -def cubic_system_2(): - @bp.odeint - def int_x(x, t, Iext): - return x ** 3 - x + Iext - - analyzer = bp.analysis.PhasePlane1D(model=int_x, - target_vars={'x': [-2, 2]}, - pars_update={'Iext': 0.}, - resolutions=0.001) - analyzer.plot_vector_field() - analyzer.plot_fixed_point(show=True) - - -def sin_1d(): - @bp.odeint - def int_x(x, t, Iext): - return bp.math.sin(x) + Iext - - pp = bp.analysis.PhasePlane1D(model=int_x, - target_vars={'x': [-5, 5]}, - pars_update={'Iext': 0.9}, - resolutions=0.001) - pp.plot_vector_field() - pp.plot_fixed_point(show=True) - - bf = bp.analysis.Bifurcation1D(model=int_x, - target_vars={'x': [-5, 5]}, - target_pars={'Iext': [0., 1.5]}, - resolutions=0.001) - bf.plot_bifurcation(show=True, tol_aux=1e-7) - - -def sincos_1d(): - @bp.odeint - def int_x(x, t, a=1., b=1.): - return bp.math.sin(a * x) + bp.math.cos(b * x) - - pp = bp.analysis.PhasePlane1D( - model=int_x, - target_vars={'x': [-bp.math.pi, bp.math.pi]}, - resolutions=0.001 - ) - pp.plot_vector_field() - pp.plot_fixed_point(show=True) - - bf = bp.analysis.Bifurcation1D( - model=int_x, - target_vars={'x': [-bp.math.pi, bp.math.pi]}, - target_pars={'a': [0.5, 1.5], 'b': [0.5, 1.5]}, - resolutions={'a': 0.01, 'b': 0.01} - ) - bf.plot_bifurcation(show=True) - - -if __name__ == '__main__': - sin_1d() diff --git a/examples/dynamics_analysis/2d_decision_making_model.py b/examples/dynamics_analysis/2d_decision_making_model.py deleted file mode 100644 index dd8651b23..000000000 --- a/examples/dynamics_analysis/2d_decision_making_model.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy as bp -import brainpy.math as bm - -bp.math.enable_x64() - -# parameters -gamma = 0.641 # Saturation factor for gating variable -tau = 0.06 # Synaptic time constant [sec] -a = 270. -b = 108. -d = 0.154 - -JE = 0.3725 # self-coupling strength [nA] -JI = -0.1137 # cross-coupling strength [nA] -JAext = 0.00117 # Stimulus input strength [nA] - -mu = 20. # Stimulus firing rate [spikes/sec] -coh = 0.5 # Stimulus coherence [%] -Ib1 = 0.3297 -Ib2 = 0.3297 - - -@bp.odeint -def int_s1(s1, t, s2, coh=0.5, mu=20.): - I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh) - r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b))) - return - s1 / tau + (1. - s1) * gamma * r1 - - -@bp.odeint -def int_s2(s2, t, s1, coh=0.5, mu=20.): - I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh) - r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b))) - return - s2 / tau + (1. - s2) * gamma * r2 - - -def phase_plane_analysis(): - # phase plane analysis - analyzer = bp.analysis.PhasePlane2D( - model=[int_s1, int_s2], - target_vars={'s1': [0, 1], 's2': [0, 1]}, - resolutions=0.001, - ) - analyzer.plot_vector_field() - analyzer.plot_nullcline(coords=dict(s2='s2-s1')) - analyzer.plot_fixed_point() - analyzer.show_figure() - - -def bifurcation_analysis(): - # codimension 1 bifurcation - analyzer = bp.analysis.Bifurcation2D( - model=[int_s1, int_s2], - target_vars={'s1': [0., 1.], 's2': [0., 1.]}, - target_pars={'coh': [0., 1.]}, - pars_update={'mu': 40.}, - resolutions={'coh': 0.005}, - ) - analyzer.plot_bifurcation(num_par_segments=1, - num_fp_segment=1, - select_candidates='aux_rank', - num_rank=50) - analyzer.show_figure() - - -def fixed_point_finder(): - def step(s): - ds1 = int_s1.f(s[0], 0., s[1]) - ds2 = int_s2.f(s[1], 0., s[0]) - return bm.asarray([ds1, ds2]) - - finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS) - # finder.find_fps_with_gd_method( - # candidates=bm.random.random((1000, 2)), - # tolerance=1e-8, - # num_batch=200, - # optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.01, 1, 0.9999)), - # ) - finder.find_fps_with_opt_solver(bm.random.random((1000, 2))) - finder.filter_loss(1e-14) - finder.keep_unique() - - print('fixed_points: ', finder.fixed_points) - print('losses:', finder.losses) - print('jacobians: ', finder.compute_jacobians(finder.fixed_points)) - - -if __name__ == '__main__': - phase_plane_analysis() - bifurcation_analysis() - fixed_point_finder() diff --git a/examples/dynamics_analysis/2d_wilson_cowan_model.py b/examples/dynamics_analysis/2d_wilson_cowan_model.py deleted file mode 100644 index 4298feabb..000000000 --- a/examples/dynamics_analysis/2d_wilson_cowan_model.py +++ /dev/null @@ -1,85 +0,0 @@ -import brainpy as bp -import brainpy.math as bm - -bp.math.enable_x64() - - -class WilsonCowanModel(bp.DynamicalSystem): - def __init__(self, num, method='exp_auto'): - super(WilsonCowanModel, self).__init__() - - # Connection weights - self.wEE = 12 - self.wEI = 4 - self.wIE = 13 - self.wII = 11 - - # Refractory parameter - self.r = 1 - - # Excitatory parameters - self.E_tau = 1 # Timescale of excitatory population - self.E_a = 1.2 # Gain of excitatory population - self.E_theta = 2.8 # Threshold of excitatory population - - # Inhibitory parameters - self.I_tau = 1 # Timescale of inhibitory population - self.I_a = 1 # Gain of inhibitory population - self.I_theta = 4 # Threshold of inhibitory population - - # variables - self.i = bm.Variable(bm.ones(num)) - self.e = bm.Variable(bm.ones(num)) - self.Iext = bm.Variable(bm.zeros(num)) - - # functions - def F(x, a, theta): - return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta)) - - def de(e, t, i, Iext=0.): - x = self.wEE * e - self.wEI * i + Iext - return (-e + (1 - self.r * e) * F(x, self.E_a, self.E_theta)) / self.E_tau - - def di(i, t, e): - x = self.wIE * e - self.wII * i - return (-i + (1 - self.r * i) * F(x, self.I_a, self.I_theta)) / self.I_tau - - self.int_e = bp.odeint(de, method=method) - self.int_i = bp.odeint(di, method=method) - - def update(self, tdi): - t, dt = tdi['t'], tdi['dt'] - self.e.value = self.int_e(self.e, t, self.i, self.Iext, dt) - self.i.value = self.int_i(self.i, t, self.e, dt) - self.Iext[:] = 0. - - -model = WilsonCowanModel(2) -model.e[:] = [-0.2, 1.] -model.i[:] = [0.0, 1.] - -# simulation -runner = bp.DSRunner(model, monitors=['e', 'i']) -runner.run(100) - -fig, gs = bp.visualize.get_figure(2, 1, 3, 8) -fig.add_subplot(gs[0, 0]) -bp.visualize.line_plot(runner.mon.ts, runner.mon.e, plot_ids=[0], legend='e', linestyle='--') -bp.visualize.line_plot(runner.mon.ts, runner.mon.i, plot_ids=[0], legend='i', linestyle='--') -fig.add_subplot(gs[1, 0]) -bp.visualize.line_plot(runner.mon.ts, runner.mon.e, plot_ids=[1], legend='e') -bp.visualize.line_plot(runner.mon.ts, runner.mon.i, plot_ids=[1], legend='i', show=True) - - -# phase plane analysis -pp = bp.analysis.PhasePlane2D( - model, - target_vars={'e': [-0.2, 1.], 'i': [-0.2, 1.]}, - resolutions=0.001, -) -pp.plot_vector_field() -pp.plot_nullcline(coords={'i': 'i-e'}) -pp.plot_fixed_point() -pp.plot_trajectory(initials={'i': [0.5, 0.6], 'e': [-0.1, 0.4]}, - duration=10, dt=0.1) -pp.show_figure() diff --git a/examples/dynamics_analysis/3d_hindmarsh_rose_model.py b/examples/dynamics_analysis/3d_hindmarsh_rose_model.py deleted file mode 100644 index c8970ddf9..000000000 --- a/examples/dynamics_analysis/3d_hindmarsh_rose_model.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- - -import matplotlib.pyplot as plt -import numpy as np - -import brainpy as bp - -bp.math.enable_x64() - - -def simulation(): - model = bp.neurons.HindmarshRose(1) - runner = bp.DSRunner( - model, - monitors=['V', 'y', 'z'], - inputs=[model.input, 1.5], - ) - runner.run(2000.) - bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V') - # bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y') - # bp.visualize.line_plot(runner.mon.ts, runner.mon.z, legend='z') - plt.show() - - -def bifurcation_analysis(): - model = bp.neurons.HindmarshRose(1) - analyzer = bp.analysis.FastSlow2D( - model, - fast_vars={'V': [-3, 2], 'y': [-20., 3.]}, - slow_vars={'z': [-0.5, 3.]}, - pars_update={'I_ext': 1.5}, - resolutions={'z': 0.01}, - # options={bp.analysis.C.y_by_x_in_fy: lambda x: model.c - model.d * x * x} - ) - analyzer.plot_bifurcation(num_rank=20) - analyzer.plot_trajectory({'V': [1.], 'y': [1.], 'z': [1.]}, - duration=1700, - plot_durations=[360, 1680]) - analyzer.show_figure() - - -def phase_plane_analysis(): - model = bp.neurons.HindmarshRose(1) - for z in np.arange(0., 2.5, 0.3): - analyzer = bp.analysis.PhasePlane2D( - model, - target_vars={'V': [-3, 2], 'y': [-20., 3.]}, - pars_update={'I_ext': 1.5, 'z': z}, - resolutions={'V': 0.01, 'y': 0.01}, - ) - analyzer.plot_nullcline() - analyzer.plot_vector_field() - fps = analyzer.plot_fixed_point(with_return=True) - analyzer.plot_trajectory({'V': [fps[-1, 0] + 0.1], - 'y': [fps[-1, 0] + 0.1]}, - duration=500, - plot_durations=[400, 500]) - plt.title(f'z={z:.2f}') - plt.show() - # plt.savefig(f'data/z={z:.2f}.png') - plt.close() - - -if __name__ == '__main__': - simulation() - bifurcation_analysis() - phase_plane_analysis() diff --git a/examples/dynamics_analysis/highdim_CANN.py b/examples/dynamics_analysis/highdim_CANN.py deleted file mode 100644 index c6519e2fa..000000000 --- a/examples/dynamics_analysis/highdim_CANN.py +++ /dev/null @@ -1,153 +0,0 @@ -# -*- coding: utf-8 -*- - -import matplotlib.pyplot as plt -from sklearn.decomposition import PCA - -import brainpy as bp -import brainpy.math as bm - -bm.set_platform('cpu') - - -class CANN1D(bp.NeuGroup): - def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4., - z_min=-bm.pi, z_max=bm.pi, **kwargs): - super(CANN1D, self).__init__(size=num, **kwargs) - - # parameters - self.tau = tau # The synaptic time constant - self.k = k # Degree of the rescaled inhibition - self.a = a # Half-width of the range of excitatory connections - self.A = A # Magnitude of the external input - self.J0 = J0 # maximum connection value - - # feature space - self.z_min = z_min - self.z_max = z_max - self.z_range = z_max - z_min - self.x = bm.linspace(z_min, z_max, num) # The encoded feature values - self.rho = num / self.z_range # The neural density - self.dx = self.z_range / num # The stimulus density - - # variables - self.u = bm.Variable(bm.zeros(num)) - self.input = bm.Variable(bm.zeros(num)) - - # The connection matrix - self.conn_mat = self.make_conn(self.x) - - # function - self.integral = bp.odeint(self.derivative) - - def derivative(self, u, t, Iext): - r1 = bm.square(u) - r2 = 1.0 + self.k * bm.sum(r1) - r = r1 / r2 - Irec = bm.dot(self.conn_mat, r) - du = (-u + Irec + Iext) / self.tau - return du - - def dist(self, d): - d = bm.remainder(d, self.z_range) - d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d) - return d - - def make_conn(self, x): - assert bm.ndim(x) == 1 - x_left = bm.reshape(x, (-1, 1)) - x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0) - d = self.dist(x_left - x_right) - Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) - return Jxx - - def get_stimulus_by_pos(self, pos): - return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a)) - - def update(self, tdi): - t, dt = tdi.get('t'), tdi.get('dt') - self.u.value = self.integral(self.u, t, self.input, dt) - self.input[:] = 0. - - -def find_fixed_points(pars=None, verbose=False, opt_method='gd', cand_method='random', tolerance=1e-6): - if pars is None: pars = dict() - cann = CANN1D(num=512, **pars) - - if cand_method == 'random': - candidates = bm.random.uniform(0, 20., (1000, cann.num)) - elif cand_method == 'bump': - candidates = cann.get_stimulus_by_pos(bm.arange(-bm.pi, bm.pi, 0.01).reshape((-1, 1))) - candidates += bm.random.normal(0., 0.01, candidates.shape) - else: - raise ValueError - - finder = bp.analysis.SlowPointFinder(f_cell=cann, target_vars={'u': cann.u}, dt=1.) - if opt_method == 'gd': - finder.find_fps_with_gd_method( - candidates={'u': candidates}, - tolerance=tolerance, - num_batch=200, - optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.2, 1, 0.999)), - ) - elif opt_method == 'BFGS': - finder.find_fps_with_opt_solver({'u': candidates}) - else: - raise ValueError() - finder.filter_loss(tolerance) - finder.keep_unique(5e-3) - - if verbose: - print(finder.fixed_points) - print(finder.losses) - print(finder.selected_ids) - - return finder.fixed_points, finder - - -def visualize_fixed_points(fixed_points): - bp.visualize.animate_1D( - dynamical_vars={'ys': fixed_points['u'], - 'xs': bm.linspace(-bm.pi, bm.pi, fixed_points['u'].shape[1]), - 'legend': 'fixed point'}, - frame_step=1, - frame_delay=100, - show=True, - # save_path='cann_fps.gif' - ) - - -def verify_fixed_points_through_simulation(fixed_points, pars=None, num=3): - if pars is None: pars = dict() - cann = CANN1D(num=512, **pars) - - for i in range(num): - cann.u[:] = fixed_points['u'][i] - runner = bp.DSRunner(cann, monitors=['u'], dyn_vars=cann.vars()) - runner.run(100.) - plt.plot(runner.mon.ts, runner.mon.u.max(axis=1)) - plt.ylim(0, runner.mon.u.max() + 1) - plt.show() - - -def pca_reduction(fixed_points): - pca = PCA(2) - pca.fit(fixed_points['u']) - fixedpoints_pc = pca.transform(fixed_points['u']) - plt.plot(fixedpoints_pc[:, 0], fixedpoints_pc[:, 1], 'x', label='fixed points') - - plt.xlabel('PC 1') - plt.ylabel('PC 2') - plt.legend() - plt.show() - - -if __name__ == '__main__': - params = dict(k=0.1, a=0.5, A=20) - fps, finder = find_fixed_points(params, cand_method='bump', tolerance=1e-7) - # fps, finder = find_fixed_points(params, cand_method='random', opt_method='gd', tolerance=1e-7) - # fps, finder = find_fixed_points(params, cand_method='random', opt_method='BFGS', tolerance=1e-5) - visualize_fixed_points(fps) - verify_fixed_points_through_simulation(fps, params) - finder.compute_jacobians(fps['u'][:6], plot=True) - pca_reduction(fps) - diff --git a/examples/dynamics_analysis/highdim_gj_coupled_fhn.py b/examples/dynamics_analysis/highdim_gj_coupled_fhn.py deleted file mode 100644 index dffd09b60..000000000 --- a/examples/dynamics_analysis/highdim_gj_coupled_fhn.py +++ /dev/null @@ -1,115 +0,0 @@ -# -*- coding: utf-8 -*- - - -import numpy as np -import matplotlib.pyplot as plt - -import brainpy as bp -import brainpy.math as bm - -bp.math.enable_x64() - - -class GJCoupledFHN(bp.DynamicalSystem): - def __init__(self, num=2, method='exp_auto'): - super(GJCoupledFHN, self).__init__() - - # parameters - self.num = num - self.a = 0.7 - self.b = 0.8 - self.tau = 12.5 - self.gjw = 0.0001 - - # variables - self.V = bm.Variable(bm.random.uniform(-2, 2, num)) - self.w = bm.Variable(bm.random.uniform(-2, 2, num)) - self.Iext = bm.Variable(bm.zeros(num)) - - # functions - self.int_V = bp.odeint(self.dV, method=method) - self.int_w = bp.odeint(self.dw, method=method) - - def dV(self, V, t, w, Iext=0.): - gj = (V.reshape((-1, 1)) - V).sum(axis=0) * self.gjw - dV = V - V * V * V / 3 - w + Iext + gj - return dV - - def dw(self, w, t, V): - dw = (V + self.a - self.b * w) / self.tau - return dw - - def update(self, tdi): - t, dt = tdi.get('t'), tdi.get('dt') - self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt) - self.w.value = self.int_w(self.w, t, self.V, dt) - self.Iext[:] = 0. - - -def d4_system(): - model = GJCoupledFHN(2) - model.gjw = 0.01 - # Iext = bm.asarray([0., 0.1]) - Iext = bm.asarray([0., 0.6]) - - # simulation - runner = bp.DSRunner(model, monitors=['V'], inputs=['Iext', Iext]) - runner.run(300.) - bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', - plot_ids=list(range(model.num)), show=True) - - # analysis - finder = bp.analysis.SlowPointFinder(f_cell=model, - target_vars={'V': model.V, 'w': model.w}, - inputs=['Iext', Iext]) - # finder.find_fps_with_gd_method( - # candidates={'V': bm.random.normal(0., 2., (1000, model.num)), - # 'w': bm.random.normal(0., 2., (1000, model.num))}, - # tolerance=1e-7, - # num_batch=200, - # optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.05, 1, 0.9999)) - # ) - finder.find_fps_with_opt_solver(candidates={'V': bm.random.normal(0., 2., (1000, model.num)), - 'w': bm.random.normal(0., 2., (1000, model.num))}) - finder.filter_loss(1e-7) - finder.keep_unique() - - print('fixed_points: ', finder.fixed_points) - print('losses:', finder.losses) - jac = finder.compute_jacobians(finder.fixed_points, plot=True) - - -def d8_system(): - model = GJCoupledFHN(4) - model.gjw = 0.1 - Iext = bm.asarray([0., 0., 0., 0.6]) - - # simulation - runner = bp.DSRunner(model, monitors=['V'], inputs=['Iext', Iext]) - runner.run(300.) - - bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', - plot_ids=list(range(model.num)), - show=True) - - finder = bp.analysis.SlowPointFinder(f_cell=model, - target_vars={'V': model.V, 'w': model.w}, - inputs=[model.Iext, Iext]) - finder.find_fps_with_gd_method( - candidates={'V': bm.random.normal(0., 2., (1000, model.num)), - 'w': bm.random.normal(0., 2., (1000, model.num))}, - tolerance=1e-6, - num_batch=200, - optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.05, 1, 0.9999)), - ) - finder.filter_loss(1e-7) - finder.keep_unique() - - print('fixed_points: ', finder.fixed_points) - print('losses:', finder.losses) - jac = finder.compute_jacobians(finder.fixed_points, plot=True) - - -if __name__ == '__main__': - d4_system() - d8_system() diff --git a/examples/dynamics_simulation/Bazhenov_1998_thalamus_aug_response.py b/examples/dynamics_simulation/Bazhenov_1998_thalamus_aug_response.py deleted file mode 100644 index d277d8bec..000000000 --- a/examples/dynamics_simulation/Bazhenov_1998_thalamus_aug_response.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Implementation of the model: - -- Bazhenov, Maxim, et al. "Cellular and network models for - intrathalamic augmenting responses during 10-Hz stimulation." - Journal of Neurophysiology 79.5 (1998): 2730-2748. -""" - -import brainpy as bp - - -class RE(bp.CondNeuGroup): - def __init__(self, size): - super(RE, self).__init__(size, A=1.43e-4) - - self.IL = bp.channels.IL(size, ) - self.IKL = bp.channels.IKL(size, ) - self.INa = bp.channels.INa_TM1991(size, V_sh=-50.) - self.IK = bp.channels.IK_TM1991(size, V_sh=-50.) - self.IT = bp.channels.ICaT_HP1992(size, V_sh=0., phi_q=3., phi_p=3.) - - -class TC(bp.CondNeuGroup): - def __init__(self, size): - super(TC, self).__init__(size, A=2.9e-4) - - self.IL = bp.channels.IL(size, ) - self.IKL = bp.channels.IKL(size, ) - self.INa = bp.channels.INa_TM1991(size, V_sh=-50.) - self.IK = bp.channels.IK_TM1991(size, V_sh=-50.) - self.IT = bp.channels.ICaT_HM1992(size, V_sh=0., ) - self.IA = bp.channels.IKA1_HM1992(size, V_sh=0., phi_q=3.7255, phi_p=3.7) - - self.Ih = bp.channels.Ih_De1996(size, ) - self.Ca = bp.channels.CalciumFirstOrder(size, ) - diff --git a/examples/dynamics_simulation/Brette_2007_COBA.py b/examples/dynamics_simulation/Brette_2007_COBA.py deleted file mode 100644 index aedbccece..000000000 --- a/examples/dynamics_simulation/Brette_2007_COBA.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy as bp -import brainpy.math as bm - -bp.math.set_platform('cpu') - - -class EINet_V1(bp.Network): - def __init__(self, scale=1.0, method='exp_auto'): - super(EINet_V1, self).__init__() - - # network size - num_exc = int(3200 * scale) - num_inh = int(800 * scale) - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - self.E = bp.neurons.LIF(num_exc, **pars, method=method) - self.I = bp.neurons.LIF(num_inh, **pars, method=method) - self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. - self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.E2I = bp.synapses.Exponential(self.E, self.I, bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.I2E = bp.synapses.Exponential(self.I, self.E, bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - self.I2I = bp.synapses.Exponential(self.I, self.I, bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - - -def run_model_v1(): - net = EINet_V1(scale=1., method='exp_auto') - # simulation - runner = bp.DSRunner( - net, - monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)] - ) - runner.run(100.) - - # visualization - bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) - - -class EINet_V2(bp.Network): - def __init__(self, scale=1.0, method='exp_auto'): - super(EINet_V2, self).__init__() - - # network size - num_exc = int(3200 * scale) - num_inh = int(800 * scale) - - # neurons - self.N = bp.neurons.LIF(num_exc + num_inh, - V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - method=method, V_initializer=bp.initialize.Normal(-55., 2.)) - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.Esyn = bp.synapses.Exponential(pre=self.N[:num_exc], - post=self.N, - conn=bp.connect.FixedProb(0.02), - g_max=we, tau=5., - output=bp.synouts.COBA(E=0.), - method=method) - self.Isyn = bp.synapses.Exponential(pre=self.N[num_exc:], - post=self.N, - conn=bp.connect.FixedProb(0.02), - g_max=wi, tau=10., - output=bp.synouts.COBA(E=-80.), - method=method) - - -def run_model_v2(): - net = EINet_V2(scale=1., method='exp_auto') - # simulation - runner = bp.DSRunner( - net, - monitors={'spikes': net.N.spike}, - inputs=[(net.N.input, 20.)] - ) - runner.run(100.) - - # visualization - bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True) - - -if __name__ == '__main__': - run_model_v1() - run_model_v2() diff --git a/examples/dynamics_simulation/Brette_2007_COBAHH.py b/examples/dynamics_simulation/Brette_2007_COBAHH.py deleted file mode 100644 index 8a720639e..000000000 --- a/examples/dynamics_simulation/Brette_2007_COBAHH.py +++ /dev/null @@ -1,73 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy as bp - -bp.math.set_platform('cpu') - - -class HH(bp.CondNeuGroup): - def __init__(self, size): - super(HH, self).__init__(size, ) - self.INa = bp.channels.INa_TM1991(size, g_max=100., V_sh=-63.) - self.IK = bp.channels.IK_TM1991(size, g_max=30., V_sh=-63.) - self.IL = bp.channels.IL(size, E=-60., g_max=0.05) - - -class EINet_v1(bp.Network): - def __init__(self, scale=1.): - super(EINet_v1, self).__init__() - self.E = HH(int(3200 * scale)) - self.I = HH(int(800 * scale)) - prob = 0.02 - self.E2E = bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(prob), - g_max=0.03 / scale, tau=5, - output=bp.synouts.COBA(E=0.)) - self.E2I = bp.synapses.Exponential(self.E, self.I, bp.conn.FixedProb(prob), - g_max=0.03 / scale, tau=5., - output=bp.synouts.COBA(E=0.)) - self.I2E = bp.synapses.Exponential(self.I, self.E, bp.conn.FixedProb(prob), - g_max=0.335 / scale, tau=10., - output=bp.synouts.COBA(E=-80)) - self.I2I = bp.synapses.Exponential(self.I, self.I, bp.conn.FixedProb(prob), - g_max=0.335 / scale, tau=10., - output=bp.synouts.COBA(E=-80.)) - - -class EINet_v2(bp.Network): - def __init__(self, scale=1.): - super(EINet_v2, self).__init__() - - prob = 0.02 - self.num_exc = int(3200 * scale) - self.num_inh = int(800 * scale) - - self.N = HH(self.num_exc + self.num_inh) - self.Esyn = bp.synapses.Exponential(self.N[:self.num_exc], - self.N, - bp.conn.FixedProb(prob), - g_max=0.03 / scale, tau=5, - output=bp.synouts.COBA(E=0.)) - self.Isyn = bp.synapses.Exponential(self.N[self.num_exc:], - self.N, - bp.conn.FixedProb(prob), - g_max=0.335 / scale, tau=10., - output=bp.synouts.COBA(E=-80)) - - -def run_ei_v1(): - net = EINet_v1(scale=1) - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}) - runner.run(100.) - bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) - - -def run_ei_v2(): - net = EINet_v2(scale=1) - runner = bp.DSRunner(net, monitors={'spikes': net.N.spike}) - runner.run(100.) - bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True) - - -if __name__ == '__main__': - run_ei_v1() - run_ei_v2() diff --git a/examples/dynamics_simulation/COBA_for_benchmark.py b/examples/dynamics_simulation/COBA_for_benchmark.py deleted file mode 100644 index c610d8b4a..000000000 --- a/examples/dynamics_simulation/COBA_for_benchmark.py +++ /dev/null @@ -1,68 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy as bp -import brainpy.math as bm - -bp.math.set_platform('cpu') - - -class ExpCOBA(bp.dyn.TwoEndConn): - def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., - method='exp_auto'): - super(ExpCOBA, self).__init__(pre=pre, post=post, conn=conn) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') - - # parameters - self.E = E - self.tau = tau - self.delay = delay - self.g_max = g_max - self.pre2post = self.conn.require('pre2post') - - # variables - self.g = bm.Variable(bm.zeros(self.post.num)) - - # function - self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method) - - def update(self, tdi): - self.g.value = self.integral(self.g, tdi.t, tdi.dt) - self.g += bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max) - self.post.input += self.g * (self.E - self.post.V) - - -class EINet(bp.dyn.Network): - def __init__(self, scale=1.0, method='exp_auto'): - # network size - num_exc = int(3200 * scale) - num_inh = int(800 * scale) - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - E = bp.neurons.LIF(num_exc, **pars, method=method) - I = bp.neurons.LIF(num_inh, **pars, method=method) - E.V[:] = bp.math.random.randn(num_exc) * 2 - 55. - I.V[:] = bp.math.random.randn(num_inh) * 2 - 55. - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - E2E = ExpCOBA(E, E, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method) - E2I = ExpCOBA(E, I, bp.conn.FixedProb(prob=0.02), E=0., g_max=we, tau=5., method=method) - I2E = ExpCOBA(I, E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method) - I2I = ExpCOBA(I, I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=wi, tau=10., method=method) - - super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) - - -net = EINet(scale=10., method='euler') -# simulation -runner = bp.dyn.DSRunner(net, - # monitors=['E.spike'], - inputs=[('E.input', 20.), ('I.input', 20.)]) -t = runner.run(10000.) -print(t) - -# visualization -# bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) diff --git a/examples/dynamics_simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py b/examples/dynamics_simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py deleted file mode 100644 index 89bad5eab..000000000 --- a/examples/dynamics_simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- - - -"""" -Implementation of the paper: - -- Fazli, Mehran, and Richard Bertram. "Network Properties of Electrically - Coupled Bursting Pituitary Cells." Frontiers in Endocrinology 13 (2022). -""" - -import brainpy as bp -import brainpy.math as bm - - -class PituitaryCell(bp.NeuGroup): - def __init__(self, size, name=None): - super(PituitaryCell, self).__init__(size, name=name) - - # parameter values - self.vn = -5 - self.kc = 0.12 - self.ff = 0.005 - self.vca = 60 - self.vk = -75 - self.vl = -50.0 - self.gk = 2.5 - self.cm = 5 - self.gbk = 1 - self.gca = 2.1 - self.gsk = 2 - self.vm = -20 - self.vb = -5 - self.sn = 10 - self.sm = 12 - self.sbk = 2 - self.taun = 30 - self.taubk = 5 - self.ks = 0.4 - self.alpha = 0.0015 - self.gl = 0.2 - - # variables - self.V = bm.Variable(bm.random.random(self.num) * -90 + 20) - self.n = bm.Variable(bm.random.random(self.num) / 2) - self.b = bm.Variable(bm.random.random(self.num) / 2) - self.c = bm.Variable(bm.random.random(self.num)) - self.input = bm.Variable(self.num) - - # integrators - self.integral = bp.odeint(bp.JointEq(self.dV, self.dn, self.dc, self.db), method='exp_euler') - - def dn(self, n, t, V): - ninf = 1 / (1 + bm.exp((self.vn - V) / self.sn)) - return (ninf - n) / self.taun - - def db(self, b, t, V): - bkinf = 1 / (1 + bm.exp((self.vb - V) / self.sbk)) - return (bkinf - b) / self.taubk - - def dc(self, c, t, V): - minf = 1 / (1 + bm.exp((self.vm - V) / self.sm)) - ica = self.gca * minf * (V - self.vca) - return -self.ff * (self.alpha * ica + self.kc * c) - - def dV(self, V, t, n, b, c): - minf = 1 / (1 + bm.exp((self.vm - V) / self.sm)) - cinf = c ** 2 / (c ** 2 + self.ks * self.ks) - ica = self.gca * minf * (V - self.vca) - isk = self.gsk * cinf * (V - self.vk) - ibk = self.gbk * b * (V - self.vk) - ikdr = self.gk * n * (V - self.vk) - il = self.gl * (V - self.vl) - return -(ica + isk + ibk + ikdr + il + self.input) / self.cm - - def update(self, tdi, x=None): - V, n, c, b = self.integral(self.V.value, self.n.value, self.c.value, self.b.value, tdi.t, tdi.dt) - self.V.value = V - self.n.value = n - self.c.value = c - self.b.value = b - - def clear_input(self): - self.input.value = bm.zeros_like(self.input) - - -class PituitaryNetwork(bp.Network): - def __init__(self, num, gc): - super(PituitaryNetwork, self).__init__() - - self.N = PituitaryCell(num) - self.gj = bp.synapses.GapJunction(self.N, self.N, bp.conn.All2All(include_self=False), g_max=gc) - - -if __name__ == '__main__': - net = PituitaryNetwork(2, 0.002) - runner = bp.DSRunner(net, monitors={'V': net.N.V}, dt=0.5) - runner.run(10 * 1e3) - - fig, gs = bp.visualize.get_figure(1, 1, 6, 10) - fig.add_subplot(gs[0, 0]) - bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=(0, 1), show=True) diff --git a/examples/dynamics_simulation/JR_1995_jansen_rit_model.py b/examples/dynamics_simulation/JR_1995_jansen_rit_model.py deleted file mode 100644 index a909ab76c..000000000 --- a/examples/dynamics_simulation/JR_1995_jansen_rit_model.py +++ /dev/null @@ -1,139 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy as bp -import brainpy.math as bm - - -class JansenRitModel(bp.DynamicalSystem): - r"""The Jansen-Rit model, a neural mass model of the dynamic - interactions between 3 populations: - - - pyramidal cells (PCs) - - excitatory interneurons (EINs) - - inhibitory interneurons (IINs) - - Originally, the model has been developed to describe the waxing-and-waning - of EEG activity in the alpha frequency range (8-12 Hz) in the visual cortex [1]. - In the past years, however, it has been used as a generic model to describe - the macroscopic electrophysiological activity within a cortical column [2]. - - By using the linearity of the convolution operation, the dynamic interactions between PCs, EINs and IINs can be - expressed via 6 coupled ordinary differential equations that are composed of the two operators defined above: - - .. math:: - - \dot V_{pce} &= I_{pce}, \\ - \dot I_{pce} &= \frac{H_e}{\tau_e} c_4 S(c_3 V_{in}) - \frac{2 I_{pce}}{\tau_e} - \frac{V_{pce}}{\tau_e^2}, \\ - \dot V_{pci} &= I_{pci}, \\ - \dot I_{pci} &= \frac{H_i}{\tau_i} c_2 S(c_1 V_{in}) - \frac{2 I_{pci}}{\tau_i} - \frac{V_{pci}}{\tau_i^2}, \\ - \dot V_{in} &= I_{in}, \\ - \dot I_{in} &= \frac{H_e}{\tau_e} S(V_{pce} - V_{pci}) - \frac{2 I_{in}}{\tau_e} - \frac{V_{in}}{\tau_e^2}, - - where :math:`V_{pce}`, :math:`V_{pci}`, :math:`V_{in}` are used to represent the average membrane potential - deflection caused by the excitatory synapses at the PC population, the inhibitory synapses at the PC - population, and the excitatory synapses at both interneuron populations, respectively. - - References - ---------- - .. [1] B.H. Jansen & V.G. Rit (1995) Electroencephalogram and visual evoked - potential generation in a mathematical model of coupled cortical - columns. Biological Cybernetics, 73(4): 357-366. - .. [2] A. Spiegler, S.J. Kiebel, F.M. Atay, T.R. Knösche (2010) Bifurcation analysis of neural - mass models: Impact of extrinsic inputs and dendritic time constants. NeuroImage, 52(3): - 1041-1058, https://doi.org/10.1016/j.neuroimage.2009.12.081. - - """ - - def __init__(self, num, C=135., method='exp_auto'): - super(JansenRitModel, self).__init__() - - self.num = num - - # parameters # - self.v_max = 5. # maximum firing rate - self.v0 = 6. # firing threshold - self.r = 0.56 # slope of the sigmoid - # other parameters - self.A = 3.25 - self.B = 22. - self.a = 100. - self.tau_e = 0.01 # second - self.tau_i = 0.02 # second - self.b = 50. - self.e0 = 2.5 - # The connectivity constants - self.C1 = C - self.C2 = 0.8 * C - self.C3 = 0.25 * C - self.C4 = 0.25 * C - - # variables # - # y0, y1 and y2 representing the firing rate of - # pyramidal, excitatory and inhibitory neurones. - self.y0 = bm.Variable(bm.zeros(self.num)) - self.y1 = bm.Variable(bm.zeros(self.num)) - self.y2 = bm.Variable(bm.zeros(self.num)) - self.y3 = bm.Variable(bm.zeros(self.num)) - self.y4 = bm.Variable(bm.zeros(self.num)) - self.y5 = bm.Variable(bm.zeros(self.num)) - self.p = bm.Variable(bm.ones(self.num) * 220.) - - # integral function - self.derivative = bp.JointEq([self.dy0, self.dy1, self.dy2, self.dy3, self.dy4, self.dy5]) - self.integral = bp.odeint(self.derivative, method=method) - - def sigmoid(self, x): - return self.v_max / (1. + bm.exp(self.r * (self.v0 - x))) - - def dy0(self, y0, t, y3): return y3 - - def dy1(self, y1, t, y4): return y4 - - def dy2(self, y2, t, y5): return y5 - - def dy3(self, y3, t, y0, y1, y2): - return (self.A * self.sigmoid(y1 - y2) - 2 * y3 - y0 / self.tau_e) / self.tau_e - - def dy4(self, y4, t, y0, y1, p): - return (self.A * (p + self.C2 * self.sigmoid(self.C1 * y0)) - 2 * y4 - y1 / self.tau_e) / self.tau_e - - def dy5(self, y5, t, y0, y2): - return (self.B * self.C4 * self.sigmoid(self.C3 * y0) - 2 * y5 - y2 / self.tau_i) / self.tau_i - - def update(self, tdi): - t, dt = tdi['t'], tdi['dt'] - self.y0.value, self.y1.value, self.y2.value, self.y3.value, self.y4.value, self.y5.value = \ - self.integral(self.y0, self.y1, self.y2, self.y3, self.y4, self.y5, t, p=self.p, dt=dt) - - -def simulation(duration=5.): - dt = 0.1 / 1e3 - # random input uniformly distributed between 120 and 320 pulses per second - all_ps = bm.random.uniform(120, 320, size=(int(duration / dt), 1)) - jrm = JansenRitModel(num=6, C=bm.array([68., 128., 135., 270., 675., 1350.])) - runner = bp.DSRunner(jrm, - monitors=['y0', 'y1', 'y2', 'y3', 'y4', 'y5'], - inputs=['p', all_ps, 'iter', '='], - dt=dt) - runner.run(duration) - - start, end = int(2 / dt), int(duration / dt) - fig, gs = bp.visualize.get_figure(6, 3, 2, 3) - for i in range(6): - fig.add_subplot(gs[i, 0]) - title = 'E' if i == 0 else None - xlabel = 'time [s]' if i == 5 else None - bp.visualize.line_plot(runner.mon.ts[start: end], runner.mon.y1[start: end, i], - title=title, xlabel=xlabel, ylabel='Hz') - fig.add_subplot(gs[i, 1]) - title = 'P' if i == 0 else None - bp.visualize.line_plot(runner.mon.ts[start: end], runner.mon.y0[start: end, i], - title=title, xlabel=xlabel) - fig.add_subplot(gs[i, 2]) - title = 'I' if i == 0 else None - bp.visualize.line_plot(runner.mon.ts[start: end], runner.mon.y2[start: end, i], - title=title, show=i == 5, xlabel=xlabel) - - -if __name__ == '__main__': - simulation() diff --git a/examples/dynamics_simulation/Li_2017_unified_thalamus_oscillation_model.py b/examples/dynamics_simulation/Li_2017_unified_thalamus_oscillation_model.py deleted file mode 100644 index c1a5f99f9..000000000 --- a/examples/dynamics_simulation/Li_2017_unified_thalamus_oscillation_model.py +++ /dev/null @@ -1,353 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Implementation of the model: - -- Li, Guoshi, Craig S. Henriquez, and Flavio Fröhlich. "Unified - thalamic model generates multiple distinct oscillations with - state-dependent entrainment by stimulation." PLoS computational - biology 13.10 (2017): e1005797. -""" - -from typing import Dict - -import matplotlib.pyplot as plt -import numpy as np - -import brainpy as bp -import brainpy.math as bm - - -class HTC(bp.CondNeuGroup): - def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-65.), ): - gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125) - IL = bp.channels.IL(size, g_max=gL, E=-70) - IKL = bp.channels.IKL(size, g_max=gKL) - INa = bp.channels.INa_Ba2002(size, V_sh=-30) - IDR = bp.channels.IKDR_Ba2002(size, V_sh=-30., phi=0.25) - Ih = bp.channels.Ih_HM1992(size, g_max=0.01, E=-43) - - ICaL = bp.channels.ICaL_IS2008(size, g_max=0.5) - IAHP = bp.channels.IAHP_De1994(size, g_max=0.3, E=-90.) - ICaN = bp.channels.ICaN_IS2008(size, g_max=0.5) - ICaT = bp.channels.ICaT_HM1992(size, g_max=2.1) - ICaHT = bp.channels.ICaHT_HM1992(size, g_max=3.0) - Ca = bp.channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, ICaL=ICaL, - IAHP=IAHP, ICaN=ICaN, ICaT=ICaT, ICaHT=ICaHT) - - super(HTC, self).__init__(size, A=2.9e-4, V_initializer=V_initializer, V_th=20., - IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca) - - -class RTC(bp.CondNeuGroup): - def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-65.), ): - gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125) - IL = bp.channels.IL(size, g_max=gL, E=-70) - IKL = bp.channels.IKL(size, g_max=gKL) - INa = bp.channels.INa_Ba2002(size, V_sh=-40) - IDR = bp.channels.IKDR_Ba2002(size, V_sh=-40, phi=0.25) - Ih = bp.channels.Ih_HM1992(size, g_max=0.01, E=-43) - - ICaL = bp.channels.ICaL_IS2008(size, g_max=0.3) - IAHP = bp.channels.IAHP_De1994(size, g_max=0.1, E=-90.) - ICaN = bp.channels.ICaN_IS2008(size, g_max=0.6) - ICaT = bp.channels.ICaT_HM1992(size, g_max=2.1) - ICaHT = bp.channels.ICaHT_HM1992(size, g_max=0.6) - Ca = bp.channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, ICaL=ICaL, - IAHP=IAHP, ICaN=ICaN, ICaT=ICaT, ICaHT=ICaHT) - - super(RTC, self).__init__(size, A=2.9e-4, V_initializer=V_initializer, V_th=20., - IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca) - - -class IN(bp.CondNeuGroup): - def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ): - gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125) - IL = bp.channels.IL(size, g_max=gL, E=-60) - IKL = bp.channels.IKL(size, g_max=gKL) - INa = bp.channels.INa_Ba2002(size, V_sh=-30) - IDR = bp.channels.IKDR_Ba2002(size, V_sh=-30, phi=0.25) - Ih = bp.channels.Ih_HM1992(size, g_max=0.05, E=-43) - - IAHP = bp.channels.IAHP_De1994(size, g_max=0.2, E=-90.) - ICaN = bp.channels.ICaN_IS2008(size, g_max=0.1) - ICaHT = bp.channels.ICaHT_HM1992(size, g_max=2.5) - Ca = bp.channels.CalciumDetailed(size, C_rest=5e-5, tau=10., d=0.5, - IAHP=IAHP, ICaN=ICaN, ICaHT=ICaHT) - - super(IN, self).__init__(size, A=1.7e-4, V_initializer=V_initializer, V_th=20., - IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ih=Ih, Ca=Ca) - - -class TRN(bp.CondNeuGroup): - def __init__(self, size, gKL=0.01, V_initializer=bp.init.OneInit(-70.), ): - gL = 0.01 if size == 1 else bp.init.Uniform(0.0075, 0.0125) - IL = bp.channels.IL(size, g_max=gL, E=-60) - IKL = bp.channels.IKL(size, g_max=gKL) - INa = bp.channels.INa_Ba2002(size, V_sh=-40) - IDR = bp.channels.IKDR_Ba2002(size, V_sh=-40) - - IAHP = bp.channels.IAHP_De1994(size, g_max=0.2, E=-90.) - ICaN = bp.channels.ICaN_IS2008(size, g_max=0.2) - ICaT = bp.channels.ICaT_HP1992(size, g_max=1.3) - Ca = bp.channels.CalciumDetailed(size, - C_rest=5e-5, tau=100., d=0.5, - IAHP=IAHP, ICaN=ICaN, ICaT=ICaT) - - super(TRN, self).__init__(size, - A=1.43e-4, V_th=20., - V_initializer=V_initializer, - IL=IL, IKL=IKL, INa=INa, IDR=IDR, Ca=Ca - ) - - -class MgBlock(bp.SynOut): - def __init__(self, E=0.): - super(MgBlock, self).__init__() - self.E = E - - def filter(self, g): - V = self.master.post.V.value - return g * (self.E - V) / (1 + bm.exp(-(V + 25) / 12.5)) - - -class Thalamus(bp.Network): - def __init__( - self, - g_input: Dict[str, float], - g_KL: Dict[str, float], - HTC_V_init=bp.init.OneInit(-65.), - RTC_V_init=bp.init.OneInit(-65.), - IN_V_init=bp.init.OneInit(-70.), - RE_V_init=bp.init.OneInit(-70.), - ): - super(Thalamus, self).__init__() - - # populations - self.HTC = HTC(size=(7, 7), gKL=g_KL['TC'], V_initializer=HTC_V_init) - self.RTC = RTC(size=(12, 12), gKL=g_KL['TC'], V_initializer=RTC_V_init) - self.RE = TRN(size=(10, 10), gKL=g_KL['RE'], V_initializer=IN_V_init) - self.IN = IN(size=(8, 8), gKL=g_KL['IN'], V_initializer=RE_V_init) - - # noises - self.poisson_HTC = bp.neurons.PoissonGroup(self.HTC.size, freqs=100) - self.poisson_RTC = bp.neurons.PoissonGroup(self.RTC.size, freqs=100) - self.poisson_IN = bp.neurons.PoissonGroup(self.IN.size, freqs=100) - self.poisson_RE = bp.neurons.PoissonGroup(self.RE.size, freqs=100) - self.noise2HTC = bp.synapses.Exponential(self.poisson_HTC, self.HTC, bp.conn.One2One(), - output=bp.synouts.COBA(E=0.), tau=5., - g_max=g_input['TC']) - self.noise2RTC = bp.synapses.Exponential(self.poisson_RTC, self.RTC, bp.conn.One2One(), - output=bp.synouts.COBA(E=0.), tau=5., - g_max=g_input['TC']) - self.noise2IN = bp.synapses.Exponential(self.poisson_IN, self.IN, bp.conn.One2One(), - output=bp.synouts.COBA(E=0.), tau=5., - g_max=g_input['IN']) - self.noise2RE = bp.synapses.Exponential(self.poisson_RE, self.RE, bp.conn.One2One(), - output=bp.synouts.COBA(E=0.), tau=5., - g_max=g_input['RE']) - - # HTC cells were connected with gap junctions - self.gj_HTC = bp.synapses.GapJunction(self.HTC, self.HTC, - bp.conn.ProbDist(dist=2., prob=0.3, ), - comp_method='sparse', - g_max=1e-2) - - # HTC provides feedforward excitation to INs - self.HTC2IN_ampa = bp.synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - alpha=0.94, - beta=0.18, - g_max=6e-3) - self.HTC2IN_nmda = bp.synapses.AMPA(self.HTC, self.IN, bp.conn.FixedProb(0.3), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - output=MgBlock(), - alpha=1., - beta=0.0067, - g_max=3e-3) - - # INs delivered feedforward inhibition to RTC cells - self.IN2RTC = bp.synapses.GABAa(self.IN, self.RTC, bp.conn.FixedProb(0.3), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - output=bp.synouts.COBA(E=-80), - alpha=10.5, - beta=0.166, - g_max=3e-3) - - # 20% RTC cells electrically connected with HTC cells - self.gj_RTC2HTC = bp.synapses.GapJunction(self.RTC, self.HTC, - bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2), - comp_method='sparse', - g_max=1 / 300) - - # Both HTC and RTC cells sent glutamatergic bp.synapses to RE neurons, while - # receiving GABAergic feedback inhibition from the RE population - self.HTC2RE_ampa = bp.synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - alpha=0.94, - beta=0.18, - g_max=4e-3) - self.RTC2RE_ampa = bp.synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - alpha=0.94, - beta=0.18, - g_max=4e-3) - self.HTC2RE_nmda = bp.synapses.AMPA(self.HTC, self.RE, bp.conn.FixedProb(0.2), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - output=MgBlock(), - alpha=1., - beta=0.0067, - g_max=2e-3) - self.RTC2RE_nmda = bp.synapses.AMPA(self.RTC, self.RE, bp.conn.FixedProb(0.2), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - output=MgBlock(), - alpha=1., - beta=0.0067, - g_max=2e-3) - self.RE2HTC = bp.synapses.GABAa(self.RE, self.HTC, bp.conn.FixedProb(0.2), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - output=bp.synouts.COBA(E=-80), - alpha=10.5, - beta=0.166, - g_max=3e-3) - self.RE2RTC = bp.synapses.GABAa(self.RE, self.RTC, bp.conn.FixedProb(0.2), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - output=bp.synouts.COBA(E=-80), - alpha=10.5, - beta=0.166, - g_max=3e-3) - - # RE neurons were connected with both gap junctions and GABAergic bp.synapses - self.gj_RE = bp.synapses.GapJunction(self.RE, self.RE, - bp.conn.ProbDist(dist=2., prob=0.3, pre_ratio=0.2), - comp_method='sparse', - g_max=1 / 300) - self.RE2RE = bp.synapses.GABAa(self.RE, self.RE, bp.conn.FixedProb(0.2), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - output=bp.synouts.COBA(E=-70), - alpha=10.5, beta=0.166, - g_max=1e-3) - - # 10% RE neurons project GABAergic bp.synapses to local interneurons - # probability (0.05) was used for the RE->IN bp.synapses according to experimental data - self.RE2IN = bp.synapses.GABAa(self.RE, self.IN, bp.conn.FixedProb(0.05, pre_ratio=0.1), - delay_step=int(2 / bm.get_dt()), - stp=bp.synplast.STD(tau=700, U=0.07), - output=bp.synouts.COBA(E=-80), - alpha=10.5, beta=0.166, - g_max=1e-3, ) - - -states = { - 'delta': dict(g_input={'IN': 1e-4, 'RE': 1e-4, 'TC': 1e-4}, - g_KL={'TC': 0.035, 'RE': 0.03, 'IN': 0.01}), - 'spindle': dict(g_input={'IN': 3e-4, 'RE': 3e-4, 'TC': 3e-4}, - g_KL={'TC': 0.01, 'RE': 0.02, 'IN': 0.015}), - 'alpha': dict(g_input={'IN': 1.5e-3, 'RE': 1.5e-3, 'TC': 1.5e-3}, - g_KL={'TC': 0., 'RE': 0.01, 'IN': 0.02}), - 'gamma': dict(g_input={'IN': 1.5e-3, 'RE': 1.5e-3, 'TC': 1.7e-2}, - g_KL={'TC': 0., 'RE': 0.01, 'IN': 0.02}), -} - - -def rhythm_const_input(amp, freq, length, duration, t_start=0., t_end=None, dt=None): - if t_end is None: t_end = duration - if length > duration: - raise ValueError(f'Expected length <= duration, while we got {length} > {duration}') - sec_length = 1e3 / freq - values, durations = [0.], [t_start] - for t in np.arange(t_start, t_end, sec_length): - values.append(amp) - if t + length <= t_end: - durations.append(length) - values.append(0.) - if t + sec_length <= t_end: - durations.append(sec_length - length) - else: - durations.append(t_end - t - length) - else: - durations.append(t_end - t) - values.append(0.) - durations.append(duration - t_end) - return bp.inputs.section_input(values=values, durations=durations, dt=dt, ) - - -def try_trn_neuron(): - with bm.environment(dt=0.01): - trn = TRN(1) - inputs = bp.inputs.section_input(values=[0, -0.05, 0], durations=[100, 100, 500]) - - @bm.to_dynsys(child_objs=trn) - def update(s, inp): - trn.input += inp - trn.update(s) - return trn.input.value - - runner = bp.DSRunner(update, monitors={'V': trn.V}) - I = runner.run(inputs=inputs) - - bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) - bp.visualize.line_plot(runner.mon.ts, I, show=True) - - -def try_network(): - duration = 3e3 - net = Thalamus( - IN_V_init=bp.init.OneInit(-70.), - RE_V_init=bp.init.OneInit(-70.), - HTC_V_init=bp.init.OneInit(-80.), - RTC_V_init=bp.init.OneInit(-80.), - **states['delta'], - ) - net.reset() - - # currents = rhythm_const_input(2e-4, freq=4., length=10., duration=duration, - # t_end=2e3, t_start=1e3) - # plt.plot(currents) - # plt.show() - - runner = bp.DSRunner( - net, - monitors=['HTC.spike', 'RTC.spike', 'RE.spike', 'IN.spike', - 'HTC.V', 'RTC.V', 'RE.V', 'IN.V', ], - # inputs=[('HTC.input', currents, 'iter'), - # ('RTC.input', currents, 'iter'), - # ('IN.input', currents, 'iter')], - ) - runner.run(duration) - - fig, gs = bp.visualize.get_figure(4, 2, 2, 5) - fig.add_subplot(gs[0, 0]) - bp.visualize.line_plot(runner.mon.ts, runner.mon.get('HTC.V'), ylabel='HTC', xlim=(0, duration)) - fig.add_subplot(gs[1, 0]) - bp.visualize.line_plot(runner.mon.ts, runner.mon.get('RTC.V'), ylabel='RTC', xlim=(0, duration)) - fig.add_subplot(gs[2, 0]) - bp.visualize.line_plot(runner.mon.ts, runner.mon.get('IN.V'), ylabel='IN', xlim=(0, duration)) - fig.add_subplot(gs[3, 0]) - bp.visualize.line_plot(runner.mon.ts, runner.mon.get('RE.V'), ylabel='RE', xlim=(0, duration)) - - fig.add_subplot(gs[0, 1]) - bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('HTC.spike'), xlim=(0, duration)) - fig.add_subplot(gs[1, 1]) - bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('RTC.spike'), xlim=(0, duration)) - fig.add_subplot(gs[2, 1]) - bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('IN.spike'), xlim=(0, duration)) - fig.add_subplot(gs[3, 1]) - bp.visualize.raster_plot(runner.mon.ts, runner.mon.get('RE.spike'), xlim=(0, duration)) - - plt.show() - - -if __name__ == '__main__': - try_trn_neuron() - try_network() diff --git a/examples/dynamics_simulation/Susin_2021_gamma_oscillation_nets.py b/examples/dynamics_simulation/Susin_2021_gamma_oscillation_nets.py deleted file mode 100644 index 7e99eb22e..000000000 --- a/examples/dynamics_simulation/Susin_2021_gamma_oscillation_nets.py +++ /dev/null @@ -1,651 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Implementation of the paper: - -- Susin, Eduarda, and Alain Destexhe. "Integration, coincidence detection and - resonance in networks of spiking neurons expressing gamma oscillations and - asynchronous states." PLoS computational biology 17.9 (2021): e1009416. - -""" - -import matplotlib.pyplot as plt -import numpy as np -from scipy.signal import kaiserord, lfilter, firwin, hilbert - -import brainpy as bp -import brainpy.math as bm - -# Table 1: specific neuron model parameters -RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65, - E_e=0., E_i=-80.) -FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65, - E_e=0., E_i=-80.) -Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65, - E_e=0., E_i=-80.) - - -class AdEx(bp.NeuGroup): - def __init__( - self, - size, - - # neuronal parameters - Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, - gL=10, EL=-65, V_reset=-65, V_sp_th=-30., - - # synaptic parameters - tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80., - - # other parameters - name=None, method='exp_euler', - V_initializer=bp.init.Uniform(-65, -50), - w_initializer=bp.init.Constant(0.), - ): - super(AdEx, self).__init__(size=size, name=name) - - # neuronal parameters - self.Vth = Vth - self.delta = delta - self.tau_ref = tau_ref - self.tau_w = tau_w - self.a = a - self.b = b - self.C = C - self.gL = gL - self.EL = EL - self.V_reset = V_reset - self.V_sp_th = V_sp_th - - # synaptic parameters - self.tau_e = tau_e - self.tau_i = tau_i - self.E_e = E_e - self.E_i = E_i - - # neuronal variables - self.V = bp.init.variable_(V_initializer, self.num) - self.w = bp.init.variable_(w_initializer, self.num) - self.spike = bm.Variable(self.num, dtype=bool) - self.refractory = bm.Variable(self.num, dtype=bool) - self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8) - - # synaptic parameters - self.ge = bm.Variable(self.num) - self.gi = bm.Variable(self.num) - - # integral - self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method) - - def dge(self, ge, t): - return -ge / self.tau_e - - def dgi(self, gi, t): - return -gi / self.tau_i - - def dV(self, V, t, w, ge, gi, Iext=None): - I = ge * (self.E_e - V) + gi * (self.E_i - V) - if Iext is not None: I += Iext - dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta) - - w + self.gL * (self.EL - V) + I) / self.C - return dVdt - - def dw(self, w, t, V): - dwdt = (self.a * (V - self.EL) - w) / self.tau_w - return dwdt - - def update(self, tdi, x=None): - V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value, - tdi.t, Iext=x, dt=tdi.dt) - refractory = (tdi.t - self.t_last_spike) <= self.tau_ref - V = bm.where(refractory, self.V.value, V) - spike = V >= self.V_sp_th - self.V.value = bm.where(spike, self.V_reset, V) - self.w.value = bm.where(spike, w + self.b, w) - self.ge.value = ge - self.gi.value = gi - self.spike.value = spike - self.refractory.value = bm.logical_or(refractory, spike) - self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike) - - -class PINGNet(bp.Network): - def __init__(self, ext_varied_rates, ext_weight=4., method='exp_euler', dt=bm.get_dt()): - super(PINGNet, self).__init__() - - self.num_exc = 20000 - self.num_inh = 5000 - self.exc_syn_tau = 1. # ms - self.inh_syn_tau = 7.5 # ms - self.exc_syn_weight = 5. # nS - self.inh_syn_weight = 3.34 # nS - self.num_delay_step = int(1.5 / dt) - self.ext_varied_rates = ext_varied_rates - - # neuronal populations - RS_par_ = RS_par.copy() - FS_par_ = FS_par.copy() - RS_par_.update(Vth=-50, V_sp_th=-40) - FS_par_.update(Vth=-50, V_sp_th=-40) - self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_) - self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_) - self.ext_pop = bp.neurons.PoissonGroup(self.num_exc, freqs=bm.Variable(1)) - - # Poisson inputs - self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=ext_weight) - self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=ext_weight) - - # synaptic projections - self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight, - delay_step=self.num_delay_step) - self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight, - delay_step=self.num_delay_step) - - def change_freq(self, tdi): - self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i] - - -class AINet(bp.Network): - def __init__(self, ext_varied_rates, ext_weight=1., method='exp_euler', dt=bm.get_dt()): - super(AINet, self).__init__() - - self.num_exc = 20000 - self.num_inh = 5000 - self.exc_syn_tau = 5. # ms - self.inh_syn_tau = 5. # ms - self.exc_syn_weight = 1. # nS - self.inh_syn_weight = 5. # nS - self.num_delay_step = int(1.5 / dt) - self.ext_varied_rates = ext_varied_rates - - # neuronal populations - RS_par_ = RS_par.copy() - FS_par_ = FS_par.copy() - RS_par_.update(Vth=-50, V_sp_th=-40) - FS_par_.update(Vth=-50, V_sp_th=-40) - self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_) - self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_) - self.ext_pop = bp.neurons.PoissonGroup(self.num_exc, freqs=bm.Variable(1)) - - # Poisson inputs - self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=ext_weight) - self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=ext_weight) - - # synaptic projections - self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight, - delay_step=self.num_delay_step) - self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight, - delay_step=self.num_delay_step) - - def change_freq(self, tdi): - self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i] - - -class INGNet(bp.Network): - def __init__(self, ext_varied_rates, ext_weight=0.9, method='exp_euler', dt=bm.get_dt()): - super(INGNet, self).__init__() - - self.num_rs = 20000 - self.num_fs = 4000 - self.num_fs2 = 1000 - self.exc_syn_tau = 5. # ms - self.inh_syn_tau = 5. # ms - self.exc_syn_weight = 1. # nS - self.inh_syn_weight = 5. # nS - self.num_delay_step = int(1.5 / dt) - self.ext_varied_rates = ext_varied_rates - - # neuronal populations - RS_par_ = RS_par.copy() - FS_par_ = FS_par.copy() - FS2_par_ = FS_par.copy() - RS_par_.update(Vth=-50, V_sp_th=-40) - FS_par_.update(Vth=-50, V_sp_th=-40) - FS2_par_.update(Vth=-50, V_sp_th=-40) - self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_) - self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_) - self.fs2_pop = AdEx(self.num_fs2, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS2_par_) - self.ext_pop = bp.neurons.PoissonGroup(self.num_rs, freqs=bm.Variable(1)) - - # Poisson inputs - self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=ext_weight) - self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=ext_weight) - self.ext_to_FS2 = bp.synapses.Delta(self.ext_pop, self.fs2_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=ext_weight) - - # synaptic projections - self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.RS_to_FS2 = bp.synapses.Delta(self.rs_pop, self.fs2_pop, bp.conn.FixedProb(0.15), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - - self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight, - delay_step=self.num_delay_step) - self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight, - delay_step=self.num_delay_step) - self.FS_to_FS2 = bp.synapses.Delta(self.fs_pop, self.fs2_pop, bp.conn.FixedProb(0.03), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight, - delay_step=self.num_delay_step) - - self.FS2_to_RS = bp.synapses.Delta(self.fs2_pop, self.rs_pop, bp.conn.FixedProb(0.15), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.FS2_to_FS = bp.synapses.Delta(self.fs2_pop, self.fs_pop, bp.conn.FixedProb(0.15), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.FS2_to_FS2 = bp.synapses.Delta(self.fs2_pop, self.fs2_pop, bp.conn.FixedProb(0.6), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - - def change_freq(self, tdi): - self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i] - - -class CHINGNet(bp.Network): - def __init__(self, ext_varied_rates, method='exp_euler', dt=bm.get_dt()): - super(CHINGNet, self).__init__() - - self.num_rs = 19000 - self.num_fs = 5000 - self.num_ch = 1000 - self.exc_syn_tau = 5. # ms - self.inh_syn_tau = 5. # ms - self.exc_syn_weight = 1. # nS - self.inh_syn_weight1 = 7. # nS - self.inh_syn_weight2 = 5. # nS - self.ext_weight1 = 1. # nS - self.ext_weight2 = 0.75 # nS - self.num_delay_step = int(1.5 / dt) - self.ext_varied_rates = ext_varied_rates - - # neuronal populations - RS_par_ = RS_par.copy() - FS_par_ = FS_par.copy() - Ch_par_ = Ch_par.copy() - RS_par_.update(Vth=-50, V_sp_th=-40) - FS_par_.update(Vth=-50, V_sp_th=-40) - Ch_par_.update(Vth=-50, V_sp_th=-40) - self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_) - self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_) - self.ch_pop = AdEx(self.num_ch, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **Ch_par_) - self.ext_pop = bp.neurons.PoissonGroup(self.num_rs, freqs=bm.Variable(1)) - - # Poisson inputs - self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.ext_weight2) - self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.ext_weight1) - self.ext_to_CH = bp.synapses.Delta(self.ext_pop, self.ch_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.ext_weight1) - - # synaptic projections - self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.RS_to_Ch = bp.synapses.Delta(self.rs_pop, self.ch_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - - self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight1, - delay_step=self.num_delay_step) - self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight2, - delay_step=self.num_delay_step) - self.FS_to_Ch = bp.synapses.Delta(self.fs_pop, self.ch_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='gi'), - g_max=self.inh_syn_weight1, - delay_step=self.num_delay_step) - - self.Ch_to_RS = bp.synapses.Delta(self.ch_pop, self.rs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.Ch_to_FS = bp.synapses.Delta(self.ch_pop, self.fs_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - self.Ch_to_Ch = bp.synapses.Delta(self.ch_pop, self.ch_pop, bp.conn.FixedProb(0.02), - output=bp.synouts.CUBA(target_var='ge'), - g_max=self.exc_syn_weight, - delay_step=self.num_delay_step) - - def change_freq(self, tdi): - self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i] - - -def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None): - dt = bm.get_dt() if dt is None else dt - t = 0 - num_gap = int(t_gap / dt) - num_total = int(t_total / dt) - num_transition = int(t_transition / dt) - - inputs = [] - ramp_up = np.linspace(c_low, c_high, num_transition) - ramp_down = np.linspace(c_high, c_low, num_transition) - plato_base = np.ones(num_gap) * c_low - while t < num_total: - num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt) - inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down]) - t += (num_gap + num_transition + num_plato + num_transition) - return bm.asarray(np.concatenate(inputs)[:num_total]) - - -def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space): - # sampling_space: in seconds (no units) - # signal_time: in seconds (no units) - # low_cut: in Hz (no units)(band to filter) - # high_cut: in Hz (no units)(band to filter) - - signal = signal - np.mean(signal) - width = 5.0 # The desired width in Hz of the transition from pass to stop - ripple_db = 60.0 # The desired attenuation in the stop band, in dB. - sampling_rate = 1. / sampling_space - Nyquist = sampling_rate / 2. - - num_taps, beta = kaiserord(ripple_db, width / Nyquist) - if num_taps % 2 == 0: - num_taps = num_taps + 1 # Numtaps must be odd - taps = firwin(num_taps, - [low_cut / Nyquist, high_cut / Nyquist], - window=('kaiser', beta), - nyq=1.0, - pass_zero=False, - scale=True) - filtered_signal = lfilter(taps, 1.0, signal) - delay = 0.5 * (num_taps - 1) / sampling_rate # To corrected to zero-phase - delay_index = int(np.floor(delay * sampling_rate)) - filtered_signal = filtered_signal[num_taps - 1:] # taking out the "corrupted" signal - # correcting the delay and taking out the "corrupted" signal part - filtered_time = signal_time[num_taps - 1:] - delay - cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))] - - # -------------------------------------------------------------------------- - # The hilbert transform are very slow when the signal has odd lenght, - # This part check if the length is odd, and if this is the case it adds a zero in the end - # of all the vectors related to the filtered Signal: - if len(filtered_signal) % 2 != 0: # If the lengh is odd - tmp1 = filtered_signal.tolist() - tmp1.append(0) - tmp2 = filtered_time.tolist() - tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0]) - tmp3 = cutted_signal.tolist() - tmp3.append(0) - filtered_signal = np.asarray(tmp1) - filtered_time = np.asarray(tmp2) - cutted_signal = np.asarray(tmp3) - # -------------------------------------------------------------------------- - - ht_filtered_signal = hilbert(filtered_signal) - envelope = np.abs(ht_filtered_signal) - phase = np.angle(ht_filtered_signal) # The phase is between -pi and pi in radians - - return filtered_time, filtered_signal, cutted_signal, envelope, phase - - -def visualize_simulation_results(times, spikes, example_potentials, varied_rates, - xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None): - fig, gs = bp.visualize.get_figure(7, 1, 1, 12) - # 1. input firing rate - ax = fig.add_subplot(gs[0]) - plt.plot(times, varied_rates) - if xlim is None: - xlim = (0, times[-1]) - ax.set_xlim(*xlim) - ax.set_xticks([]) - ax.set_ylabel('External\nRate (Hz)') - - # 2. inhibitory cell rater plot - ax = fig.add_subplot(gs[1: 3]) - i = 0 - y_ticks = ([], []) - for key, (sp_matrix, sp_type) in spikes.items(): - iis, sps = np.where(sp_matrix) - tts = times[iis] - plt.plot(tts, sps + i, '.', markersize=1, label=key) - y_ticks[0].append(i + sp_matrix.shape[1] / 2) - y_ticks[1].append(key) - i += sp_matrix.shape[1] - ax.set_xlim(*xlim) - ax.set_xlabel('') - ax.set_ylabel('Neuron Index') - ax.set_xticks([]) - ax.set_yticks(*y_ticks) - # ax.legend() - - # 3. example membrane potential - ax = fig.add_subplot(gs[3: 5]) - for key, potential in example_potentials.items(): - vs = np.where(spikes[key][0][:, 0], 0, potential) - plt.plot(times, vs, label=key) - ax.set_xlim(*xlim) - ax.set_xticks([]) - ax.set_ylabel('V (mV)') - ax.legend() - - # 4. LFP - ax = fig.add_subplot(gs[5:7]) - ax.set_xlim(*xlim) - t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0 - t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times) - times = times[t1: t2] - lfp = 0 - for sp_matrix, sp_type in spikes.values(): - lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type) - phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50, - bm.get_dt() * 1e-3) - plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP') - plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)") - plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope") - plt.legend(loc='best') - plt.xlabel('Time (ms)') - - # save or show - if filename: - plt.savefig(filename, dpi=500) - plt.show() - - -def simulate_single_neuron(duration=4e3): - input_currents = get_inputs(0., 500, 50, 500, 600, 2e3, duration) - - RS_cell = AdEx(1, V_sp_th=RS_par['Vth'], **RS_par) - runner = bp.DSRunner(RS_cell, monitors=['V']) - runner.run(inputs=input_currents) - - FS_cell = AdEx(1, V_sp_th=FS_par['Vth'], **FS_par) - runner2 = bp.DSRunner(FS_cell, monitors=['V']) - runner2.run(inputs=input_currents) - - fig, gs = bp.visualize.get_figure(3, 1, 3, 10) - ax = fig.add_subplot(gs[0, 0]) - bp.visualize.line_plot(runner.mon.ts, input_currents) - ax.set_xlim(1600, 3000) - ax.set_title('Input Current') - - ax = fig.add_subplot(gs[1, 0]) - ax.set_xlim(1600, 3000) - ax.set_title('RS Neuron') - bp.visualize.line_plot(runner.mon.ts, runner.mon.V) - - ax = fig.add_subplot(gs[2, 0]) - ax.set_xlim(1600, 3000) - ax.set_title('FS Neuron') - bp.visualize.line_plot(runner2.mon.ts, runner2.mon.V) - plt.show() - - -def simulate_ping_net(): - duration = 6e3 - varied_rates = get_inputs(2., 3., 50., 150, 600, 1e3, duration) - - net = PINGNet(varied_rates, ext_weight=4.) - runner = bp.DSRunner( - net, - inputs=net.change_freq, - monitors={'FS.V0': lambda tdi: net.fs_pop.V[0], - 'RS.V0': lambda tdi: net.rs_pop.V[0], - 'FS.spike': lambda tdi: net.fs_pop.spike, - 'RS.spike': lambda tdi: net.rs_pop.spike} - ) - runner.run(duration) - - visualize_simulation_results(times=runner.mon.ts, - spikes={'FS': (runner.mon['FS.spike'], 'inh'), - 'RS': (runner.mon['RS.spike'], 'exc')}, - example_potentials={'FS': runner.mon['FS.V0'], - 'RS': runner.mon['RS.V0']}, - varied_rates=varied_rates.to_numpy(), - xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3) - - -def simulate_ai_net(): - duration = 2e3 - varied_rates = get_inputs(2., 2., 50., 150, 600, 1e3, duration) - - net = PINGNet(varied_rates, ext_weight=5.) - runner = bp.DSRunner( - net, - inputs=net.change_freq, - monitors={'FS.V0': lambda tdi: net.fs_pop.V[0], - 'RS.V0': lambda tdi: net.rs_pop.V[0], - 'FS.spike': lambda tdi: net.fs_pop.spike, - 'RS.spike': lambda tdi: net.rs_pop.spike} - ) - runner.run(duration) - - visualize_simulation_results(times=runner.mon.ts, - spikes={'FS': (runner.mon['FS.spike'], 'inh'), - 'RS': (runner.mon['RS.spike'], 'exc')}, - example_potentials={'FS': runner.mon['FS.V0'], - 'RS': runner.mon['RS.V0']}, - varied_rates=varied_rates.to_numpy()) - - -def simulate_ing_net(): - duration = 6e3 - varied_rates = get_inputs(2., 3., 50., 350, 600, 1e3, duration) - - net = INGNet(varied_rates, ext_weight=0.9) - runner = bp.DSRunner( - net, - inputs=net.change_freq, - monitors={'FS.V0': lambda tdi: net.fs_pop.V[0], - 'FS2.V0': lambda tdi: net.fs2_pop.V[0], - 'RS.V0': lambda tdi: net.rs_pop.V[0], - 'FS.spike': lambda tdi: net.fs_pop.spike, - 'FS2.spike': lambda tdi: net.fs2_pop.spike, - 'RS.spike': lambda tdi: net.rs_pop.spike} - ) - runner.run(duration) - - visualize_simulation_results(times=runner.mon.ts, - spikes={'FS': (runner.mon['FS.spike'], 'inh'), - 'FS2': (runner.mon['FS2.spike'], 'inh'), - 'RS': (runner.mon['RS.spike'], 'exc')}, - example_potentials={'FS': runner.mon['FS.V0'], - 'FS2': runner.mon['FS2.V0'], - 'RS': runner.mon['RS.V0']}, - varied_rates=varied_rates.to_numpy(), - xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3) - - -def simulate_ching_net(): - duration = 6e3 - varied_rates = get_inputs(1., 2., 50., 150, 600, 1e3, duration) - - net = CHINGNet(varied_rates) - runner = bp.DSRunner( - net, - inputs=net.change_freq, - monitors={'FS.V0': lambda tdi: net.fs_pop.V[0], - 'CH.V0': lambda tdi: net.ch_pop.V[0], - 'RS.V0': lambda tdi: net.rs_pop.V[0], - 'FS.spike': lambda tdi: net.fs_pop.spike, - 'CH.spike': lambda tdi: net.ch_pop.spike, - 'RS.spike': lambda tdi: net.rs_pop.spike} - ) - runner.run(duration) - - visualize_simulation_results(times=runner.mon.ts, - spikes={'FS': (runner.mon['FS.spike'], 'inh'), - 'CH': (runner.mon['CH.spike'], 'exc'), - 'RS': (runner.mon['RS.spike'], 'exc')}, - example_potentials={'FS': runner.mon['FS.V0'], - 'CH': runner.mon['CH.V0'], - 'RS': runner.mon['RS.V0']}, - varied_rates=varied_rates.to_numpy(), - xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3) - - -if __name__ == '__main__': - # simulate_single_neuron() - simulate_ping_net() - simulate_ai_net() - simulate_ing_net() - simulate_ching_net() diff --git a/examples/dynamics_simulation/Vreeswijk_1996_EI_net.py b/examples/dynamics_simulation/Vreeswijk_1996_EI_net.py deleted file mode 100644 index 9b1e31ed0..000000000 --- a/examples/dynamics_simulation/Vreeswijk_1996_EI_net.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy as bp -import brainpy.math as bm - -import matplotlib.pyplot as plt - -bm.set_platform('cpu') - - -class EINet(bp.Network): - def __init__(self, num_exc, num_inh, prob, JE, JI): - # neurons - pars = dict(V_rest=-52., V_th=-50., V_reset=-60., tau=10., tau_ref=0.) - E = bp.neurons.LIF(num_exc, **pars) - I = bp.neurons.LIF(num_inh, **pars) - E.V[:] = bm.random.random(num_exc) * (E.V_th - E.V_rest) + E.V_rest - I.V[:] = bm.random.random(num_inh) * (E.V_th - E.V_rest) + E.V_rest - - # synapses - E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob), g_max=JE, tau=2., - output=bp.synouts.CUBA()) - E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob), g_max=JE, tau=2., - output=bp.synouts.CUBA()) - I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob), g_max=JI, tau=2., - output=bp.synouts.CUBA()) - I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob), g_max=JI, tau=2., - output=bp.synouts.CUBA()) - - super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) - - -num_exc = 500 -num_inh = 500 -prob = 0.5 - -Ib = 3. -JE = 1 / bp.math.sqrt(prob * num_exc) -JI = -1 / bp.math.sqrt(prob * num_inh) - -net = EINet(num_exc, num_inh, prob=prob, JE=JE, JI=JI) - -runner = bp.DSRunner(net, - monitors=['E.spike'], - inputs=[('E.input', Ib), ('I.input', Ib)]) -t = runner.run(1000.) - -fig, gs = bp.visualize.get_figure(4, 1, 2, 10) - -fig.add_subplot(gs[:3, 0]) -bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], xlim=(50, 950)) - -fig.add_subplot(gs[3, 0]) -rates = bp.measure.firing_rate(runner.mon['E.spike'], 5.) -plt.plot(runner.mon.ts, rates) -plt.xlim(50, 950) -plt.show() diff --git a/examples/dynamics_simulation/Wang_2002_decision_making_spiking.py b/examples/dynamics_simulation/Wang_2002_decision_making_spiking.py deleted file mode 100644 index da1440ed2..000000000 --- a/examples/dynamics_simulation/Wang_2002_decision_making_spiking.py +++ /dev/null @@ -1,289 +0,0 @@ -# -*- coding: utf-8 -*- - - -import matplotlib.pyplot as plt - -import brainpy as bp -import brainpy.math as bm - - -# bm.set_platform('cpu') - - -class PoissonStim(bp.NeuGroup): - def __init__(self, size, freq_mean, freq_var, t_interval): - super(PoissonStim, self).__init__(size=size) - - # parameters - self.freq_mean = freq_mean - self.freq_var = freq_var - self.t_interval = t_interval - self.dt = bm.get_dt() / 1000. - - # variables - self.freq = bp.init.variable_(bm.zeros, 1, self.mode) - self.freq_t_last_change = bp.init.variable_(lambda s: bm.ones(s) * -1e7, 1, self.mode) - self.spike = bp.init.variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) - self.rng = bm.random.RandomState() - - def reset_state(self, batch_size=None): - self.freq.value = bp.init.variable_(bm.zeros, 1, batch_size) - self.freq_t_last_change.value = bp.init.variable_(lambda s: bm.ones(s) * -1e7, 1, batch_size) - self.spike.value = bp.init.variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - - def update(self, tdi): - t, dt = tdi['t'], tdi['dt'] - in_interval = bm.logical_and(pre_stimulus_period < t, t < pre_stimulus_period + stimulus_period) - in_interval = bm.ones_like(self.freq, dtype=bool) * in_interval - prev_freq = bm.where(in_interval, self.freq, 0.) - in_interval = bm.logical_and(in_interval, (t - self.freq_t_last_change) >= self.t_interval) - self.freq.value = bm.where(in_interval, self.rng.normal(self.freq_mean, self.freq_var, self.freq.shape), prev_freq) - self.freq_t_last_change.value = bm.where(in_interval, t, self.freq_t_last_change) - shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, bm.BatchingMode) else self.varshape - self.spike.value = self.rng.random(shape) < self.freq * self.dt - - -class DecisionMaking(bp.Network): - def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15): - super(DecisionMaking, self).__init__() - - num_exc = int(1600 * scale) - num_inh = int(400 * scale) - num_A = int(f * num_exc) - num_B = int(f * num_exc) - num_N = num_exc - num_A - num_B - print(f'Total network size: {num_exc + num_inh}') - - poisson_freq = 2400. # Hz - w_pos = 1.7 - w_neg = 1. - f * (w_pos - 1.) / (1. - f) - g_ext2E_AMPA = 2.1 # nS - g_ext2I_AMPA = 1.62 # nS - g_E2E_AMPA = 0.05 / scale # nS - g_E2I_AMPA = 0.04 / scale # nS - g_E2E_NMDA = 0.165 / scale # nS - g_E2I_NMDA = 0.13 / scale # nS - g_I2E_GABAa = 1.3 / scale # nS - g_I2I_GABAa = 1.0 / scale # nS - - ampa_par = dict(delay_step=int(0.5 / bm.get_dt()), tau=2.0) - gaba_par = dict(delay_step=int(0.5 / bm.get_dt()), tau=5.0) - nmda_par = dict(delay_step=int(0.5 / bm.get_dt()), tau_decay=100, tau_rise=2., a=0.5) - - # E neurons/pyramid neurons - A = bp.neurons.LIF(num_A, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, - tau_ref=2., V_initializer=bp.init.OneInit(-70.)) - B = bp.neurons.LIF(num_B, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, - tau_ref=2., V_initializer=bp.init.OneInit(-70.)) - N = bp.neurons.LIF(num_N, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, - tau_ref=2., V_initializer=bp.init.OneInit(-70.)) - # I neurons/interneurons - I = bp.neurons.LIF(num_inh, V_rest=-70., V_reset=-55., V_th=-50., tau=10., R=0.05, - tau_ref=1., V_initializer=bp.init.OneInit(-70.)) - - # poisson stimulus - IA = PoissonStim(num_A, freq_var=10., t_interval=50., freq_mean=mu0 + mu0 / 100. * coherence) - IB = PoissonStim(num_B, freq_var=10., t_interval=50., freq_mean=mu0 - mu0 / 100. * coherence) - - # noise neurons - self.noise_B = bp.neurons.PoissonGroup(num_B, freqs=poisson_freq) - self.noise_A = bp.neurons.PoissonGroup(num_A, freqs=poisson_freq) - self.noise_N = bp.neurons.PoissonGroup(num_N, freqs=poisson_freq) - self.noise_I = bp.neurons.PoissonGroup(num_inh, freqs=poisson_freq) - - # define external inputs - self.IA2A = bp.synapses.Exponential(IA, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.IB2B = bp.synapses.Exponential(IB, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - - # define E->E/I conn - - self.N2B_AMPA = bp.synapses.Exponential(N, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.N2A_AMPA = bp.synapses.Exponential(N, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.N2N_AMPA = bp.synapses.Exponential(N, N, bp.conn.All2All(), g_max=g_E2E_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.N2I_AMPA = bp.synapses.Exponential(N, I, bp.conn.All2All(), g_max=g_E2I_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.N2B_NMDA = bp.synapses.NMDA(N, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - self.N2A_NMDA = bp.synapses.NMDA(N, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - self.N2N_NMDA = bp.synapses.NMDA(N, N, bp.conn.All2All(), g_max=g_E2E_NMDA, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - self.N2I_NMDA = bp.synapses.NMDA(N, I, bp.conn.All2All(), g_max=g_E2I_NMDA, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - - self.B2B_AMPA = bp.synapses.Exponential(B, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.B2A_AMPA = bp.synapses.Exponential(B, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.B2N_AMPA = bp.synapses.Exponential(B, N, bp.conn.All2All(), g_max=g_E2E_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.B2I_AMPA = bp.synapses.Exponential(B, I, bp.conn.All2All(), g_max=g_E2I_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.B2B_NMDA = bp.synapses.NMDA(B, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - self.B2A_NMDA = bp.synapses.NMDA(B, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - self.B2N_NMDA = bp.synapses.NMDA(B, N, bp.conn.All2All(), g_max=g_E2E_NMDA, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - self.B2I_NMDA = bp.synapses.NMDA(B, I, bp.conn.All2All(), g_max=g_E2I_NMDA, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - - self.A2B_AMPA = bp.synapses.Exponential(A, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.A2A_AMPA = bp.synapses.Exponential(A, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.A2N_AMPA = bp.synapses.Exponential(A, N, bp.conn.All2All(), g_max=g_E2E_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.A2I_AMPA = bp.synapses.Exponential(A, I, bp.conn.All2All(), g_max=g_E2I_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.A2B_NMDA = bp.synapses.NMDA(A, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - self.A2A_NMDA = bp.synapses.NMDA(A, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - self.A2N_NMDA = bp.synapses.NMDA(A, N, bp.conn.All2All(), g_max=g_E2E_NMDA, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - self.A2I_NMDA = bp.synapses.NMDA(A, I, bp.conn.All2All(), g_max=g_E2I_NMDA, - output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) - - # define I->E/I conn - self.I2B = bp.synapses.Exponential(I, B, bp.conn.All2All(), g_max=g_I2E_GABAa, - output=bp.synouts.COBA(E=-70.), **gaba_par) - self.I2A = bp.synapses.Exponential(I, A, bp.conn.All2All(), g_max=g_I2E_GABAa, - output=bp.synouts.COBA(E=-70.), **gaba_par) - self.I2N = bp.synapses.Exponential(I, N, bp.conn.All2All(), g_max=g_I2E_GABAa, - output=bp.synouts.COBA(E=-70.), **gaba_par) - self.I2I = bp.synapses.Exponential(I, I, bp.conn.All2All(), g_max=g_I2I_GABAa, - output=bp.synouts.COBA(E=-70.), **gaba_par) - - # define external projections - self.noise2B = bp.synapses.Exponential(self.noise_B, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.noise2A = bp.synapses.Exponential(self.noise_A, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.noise2N = bp.synapses.Exponential(self.noise_N, N, bp.conn.One2One(), g_max=g_ext2E_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - self.noise2I = bp.synapses.Exponential(self.noise_I, I, bp.conn.One2One(), g_max=g_ext2I_AMPA, - output=bp.synouts.COBA(E=0.), **ampa_par) - - # nodes - self.B = B - self.A = A - self.N = N - self.I = I - self.IA = IA - self.IB = IB - - -def visualize_raster(ax, mon, t_start=0., title=None): - bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax, color='', label="Group A") - bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax, color='', label="Group B") - if title: - ax.set_title(title) - ax.set_ylabel("Neuron Index") - ax.set_xlim(t_start, total_period + 1) - ax.axvline(pre_stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') - ax.legend() - - -def visualize_results(axes, mon, t_start=0., title=None): - ax = axes[0] - bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax) - if title: - ax.set_title(title) - ax.set_ylabel("Group A") - ax.set_xlim(t_start, total_period + 1) - ax.axvline(pre_stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') - - ax = axes[1] - bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax) - ax.set_ylabel("Group B") - ax.set_xlim(t_start, total_period + 1) - ax.axvline(pre_stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') - - ax = axes[2] - rateA = bp.measure.firing_rate(mon['A.spike'], width=10.) - rateB = bp.measure.firing_rate(mon['B.spike'], width=10.) - ax.plot(mon['ts'], rateA, label="Group A") - ax.plot(mon['ts'], rateB, label="Group B") - ax.set_ylabel('Population activity [Hz]') - ax.set_xlim(t_start, total_period + 1) - ax.axvline(pre_stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') - ax.legend() - - ax = axes[3] - ax.plot(mon['ts'], mon['IA.freq'], label="group A") - ax.plot(mon['ts'], mon['IB.freq'], label="group B") - ax.set_ylabel("Input activity [Hz]") - ax.set_xlim(t_start, total_period + 1) - ax.axvline(pre_stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') - ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') - ax.legend() - ax.set_xlabel("Time [ms]") - - -pre_stimulus_period = 100. -stimulus_period = 1000. -delay_period = 500. -total_period = pre_stimulus_period + stimulus_period + delay_period - - -def single_run(): - net = DecisionMaking(scale=1., coherence=-80., mu0=50.) - - runner = bp.DSRunner( - net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'] - ) - runner.run(total_period) - - fig, gs = bp.visualize.get_figure(4, 1, 3, 10) - axes = [fig.add_subplot(gs[i, 0]) for i in range(4)] - visualize_results(axes, mon=runner.mon) - plt.show() - - -def batching_run(): - num_row, num_col = 3, 4 - num_batch = 12 - coherence = bm.expand_dims(bm.linspace(-100, 100., num_batch), 1) - - with bm.batching_environment(): - net = DecisionMaking(scale=1., coherence=coherence, mu0=20.) - net.reset_state(batch_size=num_batch) - runner = bp.DSRunner( - net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'], data_first_axis='B' - ) - runner.run(total_period) - - coherence = bm.as_numpy(coherence) - fig, gs = bp.visualize.get_figure(num_row, num_col, 3, 4) - for i in range(num_row): - for j in range(num_col): - idx = i * num_col + j - if idx < num_batch: - mon = {'A.spike': runner.mon['A.spike'][idx], - 'B.spike': runner.mon['B.spike'][idx], - 'IA.freq': runner.mon['IA.freq'][idx], - 'IB.freq': runner.mon['IB.freq'][idx], - 'ts': runner.mon['ts']} - ax = fig.add_subplot(gs[i, j]) - visualize_raster(ax, mon=mon, title=f'coherence={coherence[idx, 0]}%') - plt.show() - - -if __name__ == '__main__': - single_run() - batching_run() diff --git a/examples/dynamics_simulation/Wu_2008_CANN_1D.py b/examples/dynamics_simulation/Wu_2008_CANN_1D.py deleted file mode 100644 index 8adbabbd7..000000000 --- a/examples/dynamics_simulation/Wu_2008_CANN_1D.py +++ /dev/null @@ -1,127 +0,0 @@ -# Implementation of the paper: -# - Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. "Dynamics and computation -# of continuous attractors." Neural computation 20.4 (2008): 994-1025. - -import brainpy as bp -import brainpy.math as bm - -bm.set_platform('cpu') - - -class CANN1D(bp.NeuGroup): - def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4., - z_min=-bm.pi, z_max=bm.pi, **kwargs): - super(CANN1D, self).__init__(size=num, **kwargs) - - # parameters - self.tau = tau # The synaptic time constant - self.k = k # Degree of the rescaled inhibition - self.a = a # Half-width of the range of excitatory connections - self.A = A # Magnitude of the external input - self.J0 = J0 # maximum connection value - - # feature space - self.z_min = z_min - self.z_max = z_max - self.z_range = z_max - z_min - self.x = bm.linspace(z_min, z_max, num) # The encoded feature values - self.rho = num / self.z_range # The neural density - self.dx = self.z_range / num # The stimulus density - - # The connection matrix - self.conn_mat = self.make_conn() - - # variables - self.r = bm.Variable(bm.zeros(num)) - self.u = bm.Variable(bm.zeros(num)) - self.input = bm.Variable(bm.zeros(num)) - - def dist(self, d): - d = bm.remainder(d, self.z_range) - d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d) - return d - - def make_conn(self): - x_left = bm.reshape(self.x, (-1, 1)) - x_right = bm.repeat(self.x.reshape((1, -1)), len(self.x), axis=0) - d = self.dist(x_left - x_right) - conn = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) - return conn - - def get_stimulus_by_pos(self, pos): - return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a)) - - def update(self, tdi, x=None): - if x is not None: - self.input[:] = x - r1 = bm.square(self.u) - r2 = 1.0 + self.k * bm.sum(r1) - self.r.value = r1 / r2 - Irec = bm.dot(self.conn_mat, self.r) - self.u.value = self.u + (-self.u + Irec + self.input) / self.tau * tdi.dt - self.input[:] = 0. - - -cann = CANN1D(num=512, k=0.1) - -# Population coding - -# %% -I1 = cann.get_stimulus_by_pos(0.) -Iext, duration = bp.inputs.section_input(values=[0., I1, 0.], - durations=[1., 8., 8.], - return_length=True) -runner = bp.DSRunner(cann, monitors=['u']) -runner(inputs=Iext) -bp.visualize.animate_1D( - dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'}, - {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}], - frame_step=1, - frame_delay=100, - show=True, -) - -# Template matching # -# The cann can perform efficient population decoding by achieving template-matching. - -# %% -cann.k = 8.1 - -dur1, dur2, dur3 = 10., 30., 0. -num1 = int(dur1 / bm.get_dt()) -num2 = int(dur2 / bm.get_dt()) -num3 = int(dur3 / bm.get_dt()) -Iext = bm.zeros((num1 + num2 + num3,) + cann.size) -Iext[:num1] = cann.get_stimulus_by_pos(0.5) -Iext[num1:num1 + num2] = cann.get_stimulus_by_pos(0.) -Iext[num1:num1 + num2] += 0.1 * cann.A * bm.random.randn(num2, *cann.size) - -runner = bp.DSRunner(cann, monitors=['u']) -runner(inputs=Iext) -bp.visualize.animate_1D( - dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'}, - {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}], - frame_step=5, - frame_delay=50, - show=True -) - -# Smooth tracking # -dur1, dur2, dur3 = 10., 100., 20. -num1 = int(dur1 / bm.get_dt()) -num2 = int(dur2 / bm.get_dt()) -num3 = int(dur3 / bm.get_dt()) -position = bm.zeros(num1 + num2 + num3) -position[num1: num1 + num2] = bm.linspace(0., 20., num2) -position[num1 + num2:] = 20. -position = position.reshape((-1, 1)) -Iext = cann.get_stimulus_by_pos(position) -runner = bp.DSRunner(cann, monitors=['u']) -runner(inputs=Iext) -bp.visualize.animate_1D( - dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'}, - {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}], - frame_step=5, - frame_delay=50, - show=True, -) diff --git a/examples/dynamics_simulation/Wu_2008_CANN_1D_oscillatory_tracking.py b/examples/dynamics_simulation/Wu_2008_CANN_1D_oscillatory_tracking.py deleted file mode 100644 index 6b949133a..000000000 --- a/examples/dynamics_simulation/Wu_2008_CANN_1D_oscillatory_tracking.py +++ /dev/null @@ -1,94 +0,0 @@ -# Implementation of the paper: -# -# - Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. "Dynamics and computation -# of continuous attractors." Neural computation 20.4 (2008): 994-1025. - -import brainpy as bp -import brainpy.math as bm - - -class CANN1D(bp.NeuGroup): - def __init__(self, num, tau=1., tau_v=50., k=1., a=0.3, A=0.2, J0=1., - z_min=-bm.pi, z_max=bm.pi, m=0.3): - super(CANN1D, self).__init__(size=num) - - # parameters - self.tau = tau # The synaptic time constant - self.tau_v = tau_v - self.k = k # Degree of the rescaled inhibition - self.a = a # Half-width of the range of excitatory connections - self.A = A # Magnitude of the external input - self.J0 = J0 # maximum connection value - self.m = m - - # feature space - self.z_min = z_min - self.z_max = z_max - self.z_range = z_max - z_min - self.x = bm.linspace(z_min, z_max, num) # The encoded feature values - self.rho = num / self.z_range # The neural density - self.dx = self.z_range / num # The stimulus density - - # The connection matrix - self.conn_mat = self.make_conn() - - # variables - self.r = bm.Variable(bm.zeros(num)) - self.u = bm.Variable(bm.zeros(num)) - self.v = bm.Variable(bm.zeros(num)) - self.input = bm.Variable(bm.zeros(num)) - - def dist(self, d): - d = bm.remainder(d, self.z_range) - d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d) - return d - - def make_conn(self): - x_left = bm.reshape(self.x, (-1, 1)) - x_right = bm.repeat(self.x.reshape((1, -1)), len(self.x), axis=0) - d = self.dist(x_left - x_right) - conn = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) - return conn - - def get_stimulus_by_pos(self, pos): - return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a)) - - def update(self, tdi): - r1 = bm.square(self.u) - r2 = 1.0 + self.k * bm.sum(r1) - self.r.value = r1 / r2 - Irec = bm.dot(self.conn_mat, self.r) - self.u.value = self.u + (-self.u + Irec + self.input - self.v) / self.tau * tdi.dt - self.v.value = self.v + (-self.v + self.m * self.u) / self.tau_v * tdi.dt - self.input[:] = 0. - - -cann = CANN1D(num=512) - -# Smooth tracking # -dur1, dur2, dur3 = 100., 2000., 500. -num1 = int(dur1 / bm.get_dt()) -num2 = int(dur2 / bm.get_dt()) -num3 = int(dur3 / bm.get_dt()) -position = bm.zeros(num1 + num2 + num3) -final_pos = cann.a / cann.tau_v * 0.6 * dur2 -position[num1: num1 + num2] = bm.linspace(0., final_pos, num2) -position[num1 + num2:] = final_pos -position = position.reshape((-1, 1)) -Iext = cann.get_stimulus_by_pos(position) -runner = bp.DSRunner(cann, - inputs=('input', Iext, 'iter'), - monitors=['u', 'v'], - dyn_vars=cann.vars()) -runner(dur1 + dur2 + dur3) -bp.visualize.animate_1D( - dynamical_vars=[ - {'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'}, - {'ys': runner.mon.v, 'xs': cann.x, 'legend': 'v'}, - {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'} - ], - frame_step=30, - frame_delay=5, - show=True, - save_path='./cann_1d_oscillatory_tracking.gif' -) diff --git a/examples/dynamics_simulation/Wu_2008_CANN_2D.py b/examples/dynamics_simulation/Wu_2008_CANN_2D.py deleted file mode 100644 index 0d9731ed7..000000000 --- a/examples/dynamics_simulation/Wu_2008_CANN_2D.py +++ /dev/null @@ -1,114 +0,0 @@ -# Implementation of the paper: -# - Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. "Dynamics and computation -# of continuous attractors." Neural computation 20.4 (2008): 994-1025. - -import jax -import matplotlib.pyplot as plt -import numpy as np - -import brainpy as bp -import brainpy.math as bm - -bm.set_platform('cpu') - - -class CANN2D(bp.NeuGroup): - def __init__(self, length, tau=1., k=8.1, a=0.5, A=10., J0=4., - z_min=-bm.pi, z_max=bm.pi, name=None): - super(CANN2D, self).__init__(size=(length, length), name=name) - - # parameters - self.length = length - self.tau = tau # The synaptic time constant - self.k = k # Degree of the rescaled inhibition - self.a = a # Half-width of the range of excitatory connections - self.A = A # Magnitude of the external input - self.J0 = J0 # maximum connection value - - # feature space - self.z_min = z_min - self.z_max = z_max - self.z_range = z_max - z_min - self.x = bm.linspace(z_min, z_max, length) # The encoded feature values - self.rho = length / self.z_range # The neural density - self.dx = self.z_range / length # The stimulus density - - # The connections - self.conn_mat = self.make_conn() - - # variables - self.r = bm.Variable(bm.zeros((length, length))) - self.u = bm.Variable(bm.zeros((length, length))) - self.input = bm.Variable(bm.zeros((length, length))) - - def show_conn(self): - plt.imshow(np.asarray(self.conn_mat)) - plt.colorbar() - plt.show() - - def dist(self, d): - v_size = bm.asarray([self.z_range, self.z_range]) - return bm.where(d > v_size / 2, v_size - d, d) - - def make_conn(self): - x1, x2 = bm.meshgrid(self.x, self.x) - value = bm.stack([x1.flatten(), x2.flatten()]).T - - @jax.vmap - def get_J(v): - d = self.dist(bm.abs(v - value)) - d = bm.linalg.norm(d, axis=1) - # d = d.reshape((self.length, self.length)) - Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) - return Jxx - - return get_J(value) - - def get_stimulus_by_pos(self, pos): - assert bm.size(pos) == 2 - x1, x2 = bm.meshgrid(self.x, self.x) - value = bm.stack([x1.flatten(), x2.flatten()]).T - d = self.dist(bm.abs(bm.asarray(pos) - value)) - d = bm.linalg.norm(d, axis=1) - d = d.reshape((self.length, self.length)) - return self.A * bm.exp(-0.25 * bm.square(d / self.a)) - - def update(self, tdi): - r1 = bm.square(self.u) - r2 = 1.0 + self.k * bm.sum(r1) - self.r.value = r1 / r2 - interaction = (self.r.flatten() @ self.conn_mat).reshape((self.length, self.length)) - self.u.value = self.u + (-self.u + self.input + interaction) / self.tau * tdi.dt - self.input[:] = 0. - - -cann = CANN2D(length=100, k=0.1) -cann.show_conn() - -# encoding -Iext, length = bp.inputs.section_input( - values=[cann.get_stimulus_by_pos([0., 0.]), 0.], - durations=[10., 20.], - return_length=True -) -runner = bp.DSRunner(cann, - inputs=['input', Iext, 'iter'], - monitors=['r'], - dyn_vars=cann.vars()) -runner.run(length) - -bp.visualize.animate_2D(values=runner.mon.r.reshape((-1, cann.num)), - net_size=(cann.length, cann.length)) - -# tracking -length = 20 -positions = bp.inputs.ramp_input(-bm.pi, bm.pi, duration=length, t_start=0) -positions = bm.stack([positions, positions]).T -Iext = jax.vmap(cann.get_stimulus_by_pos)(positions) -runner = bp.DSRunner(cann, - inputs=['input', Iext, 'iter'], - monitors=['r']) -runner.run(length) - -bp.visualize.animate_2D(values=runner.mon.r.reshape((-1, cann.num)), - net_size=(cann.length, cann.length)) diff --git a/examples/dynamics_simulation/multi_scale_COBAHH.py b/examples/dynamics_simulation/multi_scale_COBAHH.py index ddb070139..1ee6475e4 100644 --- a/examples/dynamics_simulation/multi_scale_COBAHH.py +++ b/examples/dynamics_simulation/multi_scale_COBAHH.py @@ -12,7 +12,6 @@ from brainpy.synouts import COBA from brainpy.connect import FixedProb from jax import vmap -import seaborn as sns comp_method = 'sparse' @@ -324,8 +323,6 @@ def visualize(seed=20873, gc=1., gEE=0.0060, gEI=0.0060, gIE=.26800, gII=0.26800 # plt.ylabel('Current') # plt.show() - sns.set_theme(font_scale=1.5) - fig, gs = bp.visualize.get_figure(2, 1, 2.25 * 1, 6 * 1) plot_ids = [0, 2, 4, 8] fig.add_subplot(gs[0, 0]) diff --git a/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py b/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py deleted file mode 100644 index 9671d1abb..000000000 --- a/examples/dynamics_training/Bellec_2020_eprop_evidence_accumulation.py +++ /dev/null @@ -1,226 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Implementation of the paper: - -- Bellec, G., Scherr, F., Subramoney, A., Hajek, E., Salaj, D., Legenstein, R., - & Maass, W. (2020). A solution to the learning dilemma for recurrent networks - of spiking neurons. Nature communications, 11(1), 1-15. - -""" - -import matplotlib.pyplot as plt -import numpy as np -import brainpy as bp -import brainpy.math as bm -from jax.lax import stop_gradient -from matplotlib import patches - -bm.set_environment(mode=bm.training_mode, dt=1.) - -# training parameters -n_batch = 128 # batch size - -# neuron model and simulation parameters -reg_f = 1. # regularization coefficient for firing rate -reg_rate = 10 # target firing rate for regularization [Hz] - -# Experiment parameters -t_cue_spacing = 150 # distance between two consecutive cues in ms - -# Frequencies -input_f0 = 40. / 1000. # poisson firing rate of input neurons in khz -regularization_f0 = reg_rate / 1000. # mean target network firing frequency - - -class EligSNN(bp.Network): - def __init__(self, num_in, num_rec, num_out, eprop=True, tau_a=2e3, tau_v=2e1): - super(EligSNN, self).__init__() - - # parameters - self.num_in = num_in - self.num_rec = num_rec - self.num_out = num_out - self.eprop = eprop - - # neurons - self.i = bp.neurons.InputGroup(num_in) - self.o = bp.neurons.LeakyIntegrator(num_out, tau=20) - - n_regular = int(num_rec / 2) - n_adaptive = num_rec - n_regular - beta1 = bm.exp(- bm.get_dt() / tau_a) - beta2 = 1.7 * (1 - beta1) / (1 - bm.exp(-1 / tau_v)) - beta = bm.concatenate([bm.ones(n_regular), bm.ones(n_adaptive) * beta2]) - self.r = bp.neurons.ALIFBellec2020( - num_rec, V_rest=0., tau_ref=5., V_th=0.6, - tau_a=tau_a, tau=tau_v, beta=beta, - V_initializer=bp.init.ZeroInit(), - a_initializer=bp.init.ZeroInit(), - eprop=eprop - ) - - # synapses - self.i2r = bp.layers.Dense(num_in, num_rec, - W_initializer=bp.init.KaimingNormal(), - b_initializer=None) - self.i2r.W *= tau_v - self.r2r = bp.layers.Dense(num_rec, num_rec, - W_initializer=bp.init.KaimingNormal(), - b_initializer=None) - self.r2r.W *= tau_v - self.r2o = bp.layers.Dense(num_rec, num_out, - W_initializer=bp.init.KaimingNormal(), - b_initializer=None) - - def update(self, shared, x): - self.r.input += self.i2r(shared, x) - z = stop_gradient(self.r.spike.value) if self.eprop else self.r.spike.value - self.r.input += self.r2r(shared, z) - self.r(shared) - self.o.input += self.r2o(shared, self.r.spike.value) - self.o(shared) - return self.o.V.value - - -net = EligSNN(num_in=40, num_rec=100, num_out=2, eprop=False) - - -@bp.tools.numba_jit -def generate_click_task_data(batch_size, seq_len, n_neuron, recall_duration, prob, f0=0.5, - n_cues=7, t_cue=100, t_interval=150, n_input_symbols=4): - n_channel = n_neuron // n_input_symbols - - # assign input spike probabilities - probs = np.where(np.random.random((batch_size, 1)) < 0.5, prob, 1 - prob) - - # for each example in batch, draw which cues are going to be active (left or right) - cue_assignments = np.asarray(np.random.random(n_cues) > probs, dtype=np.int_) - - # generate input nums - 0: left, 1: right, 2:recall, 3:background noise - input_nums = 3 * np.ones((batch_size, seq_len), dtype=np.int_) - input_nums[:, :n_cues] = cue_assignments - input_nums[:, -1] = 2 - - # generate input spikes - input_spike_prob = np.zeros((batch_size, seq_len, n_neuron)) - d_silence = t_interval - t_cue - for b in range(batch_size): - for k in range(n_cues): - # input channels only fire when they are selected (left or right) - c = cue_assignments[b, k] - # reverse order of cues - i_seq = d_silence + k * t_interval - i_neu = c * n_channel - input_spike_prob[b, i_seq:i_seq + t_cue, i_neu:i_neu + n_channel] = f0 - # recall cue - input_spike_prob[:, -recall_duration:, 2 * n_channel:3 * n_channel] = f0 - # background noise - input_spike_prob[:, :, 3 * n_channel:] = f0 / 4. - input_spikes = input_spike_prob > np.random.rand(*input_spike_prob.shape) - - # generate targets - target_mask = np.zeros((batch_size, seq_len), dtype=np.bool_) - target_mask[:, -1] = True - target_nums = (np.sum(cue_assignments, axis=1) > n_cues / 2).astype(np.int_) - return input_spikes, input_nums, target_nums, target_mask - - -def get_data(batch_size, n_in, t_interval, f0): - # used for obtaining a new randomly generated batch of examples - def generate_data(): - for _ in range(10): - seq_len = int(t_interval * 7 + 1200) - spk_data, _, target_data, _ = generate_click_task_data( - batch_size=batch_size, seq_len=seq_len, n_neuron=n_in, recall_duration=150, - prob=0.3, t_cue=100, n_cues=7, t_interval=t_interval, f0=f0, n_input_symbols=4 - ) - yield spk_data, target_data - - return generate_data - - -def loss_fun(predicts, targets): - predicts, mon = predicts - - # we only use network output at the end for classification - output_logits = predicts[:, -t_cue_spacing:] - - # Define the accuracy - y_predict = bm.argmax(bm.mean(output_logits, axis=1), axis=1) - accuracy = bm.equal(targets, y_predict).astype(bm.float_).mean() - - # loss function - tiled_targets = bm.tile(bm.expand_dims(targets, 1), (1, t_cue_spacing)) - loss_cls = bm.mean(bp.losses.cross_entropy_loss(output_logits, tiled_targets)) - - # Firing rate regularization: - # For historical reason we often use this regularization, - # but the other one is easier to implement in an "online" fashion by a single agent. - av = bm.mean(mon['r.spike'], axis=(0, 1)) / bm.get_dt() - loss_reg_f = bm.sum(bm.square(av - regularization_f0) * reg_f) - - # Aggregate the losses # - loss = loss_reg_f + loss_cls - loss_res = {'loss': loss, 'loss reg': loss_reg_f, 'accuracy': accuracy} - return loss, loss_res - - -# Training -trainer = bp.BPTT( - net, - loss_fun, - loss_has_aux=True, - optimizer=bp.optimizers.Adam(lr=0.01), - monitors={'r.spike': net.r.spike}, -) -trainer.fit(get_data(n_batch, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0), - num_epoch=40, - num_report=10) - -# visualization -dataset, _ = next(get_data(20, n_in=net.num_in, t_interval=t_cue_spacing, f0=input_f0)()) -runner = bp.DSTrainer(net, monitors={'spike': net.r.spike}) -outs = runner.predict(dataset, reset_state=True) - -for i in range(10): - fig, gs = bp.visualize.get_figure(3, 1, 2., 6.) - ax_inp = fig.add_subplot(gs[0, 0]) - ax_rec = fig.add_subplot(gs[1, 0]) - ax_out = fig.add_subplot(gs[2, 0]) - - data = dataset[i] - # insert empty row - n_channel = data.shape[1] // 4 - zero_fill = np.zeros((data.shape[0], int(n_channel / 2))) - data = np.concatenate((data[:, 3 * n_channel:], zero_fill, - data[:, 2 * n_channel:3 * n_channel], zero_fill, - data[:, :n_channel], zero_fill, - data[:, n_channel:2 * n_channel]), axis=1) - ax_inp.set_yticklabels([]) - ax_inp.add_patch(patches.Rectangle((0, 2 * n_channel + 2 * int(n_channel / 2)), - data.shape[0], n_channel, - facecolor="red", alpha=0.1)) - ax_inp.add_patch(patches.Rectangle((0, 3 * n_channel + 3 * int(n_channel / 2)), - data.shape[0], n_channel, - facecolor="blue", alpha=0.1)) - bp.visualize.raster_plot(runner.mon.ts, data, ax=ax_inp, marker='|') - ax_inp.set_ylabel('Input Activity') - ax_inp.set_xticklabels([]) - ax_inp.set_xticks([]) - - # spiking activity - bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'][i], ax=ax_rec, marker='|') - ax_rec.set_ylabel('Spiking Activity') - ax_rec.set_xticklabels([]) - ax_rec.set_xticks([]) - # decision activity - ax_out.set_yticks([0, 0.5, 1]) - ax_out.set_ylabel('Output Activity') - ax_out.plot(runner.mon.ts, outs[i, :, 0], label='Readout 0', alpha=0.7) - ax_out.plot(runner.mon.ts, outs[i, :, 1], label='Readout 1', alpha=0.7) - ax_out.set_xticklabels([]) - ax_out.set_xticks([]) - ax_out.set_xlabel('Time [ms]') - plt.legend() - plt.show() diff --git a/examples/dynamics_training/Gauthier_2021_ngrc_double_scroll.py b/examples/dynamics_training/Gauthier_2021_ngrc_double_scroll.py deleted file mode 100644 index 2063a7b77..000000000 --- a/examples/dynamics_training/Gauthier_2021_ngrc_double_scroll.py +++ /dev/null @@ -1,142 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Implementation of the paper: - -- Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir - computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 - -The main task is forecasting the double-scroll system. -""" - - -import matplotlib.pyplot as plt -import numpy as np - -import jax.numpy as jnp -import brainpy as bp -import brainpy.math as bm -import brainpy_datasets as bd - -bm.set_environment(bm.batching_mode, x64=True) - - -def get_subset(data, start, end): - res = {'x': data.xs[start: end], - 'y': data.ys[start: end], - 'z': data.zs[start: end]} - res = jnp.hstack([res['x'], res['y'], res['z']]) - return res.reshape((1,) + res.shape) - - -def plot_weights(Wout, coefs, bias=None): - Wout = np.asarray(Wout) - if bias is not None: - bias = np.asarray(bias) - Wout = np.concatenate([bias.reshape((1, 3)), Wout], axis=0) - coefs.insert(0, 'bias') - x_Wout, y_Wout, z_Wout = Wout[:, 0], Wout[:, 1], Wout[:, 2] - - fig = plt.figure(figsize=(10, 10)) - ax = fig.add_subplot(131) - ax.grid(axis="y") - ax.set_xlabel("$[W_{out}]_x$") - ax.set_ylabel("Features") - ax.set_yticks(np.arange(len(coefs))) - ax.set_yticklabels(coefs) - ax.barh(np.arange(x_Wout.size), x_Wout) - - ax1 = fig.add_subplot(132) - ax1.grid(axis="y") - ax1.set_yticks(np.arange(len(coefs))) - ax1.set_xlabel("$[W_{out}]_y$") - ax1.barh(np.arange(y_Wout.size), y_Wout) - - ax2 = fig.add_subplot(133) - ax2.grid(axis="y") - ax2.set_yticks(np.arange(len(coefs))) - ax2.set_xlabel("$[W_{out}]_z$") - ax2.barh(np.arange(z_Wout.size), z_Wout) - - plt.show() - - -def plot_double_scroll(ground_truth, predictions): - fig = plt.figure(figsize=(15, 10)) - ax = fig.add_subplot(121, projection='3d') - ax.set_title("Generated attractor") - ax.set_xlabel("$x$") - ax.set_ylabel("$y$") - ax.set_zlabel("$z$") - ax.grid(False) - ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2]) - - ax2 = fig.add_subplot(122, projection='3d') - ax2.set_title("Real attractor") - ax2.grid(False) - ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2]) - plt.show() - - -dt = 0.02 -t_warmup = 10. # ms -t_train = 100. # ms -t_test = 800. # ms -num_warmup = int(t_warmup / dt) # warm up NVAR -num_train = int(t_train / dt) -num_test = int(t_test / dt) - -# Datasets # -# -------- # -data_series = bd.chaos.DoubleScrollEq(t_warmup + t_train + t_test, dt=dt) - -X_warmup = get_subset(data_series, 0, num_warmup - 1) -Y_warmup = get_subset(data_series, 1, num_warmup) -X_train = get_subset(data_series, num_warmup - 1, num_warmup + num_train - 1) -# Target: Lorenz[t] - Lorenz[t - 1] -dX_train = get_subset(data_series, num_warmup, num_warmup + num_train) - X_train -X_test = get_subset(data_series, - num_warmup + num_train - 1, - num_warmup + num_train + num_test - 1) -Y_test = get_subset(data_series, - num_warmup + num_train, - num_warmup + num_train + num_test) - -# Model # -# ----- # - - -class NGRC(bp.DynamicalSystem): - def __init__(self, num_in): - super(NGRC, self).__init__() - self.r = bp.layers.NVAR(num_in, delay=2, order=3) - self.di = bp.layers.Dense(self.r.num_out, num_in, mode=bm.training_mode) - - def update(self, sha, x): - di = self.di(sha, self.r(sha, x)) - return x + di - - -model = NGRC(3) - - -# Training # -# -------- # - -# warm-up -trainer = bp.RidgeTrainer(model, alpha=1e-5, jit=True) -outputs = trainer.predict(X_warmup) -print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) - -# training -trainer.fit([X_train, {'di': dX_train}]) -plot_weights(model.di.W, model.r.get_feature_names(for_plot=True), model.di.b) - -# prediction -shared = dict() -model_jit = bm.jit(model) -outputs = [model_jit(shared, X_test[:, 0])] -for i in range(1, X_test.shape[1]): - outputs.append(model_jit(shared, outputs[i - 1])) -outputs = bm.asarray(outputs).squeeze() -print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) -plot_double_scroll(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs)) diff --git a/examples/dynamics_training/Gauthier_2021_ngrc_lorenz.py b/examples/dynamics_training/Gauthier_2021_ngrc_lorenz.py deleted file mode 100644 index b9b400b73..000000000 --- a/examples/dynamics_training/Gauthier_2021_ngrc_lorenz.py +++ /dev/null @@ -1,148 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Implementation of the paper: - -- Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir - computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 - -The main task is forecasting the Lorenz63 strange attractor. -""" - - -import matplotlib.pyplot as plt -import numpy as np - -import jax.numpy as jnp -import brainpy as bp -import brainpy.math as bm -import brainpy_datasets as bd - - -bm.set_environment(bm.batching_mode, x64=True) - - -def get_subset(data, start, end): - res = {'x': data.xs[start: end], - 'y': data.ys[start: end], - 'z': data.zs[start: end]} - res = jnp.hstack([res['x'], res['y'], res['z']]) - return res.reshape((1, ) + res.shape) - - -def plot_weights(Wout, coefs, bias=None): - Wout = np.asarray(Wout) - if bias is not None: - bias = np.asarray(bias) - Wout = np.concatenate([bias.reshape((1, 3)), Wout], axis=0) - coefs.insert(0, 'bias') - x_Wout, y_Wout, z_Wout = Wout[:, 0], Wout[:, 1], Wout[:, 2] - - fig = plt.figure(figsize=(10, 10)) - ax = fig.add_subplot(131) - ax.grid(axis="y") - ax.set_xlabel("$[W_{out}]_x$") - ax.set_ylabel("Features") - ax.set_yticks(np.arange(len(coefs))) - ax.set_yticklabels(coefs) - ax.barh(np.arange(x_Wout.size), x_Wout) - - ax1 = fig.add_subplot(132) - ax1.grid(axis="y") - ax1.set_yticks(np.arange(len(coefs))) - ax1.set_xlabel("$[W_{out}]_y$") - ax1.barh(np.arange(y_Wout.size), y_Wout) - - ax2 = fig.add_subplot(133) - ax2.grid(axis="y") - ax2.set_yticks(np.arange(len(coefs))) - ax2.set_xlabel("$[W_{out}]_z$") - ax2.barh(np.arange(z_Wout.size), z_Wout) - - plt.show() - - -def plot_lorenz(ground_truth, predictions): - fig = plt.figure(figsize=(15, 10)) - ax = fig.add_subplot(121, projection='3d') - ax.set_title("Generated attractor") - ax.set_xlabel("$x$") - ax.set_ylabel("$y$") - ax.set_zlabel("$z$") - ax.grid(False) - ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2]) - - ax2 = fig.add_subplot(122, projection='3d') - ax2.set_title("Real attractor") - ax2.grid(False) - ax2.plot(ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2]) - plt.show() - - -dt = 0.01 -t_warmup = 5. # ms -t_train = 10. # ms -t_test = 120. # ms -num_warmup = int(t_warmup / dt) # warm up NVAR -num_train = int(t_train / dt) -num_test = int(t_test / dt) - -# Datasets # -# -------- # -lorenz_series = bd.chaos.LorenzEq(t_warmup + t_train + t_test, - dt=dt, - inits={'x': 17.67715816276679, - 'y': 12.931379185960404, - 'z': 43.91404334248268}) - -X_warmup = get_subset(lorenz_series, 0, num_warmup - 1) -Y_warmup = get_subset(lorenz_series, 1, num_warmup) -X_train = get_subset(lorenz_series, num_warmup - 1, num_warmup + num_train - 1) -# Target: Lorenz[t] - Lorenz[t - 1] -dX_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train) - X_train -X_test = get_subset(lorenz_series, - num_warmup + num_train - 1, - num_warmup + num_train + num_test - 1) -Y_test = get_subset(lorenz_series, - num_warmup + num_train, - num_warmup + num_train + num_test) - - -# Model # -# ----- # -class NGRC(bp.DynamicalSystem): - def __init__(self, num_in): - super(NGRC, self).__init__() - self.r = bp.layers.NVAR(num_in, delay=2, order=2, constant=True) - self.di = bp.layers.Dense(self.r.num_out, num_in, b_initializer=None, - mode=bm.training_mode) - - def update(self, sha, x): - dx = self.di(sha, self.r(sha, x)) - return x + dx - - -model = NGRC(3) -print(model.r.get_feature_names(for_plot=True)) - - -# Training # -# -------- # - -# warm-up -trainer = bp.RidgeTrainer(model) -outputs = trainer.predict(X_warmup) -print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) - -# training -trainer.fit([X_train, dX_train]) -plot_weights(model.di.W, model.r.get_feature_names(for_plot=True), model.di.b) - -# prediction -shared = dict() -model_jit = bm.jit(model) -outputs = [model_jit(shared, X_test[:, 0])] -for i in range(1, X_test.shape[1]): - outputs.append(model_jit(shared, outputs[i - 1])) -outputs = bm.asarray(outputs) -print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) -plot_lorenz(bm.as_numpy(Y_test).squeeze(), bm.as_numpy(outputs).squeeze()) diff --git a/examples/dynamics_training/Gauthier_2021_ngrc_lorenz_inference.py b/examples/dynamics_training/Gauthier_2021_ngrc_lorenz_inference.py deleted file mode 100644 index 9577e8e78..000000000 --- a/examples/dynamics_training/Gauthier_2021_ngrc_lorenz_inference.py +++ /dev/null @@ -1,179 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Implementation of the paper: - -- Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir - computing. Nat Commun 12, 5564 (2021). https://doi.org/10.1038/s41467-021-25801-2 - -The main task is forecasting the Lorenz63 strange attractor. -""" - -import brainpy_datasets as bp_data -import matplotlib.pyplot as plt -import numpy as np - -import jax.numpy as jnp -import brainpy as bp -import brainpy.math as bm - - -bm.set_environment(bm.batching_mode, x64=True) - - -def get_subset(data, start, end): - res = {'x': data.xs[start: end], - 'y': data.ys[start: end], - 'z': data.zs[start: end]} - X = jnp.hstack([res['x'], res['y']]) - X = X.reshape((1,) + X.shape) - Y = res['z'] - Y = Y.reshape((1,) + Y.shape) - return X, Y - - -def plot_lorenz(x, y, true_z, predict_z, linewidth=.8): - fig1 = plt.figure() - fig1.set_figheight(8) - fig1.set_figwidth(12) - - t_all = t_warmup + t_train + t_test - ts = np.arange(0, t_all, dt) - - h = 240 - w = 2 - - # top left of grid is 0,0 - axs1 = plt.subplot2grid(shape=(h, w), loc=(0, 0), colspan=2, rowspan=30) - axs2 = plt.subplot2grid(shape=(h, w), loc=(36, 0), colspan=2, rowspan=30) - axs3 = plt.subplot2grid(shape=(h, w), loc=(72, 0), colspan=2, rowspan=30) - axs4 = plt.subplot2grid(shape=(h, w), loc=(132, 0), colspan=2, rowspan=30) - axs5 = plt.subplot2grid(shape=(h, w), loc=(168, 0), colspan=2, rowspan=30) - axs6 = plt.subplot2grid(shape=(h, w), loc=(204, 0), colspan=2, rowspan=30) - - # training phase x - axs1.set_title('training phase') - axs1.plot(ts[num_warmup:num_warmup + num_train], - x[num_warmup:num_warmup + num_train], - color='b', linewidth=linewidth) - axs1.set_ylabel('x') - axs1.axes.xaxis.set_ticklabels([]) - axs1.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05) - axs1.axes.set_ybound(-21., 21.) - axs1.text(-.14, .9, 'a)', ha='left', va='bottom', transform=axs1.transAxes) - - # training phase y - axs2.plot(ts[num_warmup:num_warmup + num_train], - y[num_warmup:num_warmup + num_train], - color='b', linewidth=linewidth) - axs2.set_ylabel('y') - axs2.axes.xaxis.set_ticklabels([]) - axs2.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05) - axs2.axes.set_ybound(-26., 26.) - axs2.text(-.14, .9, 'b)', ha='left', va='bottom', transform=axs2.transAxes) - - # training phase z - axs3.plot(ts[num_warmup:num_warmup + num_train], - true_z[num_warmup:num_warmup + num_train], - color='b', linewidth=linewidth) - axs3.plot(ts[num_warmup:num_warmup + num_train], - predict_z[num_warmup:num_warmup + num_train], - color='r', linewidth=linewidth) - axs3.set_ylabel('z') - axs3.set_xlabel('time') - axs3.axes.set_xbound(t_warmup - .08, t_warmup + t_train + .05) - axs3.axes.set_ybound(3., 48.) - axs3.text(-.14, .9, 'c)', ha='left', va='bottom', transform=axs3.transAxes) - - # testing phase x - axs4.set_title('testing phase') - axs4.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], - x[num_warmup + num_train:num_warmup + num_train + num_test], - color='b', linewidth=linewidth) - axs4.set_ylabel('x') - axs4.axes.xaxis.set_ticklabels([]) - axs4.axes.set_ybound(-21., 21.) - axs4.axes.set_xbound(t_warmup + t_train - .5, t_all + .5) - axs4.text(-.14, .9, 'd)', ha='left', va='bottom', transform=axs4.transAxes) - - # testing phase y - axs5.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], - y[num_warmup + num_train:num_warmup + num_train + num_test], - color='b', linewidth=linewidth) - axs5.set_ylabel('y') - axs5.axes.xaxis.set_ticklabels([]) - axs5.axes.set_ybound(-26., 26.) - axs5.axes.set_xbound(t_warmup + t_train - .5, t_all + .5) - axs5.text(-.14, .9, 'e)', ha='left', va='bottom', transform=axs5.transAxes) - - # testing phose z - axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], - true_z[num_warmup + num_train:num_warmup + num_train + num_test], - color='b', linewidth=linewidth) - axs6.plot(ts[num_warmup + num_train:num_warmup + num_train + num_test], - predict_z[num_warmup + num_train:num_warmup + num_train + num_test], - color='r', linewidth=linewidth) - axs6.set_ylabel('z') - axs6.set_xlabel('time') - axs6.axes.set_ybound(3., 48.) - axs6.axes.set_xbound(t_warmup + t_train - .5, t_all + .5) - axs6.text(-.14, .9, 'f)', ha='left', va='bottom', transform=axs6.transAxes) - - plt.show() - - -dt = 0.02 -t_warmup = 10. # ms -t_train = 20. # ms -t_test = 50. # ms -num_warmup = int(t_warmup / dt) # warm up NVAR -num_train = int(t_train / dt) -num_test = int(t_test / dt) - -# Datasets # -# -------- # -lorenz_series = bp_data.chaos.LorenzEq(t_warmup + t_train + t_test, - dt=dt, - inits={'x': 17.67715816276679, - 'y': 12.931379185960404, - 'z': 43.91404334248268}) - -X_warmup, Y_warmup = get_subset(lorenz_series, 0, num_warmup) -X_train, Y_train = get_subset(lorenz_series, num_warmup, num_warmup + num_train) -X_test, Y_test = get_subset(lorenz_series, 0, num_warmup + num_train + num_test) - - -# Model # -# ----- # - -class NGRC(bp.DynamicalSystem): - def __init__(self, num_in): - super(NGRC, self).__init__() - self.r = bp.layers.NVAR(num_in, delay=4, order=2, stride=5) - self.o = bp.layers.Dense(self.r.num_out, 1, mode=bm.training_mode) - - def update(self, sha, x): - return self.o(sha, self.r(sha, x)) - - -model = NGRC(2) - -# Training # -# -------- # - -trainer = bp.RidgeTrainer(model, alpha=0.05) - -# warm-up -outputs = trainer.predict(X_warmup) -print('Warmup NMS: ', bp.losses.mean_squared_error(outputs, Y_warmup)) - -# training -trainer.fit([X_train, Y_train]) - -# prediction -outputs = trainer.predict(X_test, reset_state=True) -print('Prediction NMS: ', bp.losses.mean_squared_error(outputs, Y_test)) - -plot_lorenz(x=bm.as_numpy(lorenz_series.xs).flatten(), - y=bm.as_numpy(lorenz_series.ys).flatten(), - true_z=bm.as_numpy(lorenz_series.zs).flatten(), - predict_z=bm.as_numpy(outputs).flatten()) From 9b133217b7c577f7b624196877864c995d75d6e1 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 27 Jan 2023 16:52:41 +0800 Subject: [PATCH 4/4] fix jax import bug when `jax>=0.4.2` --- .../_src/math/object_transform/autograd.py | 116 +++++++++++++++++- 1 file changed, 111 insertions(+), 5 deletions(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 911f20cdb..15dc24341 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -1,17 +1,16 @@ # -*- coding: utf-8 -*- +import inspect from functools import partial, wraps from typing import Union, Callable, Dict, Sequence, Any, Optional import jax import numpy as np -from jax import linear_util, dtypes, vmap, numpy as jnp -from jax._src.api import (_vjp, _jvp, - _check_callable, - _check_output_dtype_jacrev, _check_input_dtype_jacrev, - _check_output_dtype_jacfwd, _check_input_dtype_jacfwd, ) +from jax import linear_util, dtypes, vmap, numpy as jnp, core +from jax._src.api import (_vjp, _jvp) from jax.api_util import argnums_partial from jax.errors import UnexpectedTracerError +from jax.interpreters import xla from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_transpose, tree_structure) @@ -711,3 +710,110 @@ def vector_grad( argnums=argnums, return_value=return_value, has_aux=False if has_aux is None else has_aux) + + +def _check_callable(fun): + # In Python 3.10+, the only thing stopping us from supporting staticmethods + # is that we can't take weak references to them, which the C++ JIT requires. + if isinstance(fun, staticmethod): + raise TypeError(f"staticmethod arguments are not supported, got {fun}") + if not callable(fun): + raise TypeError(f"Expected a callable value, got {fun}") + if _isgeneratorfunction(fun): + raise TypeError(f"Expected a function, got a generator function: {fun}") + + +def _isgeneratorfunction(fun): + # re-implemented here because of https://bugs.python.org/issue33261 + while inspect.ismethod(fun): + fun = fun.__func__ + while isinstance(fun, partial): + fun = fun.func + return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR) + + +def _check_arg(arg): + if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)): + raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.") + + +def _valid_jaxtype(arg): + try: + xla.abstractify(arg) # faster than core.get_aval + except TypeError: + return core.valid_jaxtype(arg) + else: + return True + + +def _check_output_dtype_revderiv(name, holomorphic, x): + aval = core.get_aval(x) + if core.is_opaque_dtype(aval.dtype): + raise TypeError( + f"{name} with output element type {aval.dtype.name}") + if holomorphic: + if not dtypes.issubdtype(aval.dtype, np.complexfloating): + raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, " + f"but got {aval.dtype.name}.") + elif dtypes.issubdtype(aval.dtype, np.complexfloating): + raise TypeError(f"{name} requires real-valued outputs (output dtype that is " + f"a sub-dtype of np.floating), but got {aval.dtype.name}. " + "For holomorphic differentiation, pass holomorphic=True. " + "For differentiation of non-holomorphic functions involving complex " + "outputs, use jax.vjp directly.") + elif not dtypes.issubdtype(aval.dtype, np.floating): + raise TypeError(f"{name} requires real-valued outputs (output dtype that is " + f"a sub-dtype of np.floating), but got {aval.dtype.name}. " + "For differentiation of functions with integer outputs, use " + "jax.vjp directly.") + + +def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): + _check_arg(x) + aval = core.get_aval(x) + if core.is_opaque_dtype(aval.dtype): + raise TypeError( + f"{name} with input element type {aval.dtype.name}") + if holomorphic: + if not dtypes.issubdtype(aval.dtype, np.complexfloating): + raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " + f"but got {aval.dtype.name}.") + if (dtypes.issubdtype(aval.dtype, np.integer) or + dtypes.issubdtype(aval.dtype, np.bool_)): + if not allow_int: + raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype " + f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " + "If you want to use Boolean- or integer-valued inputs, use vjp " + "or set allow_int to True.") + elif not dtypes.issubdtype(aval.dtype, np.inexact): + raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a " + f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.") + + +_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev") +_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev") + + +def _check_output_dtype_jacfwd(holomorphic, x): + aval = core.get_aval(x) + if holomorphic: + if not dtypes.issubdtype(aval.dtype, np.complexfloating): + raise TypeError("jacfwd with holomorphic=True requires outputs with complex dtype, " + f"but got {aval.dtype.name}.") + + +def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None: + _check_arg(x) + aval = core.get_aval(x) + if core.is_opaque_dtype(aval.dtype): + raise TypeError(f"jacfwd with input element type {aval.dtype.name}") + if holomorphic: + if not dtypes.issubdtype(aval.dtype, np.complexfloating): + raise TypeError("jacfwd with holomorphic=True requires inputs with complex " + f"dtype, but got {aval.dtype.name}.") + elif not dtypes.issubdtype(aval.dtype, np.floating): + raise TypeError("jacfwd requires real-valued inputs (input dtype that is " + f"a sub-dtype of np.floating), but got {aval.dtype.name}. " + "For holomorphic differentiation, pass holomorphic=True. " + "For differentiation of non-holomorphic functions involving " + "complex inputs or integer inputs, use jax.jvp directly.")