diff --git a/brainpy/__init__.py b/brainpy/__init__.py index f2ed909ce..8c666517e 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.2.3.1" +__version__ = "2.2.3.2" try: @@ -28,9 +28,9 @@ >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -Note that the versions of "jax" and "jaxlib" should be consistent, like "jax=0.3.14", "jaxlib=0.3.14". +Note that the versions of "jax" and "jaxlib" should be consistent, like "jax=0.3.14" and "jaxlib=0.3.14". -More detail installation instruction, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax +For more detail installation instructions, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax ''') from None diff --git a/brainpy/dyn/neurons/biological_models.py b/brainpy/dyn/neurons/biological_models.py index fdd7e6a27..24606ee66 100644 --- a/brainpy/dyn/neurons/biological_models.py +++ b/brainpy/dyn/neurons/biological_models.py @@ -4,11 +4,11 @@ import brainpy.math as bm from brainpy.dyn.base import NeuGroup -from brainpy.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable +from brainpy.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_ from brainpy.integrators.joint_eq import JointEq from brainpy.integrators.ode import odeint from brainpy.integrators.sde import sdeint -from brainpy.modes import Mode, BatchingMode, TrainingMode, NormalMode, normal, check_mode +from brainpy.modes import Mode, BatchingMode, NormalMode, normal, check_mode from brainpy.tools.checking import check_initializer from brainpy.types import Shape, Array @@ -243,18 +243,18 @@ def __init__( self._V_initializer = V_initializer # variables - self.V = variable(self._V_initializer, mode, self.varshape) + self.V = variable_(self._V_initializer, self.varshape, mode) self.m = (bm.Variable(self.m_inf(self.V.value)) if m_initializer is None else - variable(self._m_initializer, mode, self.varshape)) + variable_(self._m_initializer, self.varshape, mode)) self.h = (bm.Variable(self.h_inf(self.V.value)) if h_initializer is None else - variable(self._h_initializer, mode, self.varshape)) + variable_(self._h_initializer, self.varshape, mode)) self.n = (bm.Variable(self.n_inf(self.V.value)) if n_initializer is None else - variable(self._n_initializer, mode, self.varshape)) - self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) + variable_(self._n_initializer, self.varshape, mode)) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) # integral if self.noise is None: @@ -281,21 +281,21 @@ def __init__( dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n def reset_state(self, batch_size=None): - self.V.value = variable(self._V_initializer, batch_size, self.varshape) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) if self._m_initializer is None: self.m.value = self.m_inf(self.V.value) else: - self.m.value = variable(self._m_initializer, batch_size, self.varshape) + self.m.value = variable_(self._m_initializer, self.varshape, batch_size) if self._h_initializer is None: self.h.value = self.h_inf(self.V.value) else: - self.h.value = variable(self._h_initializer, batch_size, self.varshape) + self.h.value = variable_(self._h_initializer, self.varshape, batch_size) if self._n_initializer is None: self.n.value = self.n_inf(self.V.value) else: - self.n.value = variable(self._n_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) - self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.n.value = variable_(self._n_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, m, h, n, I_ext): I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) @@ -452,10 +452,10 @@ def __init__( self._V_initializer = V_initializer # variables - self.W = variable(self._W_initializer, mode, self.varshape) - self.V = variable(self._V_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) - self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.W = variable_(self._W_initializer, self.varshape, mode) + self.V = variable_(self._V_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -464,10 +464,10 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.W.value = variable(self._W_initializer, batch_size, self.varshape) - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) - self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.W.value = variable_(self._W_initializer, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, W, I_ext): M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) @@ -718,16 +718,16 @@ def __init__( self._Ca_initializer = Ca_initializer # variables - self.Vs = variable(self._Vs_initializer, mode, self.varshape) - self.Vd = variable(self._Vd_initializer, mode, self.varshape) - self.Ca = variable(self._Ca_initializer, mode, self.varshape) + self.Vs = variable_(self._Vs_initializer, self.varshape, mode) + self.Vd = variable_(self._Vd_initializer, self.varshape, mode) + self.Ca = variable_(self._Ca_initializer, self.varshape, mode) self.h = bm.Variable(self.inf_h(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None) self.n = bm.Variable(self.inf_n(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None) self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None) self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None) self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if isinstance(mode, BatchingMode) else None) - self.Id = variable(bm.zeros, mode, self.varshape) # input to soma - self.Is = variable(bm.zeros, mode, self.varshape) # input to dendrite + self.Id = variable_(bm.zeros, self.varshape, mode) # input to soma + self.Is = variable_(bm.zeros, self.varshape, mode) # input to dendrite # self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool)) # integral @@ -737,17 +737,17 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.Vd.value = variable(self._Vd_initializer, batch_size, self.varshape) - self.Vs.value = variable(self._Vs_initializer, batch_size, self.varshape) - self.Ca.value = variable(self._Ca_initializer, batch_size, self.varshape) + self.Vd.value = variable_(self._Vd_initializer, self.varshape, batch_size) + self.Vs.value = variable_(self._Vs_initializer, self.varshape, batch_size) + self.Ca.value = variable_(self._Ca_initializer, self.varshape, batch_size) batch_axis = 0 if isinstance(self.mode, BatchingMode) else None self.h.value = bm.Variable(self.inf_h(self.Vs), batch_axis=batch_axis) self.n.value = bm.Variable(self.inf_n(self.Vs), batch_axis=batch_axis) self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis) self.c.value = bm.Variable(self.inf_c(self.Vd), batch_axis=batch_axis) self.q.value = bm.Variable(self.inf_q(self.Ca), batch_axis=batch_axis) - self.Id.value = variable(bm.zeros, batch_size, self.varshape) - self.Is.value = variable(bm.zeros, batch_size, self.varshape) + self.Id.value = variable_(bm.zeros, self.varshape, batch_size) + self.Is.value = variable_(bm.zeros, self.varshape, batch_size) # self.spike[:] = False def dCa(self, Ca, t, s, Vd): @@ -1017,11 +1017,11 @@ def __init__( self._V_initializer = V_initializer # variables - self.h = variable(self._h_initializer, mode, self.varshape) - self.n = variable(self._n_initializer, mode, self.varshape) - self.V = variable(self._V_initializer, mode, self.varshape) - self.input = variable(bm.zeros, mode, self.varshape) - self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.h = variable_(self._h_initializer, self.varshape, mode) + self.n = variable_(self._n_initializer, self.varshape, mode) + self.V = variable_(self._V_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -1030,11 +1030,11 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.h.value = variable(self._h_initializer, batch_size, self.varshape) - self.n.value = variable(self._n_initializer, batch_size, self.varshape) - self.V.value = variable(self._V_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) - self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.h.value = variable_(self._h_initializer, self.varshape, batch_size) + self.n.value = variable_(self._n_initializer, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def m_inf(self, V): alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) diff --git a/brainpy/dyn/neurons/input_groups.py b/brainpy/dyn/neurons/input_groups.py index 7eba286ff..fb95db1b8 100644 --- a/brainpy/dyn/neurons/input_groups.py +++ b/brainpy/dyn/neurons/input_groups.py @@ -7,7 +7,7 @@ import brainpy.math as bm from brainpy.dyn.base import NeuGroup from brainpy.errors import ModelBuildError -from brainpy.initialize import Initializer, parameter, variable +from brainpy.initialize import Initializer, parameter, variable_ from brainpy.modes import Mode, BatchingMode, normal from brainpy.types import Shape, Array @@ -139,7 +139,7 @@ def __init__( # variables self.i = bm.Variable(bm.zeros(1, dtype=bm.ditype())) - self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) if need_sort: sort_idx = bm.argsort(self.times) self.indices.value = self.indices[sort_idx] @@ -162,7 +162,7 @@ def body_fun(t): def reset_state(self, batch_size=None): self.i[0] = 1 - self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def update(self, tdi, x=None): self.spike[:] = False @@ -193,7 +193,7 @@ def __init__( self.freqs = parameter(freqs, self.num, allow_none=False) # variables - self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) self.rng = bm.random.RandomState(seed=seed) def update(self, tdi, x=None): @@ -205,5 +205,5 @@ def reset(self, batch_size=None): self.reset_state(batch_size) def reset_state(self, batch_size=None): - self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) diff --git a/brainpy/dyn/neurons/noise_groups.py b/brainpy/dyn/neurons/noise_groups.py index 0e6de7aeb..6fd7240df 100644 --- a/brainpy/dyn/neurons/noise_groups.py +++ b/brainpy/dyn/neurons/noise_groups.py @@ -63,13 +63,13 @@ def __init__( self.tau = init.parameter(tau, self.varshape, allow_none=False) # variables - self.x = init.variable(lambda s: bm.ones(s) * self.mean, mode, self.varshape) + self.x = init.variable_(lambda s: bm.ones(s) * self.mean, self.varshape, mode) # integral functions self.integral = sdeint(f=self.df, g=self.dg, method=method) def reset_state(self, batch_size=None): - self.x.value = init.variable(lambda s: bm.ones(s) * self.mean, batch_size, self.varshape) + self.x.value = init.variable_(lambda s: bm.ones(s) * self.mean, self.varshape, batch_size) def df(self, x, t): return (self.mean - x) / self.tau diff --git a/brainpy/dyn/neurons/reduced_models.py b/brainpy/dyn/neurons/reduced_models.py index f21dffcc7..53971b08c 100644 --- a/brainpy/dyn/neurons/reduced_models.py +++ b/brainpy/dyn/neurons/reduced_models.py @@ -8,7 +8,7 @@ import brainpy.math as bm from brainpy.dyn.base import NeuGroup from brainpy.initialize import (ZeroInit, OneInit, Initializer, - parameter, variable, variable2, noise as init_noise) + parameter, variable_, noise as init_noise) from brainpy.integrators import sdeint, odeint, JointEq from brainpy.modes import Mode, NormalMode, BatchingMode, TrainingMode, normal, check_mode from brainpy.tools.checking import check_initializer, check_callable @@ -101,8 +101,8 @@ def __init__( self._V_initializer = V_initializer # variables - self.V = variable2(self._V_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.V = variable_(self._V_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) # integral if self.noise is None: @@ -114,8 +114,8 @@ def derivative(self, V, t, I_ext): return (-V + self.V_rest + self.R * I_ext) / self.tau def reset_state(self, batch_size=None): - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) def update(self, tdi, x=None): if x is not None: self.input += x @@ -224,13 +224,13 @@ def __init__( self._V_initializer = V_initializer # variables - self.V = variable2(self._V_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.V = variable_(self._V_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) if self.tau_ref is not None: - self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) - self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -242,13 +242,13 @@ def derivative(self, V, t, I_ext): return (-V + self.V_rest + self.R * I_ext) / self.tau def reset_state(self, batch_size=None): - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) - self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt @@ -438,13 +438,13 @@ def __init__( self._V_initializer = V_initializer # variables - self.V = variable2(V_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.V = variable_(V_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) - self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode) if self.tau_ref is not None: - self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -453,13 +453,13 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) - self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def derivative(self, V, t, I_ext): exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) @@ -615,14 +615,14 @@ def __init__( self._w_initializer = w_initializer # variables - self.V = variable2(V_initializer, self.varshape, mode) - self.w = variable2(w_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.V = variable_(V_initializer, self.varshape, mode) + self.w = variable_(w_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(mode, BatchingMode) else bool - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) if self.tau_ref is not None: - self.refractory = variable2(partial(bm.zeros, dtype=bool), self.varshape, mode) - self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e8, self.varshape, mode) + self.refractory = variable_(partial(bm.zeros, dtype=bool), self.varshape, mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, mode) # functions if self.noise is None: @@ -631,16 +631,16 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.w.value = variable2(self._w_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=(bm.dftype() + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.w.value = variable_(self._w_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=(bm.dftype() if isinstance(self.mode, TrainingMode) else bool)), self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable2(partial(bm.zeros, dtype=bool), self.varshape, batch_size) - self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e8, self.varshape, batch_size) + self.refractory.value = variable_(partial(bm.zeros, dtype=bool), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, batch_size) def dV(self, V, t, w, I_ext): exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) @@ -782,13 +782,13 @@ def __init__( self._V_initializer = V_initializer # variables - self.V = variable2(V_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.V = variable_(V_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) - self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode) if self.tau_ref is not None: - self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -797,13 +797,13 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) - self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def derivative(self, V, t, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau @@ -954,12 +954,12 @@ def __init__( self._w_initializer = w_initializer # variables - self.V = variable2(V_initializer, self.varshape, mode) - self.w = variable2(w_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.V = variable_(V_initializer, self.varshape, mode) + self.w = variable_(w_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) - self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -968,12 +968,12 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.w.value = variable2(self._w_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.w.value = variable_(self._w_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) - self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, w, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau @@ -1148,13 +1148,13 @@ def __init__( self._Vth_initializer = Vth_initializer # variables - self.I1 = variable2(I1_initializer, self.varshape, mode) - self.I2 = variable2(I2_initializer, self.varshape, mode) - self.V_th = variable2(Vth_initializer, self.varshape, mode) - self.V = variable2(V_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.I1 = variable_(I1_initializer, self.varshape, mode) + self.I2 = variable_(I2_initializer, self.varshape, mode) + self.V_th = variable_(Vth_initializer, self.varshape, mode) + self.V = variable_(V_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) # integral if self.noise is None: @@ -1163,13 +1163,13 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.I1.value = variable2(self._I1_initializer, self.varshape, batch_size) - self.I2.value = variable2(self._I2_initializer, self.varshape, batch_size) - self.V_th.value = variable2(self._Vth_initializer, self.varshape, batch_size) - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.I1.value = variable_(self._I1_initializer, self.varshape, batch_size) + self.I2.value = variable_(self._I2_initializer, self.varshape, batch_size) + self.V_th.value = variable_(self._Vth_initializer, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dI1(self, I1, t): return - self.k1 * I1 @@ -1303,14 +1303,14 @@ def __init__( self._a_initializer = a_initializer # variables - self.a = variable2(a_initializer, self.varshape, mode) - self.V = variable2(V_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.a = variable_(a_initializer, self.varshape, mode) + self.V = variable_(V_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) if self.tau_ref is not None: - self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) - self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # integral if self.noise is None: @@ -1329,14 +1329,14 @@ def derivative(self): return JointEq([self.dV, self.da]) def reset_state(self, batch_size=None): - self.a.value = variable2(self._a_initializer, self.varshape, batch_size) - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.a.value = variable_(self._a_initializer, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) - self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def update(self, tdi, x=None): t, dt = tdi.t, tdi.dt @@ -1493,14 +1493,14 @@ def __init__( self._u_initializer = u_initializer # variables - self.u = variable2(u_initializer, self.varshape, mode) - self.V = variable2(V_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.u = variable_(u_initializer, self.varshape, mode) + self.V = variable_(V_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) if self.tau_ref is not None: - self.t_last_spike = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, mode) - self.refractory = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode) # functions if self.noise is None: @@ -1509,14 +1509,14 @@ def __init__( self.integral = sdeint(method=method, f=JointEq([self.dV, self.du]), g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.u.value = variable2(self._u_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.u.value = variable_(self._u_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable2(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) - self.refractory.value = variable2(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, u, I_ext): dVdt = 0.04 * V * V + 5 * V + 140 - u + I_ext @@ -1727,12 +1727,12 @@ def __init__( self._z_initializer = z_initializer # variables - self.V = variable2(self._V_initializer, self.varshape, mode) - self.y = variable2(self._y_initializer, self.varshape, mode) - self.z = variable2(self._z_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.V = variable_(self._V_initializer, self.varshape, mode) + self.y = variable_(self._y_initializer, self.varshape, mode) + self.z = variable_(self._z_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) # integral if self.noise is None: @@ -1741,12 +1741,12 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.y.value = variable2(self._y_initializer, self.varshape, batch_size) - self.z.value = variable2(self._z_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.y.value = variable_(self._y_initializer, self.varshape, batch_size) + self.z.value = variable_(self._z_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dV(self, V, t, y, z, I_ext): return y - self.a * V * V * V + self.b * V * V - z + I_ext @@ -1900,11 +1900,11 @@ def __init__( self._w_initializer = w_initializer # variables - self.V = variable2(self._V_initializer, self.varshape, mode) - self.w = variable2(self._w_initializer, self.varshape, mode) - self.input = variable2(bm.zeros, self.varshape, mode) + self.V = variable_(self._V_initializer, self.varshape, mode) + self.w = variable_(self._w_initializer, self.varshape, mode) + self.input = variable_(bm.zeros, self.varshape, mode) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) # integral if self.noise is None: @@ -1913,11 +1913,11 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable2(self._V_initializer, self.varshape, batch_size) - self.w.value = variable2(self._w_initializer, self.varshape, batch_size) - self.input.value = variable2(bm.zeros, self.varshape, batch_size) + self.V.value = variable_(self._V_initializer, self.varshape, batch_size) + self.w.value = variable_(self._w_initializer, self.varshape, batch_size) + self.input.value = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool - self.spike.value = variable2(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dV(self, V, t, w, I_ext): return V - V * V * V / 3 - w + I_ext diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index 17c7b84aa..54785b02c 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -3,18 +3,19 @@ from typing import Union, Dict, Callable, Optional from jax import vmap -from jax.lax import stop_gradient, cond +from jax.lax import stop_gradient import brainpy.math as bm from brainpy.connect import TwoEndConnector, All2All, One2One from brainpy.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn, SynConn -from brainpy.initialize import Initializer, variable +from brainpy.initialize import Initializer, variable_ from brainpy.integrators import odeint, JointEq from brainpy.tools.checking import check_integer, check_float from brainpy.modes import Mode, BatchingMode, normal, NormalMode, check_mode from brainpy.types import Array from ..synouts import CUBA, MgBlock + __all__ = [ 'Delta', 'Exponential', @@ -298,14 +299,14 @@ def __init__( self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='csr') # variables - self.g = variable(bm.zeros, mode, self.post.num) + self.g = variable_(bm.zeros, self.post.num, mode) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # function self.integral = odeint(lambda g, t: -g / self.tau, method=method) def reset_state(self, batch_size=None): - self.g.value = variable(bm.zeros, batch_size, self.post.num) + self.g.value = variable_(bm.zeros, self.post.num, batch_size) self.output.reset_state(batch_size) if self.stp is not None: self.stp.reset_state(batch_size) @@ -489,16 +490,16 @@ def __init__( self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') # variables - self.h = variable(bm.zeros, mode, self.pre.num) - self.g = variable(bm.zeros, mode, self.pre.num) + self.h = variable_(bm.zeros, self.pre.num, mode) + self.g = variable_(bm.zeros, self.pre.num, mode) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # integral self.integral = odeint(method=method, f=JointEq([self.dg, self.dh])) def reset_state(self, batch_size=None): - self.h.value = variable(bm.zeros, batch_size, self.pre.num) - self.g.value = variable(bm.zeros, batch_size, self.pre.num) + self.h.value = variable_(bm.zeros, self.pre.num, batch_size) + self.g.value = variable_(bm.zeros, self.pre.num, batch_size) self.output.reset_state(batch_size) if self.stp is not None: self.stp.reset_state(batch_size) @@ -831,8 +832,8 @@ def __init__( self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij') # variables - self.g = variable(bm.zeros, mode, self.pre.num) - self.x = variable(bm.zeros, mode, self.pre.num) + self.g = variable_(bm.zeros, self.pre.num, mode) + self.x = variable_(bm.zeros, self.pre.num, mode) self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) # integral @@ -845,8 +846,8 @@ def dx(self, x, t): return -x / self.tau_rise def reset_state(self, batch_size=None): - self.g.value = variable(bm.zeros, batch_size, self.pre.num) - self.x.value = variable(bm.zeros, batch_size, self.pre.num) + self.g.value = variable_(bm.zeros, self.pre.num, batch_size) + self.x.value = variable_(bm.zeros, self.pre.num, batch_size) self.output.reset_state(batch_size) if self.stp is not None: self.stp.reset_state(batch_size) @@ -921,6 +922,8 @@ def __init__( ): from ..neurons.input_groups import InputGroup, OutputGroup super(PoissonInput, self).__init__(InputGroup(1), OutputGroup(1), name=name, mode=mode) + self.pre = None + self.post = None # check data if not isinstance(target_var, bm.Variable): @@ -929,7 +932,7 @@ def __init__( check_integer(num_input, 'num_input', min_bound=1) check_float(freq, 'freq', min_bound=0., allow_int=True) check_float(weight, 'weight', allow_int=True) - check_mode(mode, NormalMode, name=self.__class__.__name__) + check_mode(mode, (NormalMode, BatchingMode), name=self.__class__.__name__) # parameters self.target_var = target_var @@ -943,8 +946,27 @@ def update(self, tdi): p = self.freq * tdi.dt / 1e3 a = self.num_input * p b = self.num_input * (1 - p) - inp = bm.cond((a > 5) * (b > 5), - lambda _: self.rng.normal(a, b * p, self.target_var.shape), - lambda _: self.rng.binomial(self.num_input, p, self.target_var.shape), - None) - self.target_var += inp + if isinstance(tdi.dt, (int, float)): # dt is not in tracing + if (a > 5) and (b > 5): + inp = self.rng.normal(a, b * p, self.target_var.shape) + else: + inp = self.rng.binomial(self.num_input, p, self.target_var.shape) + + else: # dt is in tracing + inp = bm.cond((a > 5) * (b > 5), + lambda _: self.rng.normal(a, b * p, self.target_var.shape), + lambda _: self.rng.binomial(self.num_input, p, self.target_var.shape), + None) + self.target_var += inp * self.weight + + def __repr__(self): + names = self.__class__.__name__ + return f'{names}(name={self.name}, num_input={self.num_input}, freq={self.freq}, weight={self.weight})' + + def reset_state(self, batch_size=None): + pass + + def reset(self, batch_size=None): + self.rng.seed(self.seed) + self.reset_state(batch_size) + diff --git a/brainpy/initialize/generic.py b/brainpy/initialize/generic.py index ee0ed06c4..4dce2d115 100644 --- a/brainpy/initialize/generic.py +++ b/brainpy/initialize/generic.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import warnings from typing import Union, Callable, Optional import jax.numpy as jnp @@ -14,6 +14,7 @@ __all__ = [ 'parameter', 'variable', + 'variable_', 'variable2', 'noise', 'delay', @@ -81,12 +82,22 @@ def init_param( return parameter(param, size, allow_none) +def variable_( + data: Union[Callable, Array], + size: Shape = None, + batch_size_or_mode: Optional[Union[int, bool, Mode]] = None, + batch_axis: int = 0, +): + return variable(data, batch_size_or_mode, size, batch_axis) + + def variable2( data: Union[Callable, Array], size: Shape = None, batch_size_or_mode: Optional[Union[int, bool, Mode]] = None, batch_axis: int = 0, ): + warnings.warn('Use brainpy.init.variable_() instead. ', UserWarning) return variable(data, batch_size_or_mode, size, batch_axis) diff --git a/brainpy/math/controls.py b/brainpy/math/controls.py index 2587b5a79..ab1d1b923 100644 --- a/brainpy/math/controls.py +++ b/brainpy/math/controls.py @@ -616,9 +616,9 @@ def for_loop(body_fun: Callable, >>> a_hist = bm.for_loop(body, dyn_vars=[a, b], operands=bm.arange(1, 5)) >>> a_hist DeviceArray([[ 1.], - [ 3.], - [ 6.], - [10.]], dtype=float32) + [ 3.], + [ 6.], + [10.]], dtype=float32) >>> a Variable([10.], dtype=float32) >>> b diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index bb3119569..02b79d381 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -195,14 +195,19 @@ def __setitem__(self, index, value): # value is JaxArray if isinstance(value, JaxArray): value = value.value + # value is numpy.ndarray + elif isinstance(value, np.ndarray): + value = jnp.asarray(value) - # tuple index + # index is a tuple if isinstance(index, tuple): index = tuple(_check_input_array(x) for x in index) - - # JaxArray index + # index is JaxArray elif isinstance(index, JaxArray): index = index.value + # index is numpy.ndarray + elif isinstance(index, np.ndarray): + index = jnp.asarray(index) # update self._value = self._value.at[index].set(value) @@ -1569,6 +1574,8 @@ def __setitem__(self, index, value): # value is JaxArray if isinstance(value, JaxArray): value = value.value + elif isinstance(value, np.ndarray): + value = jnp.asarray(value) # tuple index if isinstance(index, tuple): diff --git a/brainpy/math/operators/__init__.py b/brainpy/math/operators/__init__.py index 517a0bc95..7466d06a7 100644 --- a/brainpy/math/operators/__init__.py +++ b/brainpy/math/operators/__init__.py @@ -3,21 +3,16 @@ from . import multiplication from . import op_register -from . import pre2syn as pre2syn_module -from . import pre2post as pre2post_module -from . import syn2post as syn2post_module +from . import pre_syn_post as pre_syn_post_module from . import wrap_jax from . import spikegrad __all__ = multiplication.__all__ + op_register.__all__ -__all__ += pre2syn_module.__all__ + pre2post_module.__all__ + syn2post_module.__all__ -__all__ += wrap_jax.__all__ + spikegrad.__all__ +__all__ += pre_syn_post_module.__all__ + wrap_jax.__all__ + spikegrad.__all__ from .multiplication import * from .op_register import * -from .pre2syn import * -from .pre2post import * -from .syn2post import * +from .pre_syn_post import * from .wrap_jax import * from .spikegrad import * diff --git a/brainpy/math/operators/pre2syn.py b/brainpy/math/operators/pre2syn.py deleted file mode 100644 index b60551d5b..000000000 --- a/brainpy/math/operators/pre2syn.py +++ /dev/null @@ -1,47 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax.numpy as jnp -from jax import vmap - -from brainpy.math.numpy_ops import as_device_array - -__all__ = [ - 'pre2syn' -] - - -_pre2syn = vmap(lambda pre_id, pre_vs: pre_vs[pre_id], in_axes=(0, None)) - - -def pre2syn(pre_values, pre_ids): - """The pre-to-syn computation. - - Change the pre-synaptic data to the data with the dimension of synapses. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - syn_val = np.zeros(len(pre_ids)) - for syn_i, pre_i in enumerate(pre_ids): - syn_val[i] = pre_values[pre_i] - - Parameters - ---------- - pre_values: float, jax.numpy.ndarray, JaxArray, Variable - The pre-synaptic value. - pre_ids: jax.numpy.ndarray, JaxArray - The pre-synaptic neuron index. - - Returns - ------- - syn_val: jax.numpy.ndarray, JaxArray - The synaptic value. - """ - pre_values = as_device_array(pre_values) - pre_ids = as_device_array(pre_ids) - if jnp.ndim(pre_values) == 0: - return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values - else: - return _pre2syn(pre_ids, pre_values) diff --git a/brainpy/math/operators/pre2post.py b/brainpy/math/operators/pre_syn_post.py similarity index 53% rename from brainpy/math/operators/pre2post.py rename to brainpy/math/operators/pre_syn_post.py index 9f45d998c..27411acdb 100644 --- a/brainpy/math/operators/pre2post.py +++ b/brainpy/math/operators/pre_syn_post.py @@ -1,18 +1,14 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Union, Tuple import jax.numpy as jnp -from jax import vmap, jit -from jax.lax import cond +from jax import vmap, jit, ops as jops from brainpy.errors import MathError from brainpy.math.jaxarray import JaxArray from brainpy.math.numpy_ops import as_device_array from brainpy.types import Array -from .pre2syn import pre2syn -from .syn2post import syn2post_mean from .utils import _check_brainpylib try: @@ -29,9 +25,20 @@ 'pre2post_mean', # pre-to-post event operator - 'pre2post_event_sum', - 'pre2post_event_prod', - + 'pre2post_csr_event_sum', 'pre2post_event_sum', + 'pre2post_coo_event_sum', + 'pre2post_csr_event_prod', 'pre2post_event_prod', + + # pre-to-syn + 'pre2syn', + + # syn-to-post + 'syn2post_sum', 'syn2post', + 'syn2post_prod', + 'syn2post_max', + 'syn2post_min', + 'syn2post_mean', + 'syn2post_softmax', ] @@ -42,11 +49,11 @@ def _raise_pre_ids_is_none(pre_ids): f'(brainpy.math.ndim(pre_values) != 0).') -def pre2post_event_sum(events: Array, - pre2post: Tuple[Array, Array], - post_num: int, - values: Union[float, Array] = 1.): - """The pre-to-post synaptic computation with event-driven summation. +def pre2post_csr_event_sum(events: Array, + pre2post: Tuple[Array, Array], + post_num: int, + values: Union[float, Array] = 1.): + """The pre-to-post event-driven synaptic summation with `CSR` synapse structure. When ``values`` is a scalar, this function is equivalent to @@ -91,7 +98,7 @@ def pre2post_event_sum(events: Array, out: JaxArray, jax.numpy.ndarray A tensor with the shape of ``post_num``. """ - _check_brainpylib(pre2post_event_sum.__name__) + _check_brainpylib('pre2post_event_sum') indices, idnptr = pre2post events = as_device_array(events) indices = as_device_array(indices) @@ -100,45 +107,24 @@ def pre2post_event_sum(events: Array, return brainpylib.event_sum(events, (indices, idnptr), post_num, values) -def pre2post_event_sum2(events: Array, - pre2post: Tuple[Array, Array], - post_num: int, - values: Union[float, Array] = 1.): - """The pre-to-post synaptic computation with event-driven summation. - - When ``values`` is a scalar, this function is equivalent to - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - post_ids, idnptr = pre2post - for i in range(pre_num): - if events[i]: - for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[i]] += values - - When ``values`` is a vector (with the length of ``len(post_ids)``), - this function is equivalent to - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) +pre2post_event_sum = pre2post_csr_event_sum - post_ids, idnptr = pre2post - for i in range(pre_num): - if events[i]: - for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[i]] += values[j] +def pre2post_coo_event_sum(events: Array, + pre_ids: Array, + post_ids: Array, + post_num: int, + values: Union[float, Array] = 1.): + """The pre-to-post synaptic computation with event-driven summation. Parameters ---------- events: Array The events, must be bool. - pre2post: tuple of Array, tuple of Array - A tuple contains the connection information of pre-to-post. + pre_ids: Array + Pre-synaptic ids. + post_ids: Array + Post-synaptic idsd. post_num: int The number of post-synaptic group. values: float, Array @@ -149,16 +135,15 @@ def pre2post_event_sum2(events: Array, out: JaxArray, jax.numpy.ndarray A tensor with the shape of ``post_num``. """ - _check_brainpylib(pre2post_event_sum.__name__) - indices, idnptr = pre2post + _check_brainpylib('pre2post_event_sum') events = as_device_array(events) - indices = as_device_array(indices) - idnptr = as_device_array(idnptr) + post_ids = as_device_array(post_ids) + pre_ids = as_device_array(pre_ids) values = as_device_array(values) - return brainpylib.event_sum2(events, (indices, idnptr), post_num, values) + return brainpylib.event_sum2(events, pre_ids, post_ids, post_num, values) -def pre2post_event_prod(events, pre2post, post_num, values=1.): +def pre2post_csr_event_prod(events, pre2post, post_num, values=1.): """The pre-to-post synaptic computation with event-driven production. When ``values`` is a scalar, this function is equivalent to @@ -204,7 +189,7 @@ def pre2post_event_prod(events, pre2post, post_num, values=1.): out: JaxArray, jax.numpy.ndarray A tensor with the shape of ``post_num``. """ - _check_brainpylib(pre2post_event_prod.__name__) + _check_brainpylib('pre2post_event_prod') indices, idnptr = pre2post events = as_device_array(events) indices = as_device_array(indices) @@ -213,6 +198,9 @@ def pre2post_event_prod(events, pre2post, post_num, values=1.): return brainpylib.event_prod(events, (indices, idnptr), post_num, values) +pre2post_event_prod = pre2post_csr_event_prod + + def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic summation. @@ -388,8 +376,8 @@ def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): pre_values = as_device_array(pre_values) post_ids = as_device_array(post_ids) if jnp.ndim(pre_values) == 0: - # return out.at[post_ids].set(pre_values) - return out.at[jnp.unique(post_ids)].set(pre_values) + return out.at[post_ids].set(pre_values) + # return out.at[jnp.unique(post_ids)].set(pre_values) else: _raise_pre_ids_is_none(pre_ids) pre_ids = as_device_array(pre_ids) @@ -428,62 +416,254 @@ def pre2post_matmul2(event, conn): return f1(jnp.arange(Cr.shape[1])) -def pre2post_matmul_mask(event, conn, mask): - event = event.value if isinstance(event, JaxArray) else event - Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0] - Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1] - Ml = mask[0].value if isinstance(mask[0], JaxArray) else mask[0] - Mr = mask[1].value if isinstance(mask[1], JaxArray) else mask[1] - if jnp.ndim(event) != 1: - raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}') - if jnp.ndim(Cl) != 2: - raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}') - if jnp.ndim(Cr) != 2: - raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}') - if jnp.ndim(Mr) != 2: - raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Mr)}') - if jnp.ndim(Ml) != 2: - raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Ml)}') +_pre2syn = vmap(lambda pre_id, pre_vs: pre_vs[pre_id], in_axes=(0, None)) - f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum() * (Ml[i] * Mr[:, j]).sum(), in_axes=(0, None)) - f1 = jit(vmap(lambda ii, j: f0(ii, j).sum(), in_axes=(None, 0))) - return f1(jnp.arange(Cl.shape[0]), jnp.arange(Cr.shape[1])) +def pre2syn(pre_values, pre_ids): + """The pre-to-syn computation. -def pre2post_matmul_mask2(event, conn, mask): - event = event.value if isinstance(event, JaxArray) else event - Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0] - Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1] - Ml = mask[0].value if isinstance(mask[0], JaxArray) else mask[0] - Mr = mask[1].value if isinstance(mask[1], JaxArray) else mask[1] - if jnp.ndim(event) != 1: - raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}') - if jnp.ndim(Cl) != 2: - raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}') - if jnp.ndim(Cr) != 2: - raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}') - if jnp.ndim(Mr) != 2: - raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Mr)}') - if jnp.ndim(Ml) != 2: - raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Ml)}') - - # f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum() * (Ml[i] * Mr[:, j]).sum(), in_axes=(0, None)) - @partial(vmap, in_axes=(0, None)) - def f0(i, j): - return cond(event[i], - lambda: cond(Ml[i] @ Mr[:, j], - lambda: (Cl[i] * Cr[:, j]).sum(), - lambda: 0.), - lambda: 0.) + Change the pre-synaptic data to the data with the dimension of synapses. - ii = jnp.arange(Cl.shape[0]) - jj = jnp.arange(Cr.shape[1]) + This function is equivalent to: - # def body(_, j): - # r = f0(ii, j).sum() - # return 0, r - # _, out = scan(body, 0, jj) - # return out + .. highlight:: python + .. code-block:: python + + syn_val = np.zeros(len(pre_ids)) + for syn_i, pre_i in enumerate(pre_ids): + syn_val[i] = pre_values[pre_i] - f = jit(vmap(lambda j: f0(ii, j).sum())) - return f(jj) + Parameters + ---------- + pre_values: float, jax.numpy.ndarray, JaxArray, Variable + The pre-synaptic value. + pre_ids: jax.numpy.ndarray, JaxArray + The pre-synaptic neuron index. + + Returns + ------- + syn_val: jax.numpy.ndarray, JaxArray + The synaptic value. + """ + pre_values = as_device_array(pre_values) + pre_ids = as_device_array(pre_ids) + if jnp.ndim(pre_values) == 0: + return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values + else: + return _pre2syn(pre_ids, pre_values) + + +_jit_seg_sum = jit(jops.segment_sum, static_argnums=(2, 3)) +_jit_seg_prod = jit(jops.segment_prod, static_argnums=(2, 3)) +_jit_seg_max = jit(jops.segment_max, static_argnums=(2, 3)) +_jit_seg_min = jit(jops.segment_min, static_argnums=(2, 3)) + + +def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post summation computation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] += syn_values[syn_i] + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. + post_num: int + The number of the post-synaptic neurons. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) + + +syn2post = syn2post_sum + + +def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post product computation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] *= syn_values[syn_i] + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted) + + +def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post maximum computation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i]) + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) + + +def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post minimization computation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i]) + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted) + + +def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post mean computation. + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) + denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted) + return jnp.nan_to_num(nominator / denominator) + + +def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=True): + """The syn-to-post softmax computation. + + Parameters + ---------- + syn_values: jax.numpy.ndarray, JaxArray, Variable + The synaptic values. + post_ids: jax.numpy.ndarray, JaxArray + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns + ------- + post_val: jax.numpy.ndarray, JaxArray + The post-synaptic value. + """ + post_ids = as_device_array(post_ids) + syn_values = as_device_array(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) + syn_values = syn_values - syn_maxs[post_ids] + syn_values = jnp.exp(syn_values) + normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) + softmax = syn_values / normalizers[post_ids] + return jnp.nan_to_num(softmax) diff --git a/brainpy/math/operators/syn2post.py b/brainpy/math/operators/syn2post.py deleted file mode 100644 index d022c14a1..000000000 --- a/brainpy/math/operators/syn2post.py +++ /dev/null @@ -1,235 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax.numpy as jnp -from jax import jit, vmap -from jax import ops as jops - -from brainpy.math.numpy_ops import as_device_array - - -_jit_seg_sum = jit(jops.segment_sum, static_argnums=(2, 3)) -_jit_seg_prod = jit(jops.segment_prod, static_argnums=(2, 3)) -_jit_seg_max = jit(jops.segment_max, static_argnums=(2, 3)) -_jit_seg_min = jit(jops.segment_min, static_argnums=(2, 3)) - - -__all__ = [ - 'syn2post_sum', 'syn2post', - 'syn2post_prod', - 'syn2post_max', - 'syn2post_min', - 'syn2post_mean', - 'syn2post_softmax', - -] - - -def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post summation computation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] += syn_values[syn_i] - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. - post_num: int - The number of the post-synaptic neurons. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) - - -syn2post = syn2post_sum - - -def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post product computation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] *= syn_values[syn_i] - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted) - - -def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post maximum computation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i]) - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) - - -def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post minimization computation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i]) - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted) - - -def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post mean computation. - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) - denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted) - return jnp.nan_to_num(nominator / denominator) - - -def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=True): - """The syn-to-post softmax computation. - - Parameters - ---------- - syn_values: jax.numpy.ndarray, JaxArray, Variable - The synaptic values. - post_ids: jax.numpy.ndarray, JaxArray - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns - ------- - post_val: jax.numpy.ndarray, JaxArray - The post-synaptic value. - """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) - syn_values = syn_values - syn_maxs[post_ids] - syn_values = jnp.exp(syn_values) - normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) - softmax = syn_values / normalizers[post_ids] - return jnp.nan_to_num(softmax) - diff --git a/brainpy/math/operators/wrap_jax.py b/brainpy/math/operators/wrap_jax.py index 432bcc8cd..3fcf043e0 100644 --- a/brainpy/math/operators/wrap_jax.py +++ b/brainpy/math/operators/wrap_jax.py @@ -24,6 +24,38 @@ def segment_sum(data: Union[JaxArray, jnp.ndarray], unique_indices: bool = False, bucket_size: Optional[int] = None, mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: + """``segment_sum`` operator for brainpy `JaxArray` and `Variable`. + + Parameters + ---------- + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + + Returns + ------- + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ return JaxArray(jops.segment_sum(data.value if isinstance(data, JaxArray) else data, segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, num_segments, @@ -39,6 +71,38 @@ def segment_prod(data: Union[JaxArray, jnp.ndarray], unique_indices: bool = False, bucket_size: Optional[int] = None, mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: + """``segment_prod`` operator for brainpy `JaxArray` and `Variable`. + + Parameters + ---------- + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + + Returns + ------- + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ return JaxArray(jops.segment_prod(data.value if isinstance(data, JaxArray) else data, segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, num_segments, @@ -54,6 +118,38 @@ def segment_max(data: Union[JaxArray, jnp.ndarray], unique_indices: bool = False, bucket_size: Optional[int] = None, mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: + """``segment_max`` operator for brainpy `JaxArray` and `Variable`. + + Parameters + ---------- + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + + Returns + ------- + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ return JaxArray(jops.segment_max(data.value if isinstance(data, JaxArray) else data, segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, num_segments, @@ -69,6 +165,38 @@ def segment_min(data: Union[JaxArray, jnp.ndarray], unique_indices: bool = False, bucket_size: Optional[int] = None, mode: Optional[lax.GatherScatterMode] = None) -> JaxArray: + """``segment_min`` operator for brainpy `JaxArray` and `Variable`. + + Parameters + ---------- + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + + Returns + ------- + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ return JaxArray(jops.segment_min(data.value if isinstance(data, JaxArray) else data, segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids, num_segments, diff --git a/brainpy/math/random.py b/brainpy/math/random.py index e833998a5..84d11b0f5 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -393,14 +393,16 @@ def __init__(self, seed=None): Parameters ---------- seed : int, jax.DeviceArray, Optional - The initial seed of the random number generator. + It can be an integer for initial seed of the random number generator, + or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype. """ if seed is None: seed = np.random.randint(0, 100000, 2, dtype=np.uint32) if isinstance(seed, int): key = jr.PRNGKey(seed) else: - assert len(seed) == 2 + if len(seed) != 2 and seed.dtype != np.uint32: + raise ValueError key = seed super(RandomState, self).__init__(key) @@ -408,16 +410,24 @@ def __init__(self, seed=None): # seed and random key # # ------------------- # - def seed(self, seed): + def seed(self, seed=None): """Sets a new random seed. Parameters ---------- - seed : int - The new initial seed of the random number generator. + seed : int, ndarray + It can be an integer for initial seed of the random number generator, + or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype. """ - if seed is not None: - self.value = jr.PRNGKey(seed) + if seed is None: + seed = np.random.randint(0, 100000, 2, dtype=np.uint32) + if isinstance(seed, int): + key = jr.PRNGKey(seed) + else: + if len(seed) != 2 and seed.dtype != np.uint32: + raise ValueError + key = seed + self.value = key def split_key(self): """Create a new seed from the current seed. diff --git a/brainpy/measure/__init__.py b/brainpy/measure/__init__.py index 31078b539..976345282 100644 --- a/brainpy/measure/__init__.py +++ b/brainpy/measure/__init__.py @@ -5,6 +5,10 @@ You can access them through ``brainpy.measure.XXX``. """ +from . import correlation, firings, lfp + from .correlation import * from .firings import * +from .lfp import * + diff --git a/brainpy/measure/correlation.py b/brainpy/measure/correlation.py index c742fa05b..49b89d39a 100644 --- a/brainpy/measure/correlation.py +++ b/brainpy/measure/correlation.py @@ -2,8 +2,8 @@ from functools import partial -import numpy as np -from jax import vmap, jit, lax, numpy as jnp +import numpy as onp +from jax import vmap, lax, numpy as jnp from brainpy import math as bm @@ -17,17 +17,7 @@ ] -# @jit -@partial(vmap, in_axes=(None, 0, 0)) -def _cc(states, i, j): - sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) - return lax.cond(sqrt_ij == 0., - lambda _: 0., - lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij, - None) - - -def cross_correlation(spikes, bin, dt=None): +def cross_correlation(spikes, bin, dt=None, numpy=True): r"""Calculate cross correlation index between neurons. The coherence [1]_ between two neurons i and j is measured by their @@ -47,14 +37,21 @@ def cross_correlation(spikes, bin, dt=None): average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the network. + .. note:: + To JIT compile this function, users should make ``bin``, ``dt``, ``numpy`` static. + For example, ``partial(brainpy.measure.cross_correlation, bin=10, numpy=False)``. + Parameters ---------- - spikes : + spikes : ndarray The history of spike states of the neuron group. bin : float, int The time bin to normalize spike states. dt : float, optional The time precision. + numpy: bool + Whether we use numpy array as the functional output. + If ``False``, this function can be JIT compiled. Returns ------- @@ -67,27 +64,30 @@ def cross_correlation(spikes, bin, dt=None): inhibition in a hippocampal interneuronal network model." Journal of neuroscience 16.20 (1996): 6402-6413. """ - spikes = bm.as_device_array(spikes) + spikes = bm.as_numpy(spikes) if numpy else bm.as_device_array(spikes) + np = onp if numpy else jnp dt = bm.get_dt() if dt is None else dt bin_size = int(bin / dt) num_hist, num_neu = spikes.shape - num_bin = int(np.ceil(num_hist / bin_size)) + num_bin = int(onp.ceil(num_hist / bin_size)) + + @partial(vmap, in_axes=(None, 0, 0)) + def _cc(states, i, j): + sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) + return lax.cond(sqrt_ij == 0., + lambda _: 0., + lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij, + None) + if num_bin * bin_size != num_hist: - spikes = jnp.append(spikes, jnp.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0) + spikes = np.append(spikes, np.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0) states = spikes.T.reshape((num_neu, num_bin, bin_size)) - states = jnp.asarray(jnp.sum(states, axis=2) > 0., dtype=jnp.float_) + states = jnp.asarray(np.sum(states, axis=2) > 0., dtype=jnp.float_) indices = jnp.tril_indices(num_neu, k=-1) - return jnp.mean(_cc(states, *indices)) - + return onp.mean(np.asarray(_cc(states, *indices))) -@partial(vmap, in_axes=(None, 0)) -def _var(neu_signal, i): - neu_signal = neu_signal[:, i] - return jnp.mean(neu_signal * neu_signal) - jnp.mean(neu_signal) ** 2 - -# @jit -def voltage_fluctuation(potentials): +def voltage_fluctuation(potentials, numpy=True): r"""Calculate neuronal synchronization via voltage variance. The method comes from [1]_ [2]_ [3]_. @@ -125,8 +125,11 @@ def voltage_fluctuation(potentials): Parameters ---------- - potentials : - The membrane potential matrix of the neuron group. + potentials : ndarray + The membrane potential matrix of the neuron group. + numpy: bool + Whether we use numpy array as the functional output. + If ``False``, this function can be JIT compiled. Returns ------- @@ -143,37 +146,43 @@ def voltage_fluctuation(potentials): """ potentials = bm.as_device_array(potentials) - num_hist, num_neu = potentials.shape - var_mean = jnp.mean(_var(potentials, jnp.arange(num_neu))) avg = jnp.mean(potentials, axis=1) avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2 - return lax.cond(var_mean != 0., lambda _: avg_var / var_mean, lambda _: 1., None) + _var = vmap(lambda signal: jnp.mean(signal * signal) - jnp.mean(signal) ** 2, in_axes=1) + var_mean = jnp.mean(_var(potentials)) + r = jnp.where(var_mean == 0., 1., avg_var / var_mean) + return bm.as_numpy(r) if numpy else r -def matrix_correlation(x, y): +def matrix_correlation(x, y, numpy=True): """Pearson correlation of the lower triagonal of two matrices. The triangular matrix is offset by k = 1 in order to ignore the diagonal line Parameters ---------- - x: tensor + x: ndarray First matrix. - y: tensor + y: ndarray Second matrix + numpy: bool + Whether we use numpy array as the functional output. + If ``False``, this function can be JIT compiled. Returns ------- - coef: tensor + coef: ndarray Correlation coefficient """ - x = bm.as_numpy(x) - y = bm.as_numpy(y) + + x = bm.as_numpy(x) if numpy else bm.as_device_array(x) + y = bm.as_numpy(y) if numpy else bm.as_device_array(y) + np = onp if numpy else jnp if x.ndim != 2: - raise ValueError(f'Only support 2d tensor, but we got a tensor ' + raise ValueError(f'Only support 2d array, but we got a array ' f'with the shape of {x.shape}') if y.ndim != 2: - raise ValueError(f'Only support 2d tensor, but we got a tensor ' + raise ValueError(f'Only support 2d array, but we got a array ' f'with the shape of {y.shape}') x = x[np.triu_indices_from(x, k=1)] y = y[np.triu_indices_from(y, k=1)] @@ -181,34 +190,37 @@ def matrix_correlation(x, y): return cc -def functional_connectivity(activities): +def functional_connectivity(activities, numpy=True): """Functional connectivity matrix of timeseries activities. Parameters ---------- - activities: tensor - The multidimensional tensor with the shape of ``(num_time, num_sample)``. + activities: ndarray + The multidimensional array with the shape of ``(num_time, num_sample)``. + numpy: bool + Whether we use numpy array as the functional output. + If ``False``, this function can be JIT compiled. Returns ------- - connectivity_matrix: tensor + connectivity_matrix: ndarray ``num_sample x num_sample`` functional connectivity matrix. """ - activities = bm.as_numpy(activities) + activities = bm.as_numpy(activities) if numpy else bm.as_device_array(activities) + np = onp if numpy else jnp if activities.ndim != 2: - raise ValueError('Only support 2d tensor with shape of "(num_time, num_sample)". ' - f'But we got a tensor with the shape of {activities.shape}') + raise ValueError('Only support 2d array with shape of "(num_time, num_sample)". ' + f'But we got a array with the shape of {activities.shape}') fc = np.corrcoef(activities.T) return np.nan_to_num(fc) -# @jit def functional_connectivity_dynamics(activities, window_size=30, step_size=5): """Computes functional connectivity dynamics (FCD) matrix. Parameters ---------- - activities: tensor + activities: ndarray The time series with shape of ``(num_time, num_sample)``. window_size: int Size of each rolling window in time steps, defaults to 30. @@ -217,50 +229,52 @@ def functional_connectivity_dynamics(activities, window_size=30, step_size=5): Returns ------- - fcd_matrix: tensor + fcd_matrix: ndarray FCD matrix. """ pass -def _weighted_mean(x, w): - """Weighted Mean""" - return jnp.sum(x * w) / jnp.sum(w) - - -def _weighted_cov(x, y, w): - """Weighted Covariance""" - return jnp.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / jnp.sum(w) - - -# @jit -def weighted_correlation(x, y, w): +def weighted_correlation(x, y, w, numpy=True): """Weighted Pearson correlation of two data series. Parameters ---------- - x: tensor + x: ndarray The data series 1. - y: tensor + y: ndarray The data series 2. - w: tensor + w: ndarray Weight vector, must have same length as x and y. + numpy: bool + Whether we use numpy array as the functional output. + If ``False``, this function can be JIT compiled. Returns ------- - corr: tensor + corr: ndarray Weighted correlation coefficient. """ - x = bm.as_device_array(x) - y = bm.as_device_array(y) - w = bm.as_device_array(w) + x = bm.as_numpy(x) if numpy else bm.as_device_array(x) + y = bm.as_numpy(y) if numpy else bm.as_device_array(y) + w = bm.as_numpy(w) if numpy else bm.as_device_array(w) + np = onp if numpy else jnp + + def _weighted_mean(x, w): + """Weighted Mean""" + return np.sum(x * w) / np.sum(w) + + def _weighted_cov(x, y, w): + """Weighted Covariance""" + return np.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / np.sum(w) + if x.ndim != 1: - raise ValueError(f'Only support 1d tensor, but we got a tensor ' + raise ValueError(f'Only support 1d array, but we got a array ' f'with the shape of {x.shape}') if y.ndim != 1: - raise ValueError(f'Only support 1d tensor, but we got a tensor ' + raise ValueError(f'Only support 1d array, but we got a array ' f'with the shape of {y.shape}') if w.ndim != 1: - raise ValueError(f'Only support 1d tensor, but we got a tensor ' + raise ValueError(f'Only support 1d array, but we got a array ' f'with the shape of {w.shape}') - return _weighted_cov(x, y, w) / jnp.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w)) + return _weighted_cov(x, y, w) / np.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w)) diff --git a/brainpy/measure/firings.py b/brainpy/measure/firings.py index 0f335e8b9..9918f5ee3 100644 --- a/brainpy/measure/firings.py +++ b/brainpy/measure/firings.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- -import numpy as np +import numpy as onp import jax.numpy as jnp -from jax import jit from brainpy import math as bm @@ -29,21 +28,14 @@ def raster_plot(sp_matrix, times): Include (neuron index, spike time). """ sp_matrix = bm.as_numpy(sp_matrix) - times = np.asarray(times) - elements = np.where(sp_matrix > 0.) + times = onp.asarray(times) + elements = onp.where(sp_matrix > 0.) index = elements[1] time = times[elements[0]] return index, time -# @jit -def _firing_rate(sp_matrix, window): - sp_matrix = bm.as_device_array(sp_matrix) - rate = jnp.sum(sp_matrix, axis=1) / sp_matrix.shape[1] - return jnp.convolve(rate, window, mode='same') - - -def firing_rate(sp_matrix, width, dt=None, numpy=True): +def firing_rate(spikes, width, dt=None, numpy=True): r"""Calculate the mean firing rate over in a neuron group. This method is adopted from Brian2. @@ -57,21 +49,24 @@ def firing_rate(sp_matrix, width, dt=None, numpy=True): Parameters ---------- - sp_matrix : math.JaxArray, np.ndarray + spikes : ndarray The spike matrix which record spiking activities. width : int, float The width of the ``window`` in millisecond. dt : float, optional The sample rate. + numpy: bool + Whether we use numpy array as the functional output. + If ``False``, this function can be JIT compiled. Returns ------- - rate : numpy.ndarray + rate : ndarray The population rate in Hz, smoothed with the given window. """ + spikes = bm.as_numpy(spikes) if numpy else bm.as_device_array(spikes) + np = onp if numpy else jnp dt = bm.get_dt() if (dt is None) else dt width1 = int(width / 2 / dt) * 2 + 1 - window = jnp.ones(width1) * 1000 / width - fr = _firing_rate(sp_matrix, window) - return bm.as_numpy(fr) if numpy else fr - + window = np.ones(width1) * 1000 / width + return np.convolve(np.mean(spikes, axis=1), window, mode='same') diff --git a/brainpy/measure/lfp.py b/brainpy/measure/lfp.py new file mode 100644 index 000000000..8d1545084 --- /dev/null +++ b/brainpy/measure/lfp.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + + +from jax import numpy as jnp + +import brainpy.math as bm + +__all__ = [ + 'unitary_LFP', +] + + +def unitary_LFP(times, spikes, spike_type='exc', + xmax=0.2, ymax=0.2, va=200., lambda_=0.2, + sig_i=2.1, sig_e=2.1 * 1.5, location='soma layer', seed=None): + """A kernel-based method to calculate unitary local field potentials (uLFP) + from a network of spiking neurons [1]_. + + .. note:: + This method calculates LFP only from the neuronal spikes. It does not consider + the subthreshold synaptic events, or the dendritic voltage-dependent ion channels. + + Examples + -------- + + If you have spike data of excitatory and inhibtiory neurons, you can get the LFP + by the following methods: + + >>> import brainpy as bp + >>> n_time = 1000 + >>> n_exc = 100 + >>> n_inh = 25 + >>> times = bm.arange(n_time) * 0.1 + >>> exc_sps = bp.math.random.random((n_time, n_exc)) < 0.3 + >>> inh_sps = bp.math.random.random((n_time, n_inh)) < 0.4 + >>> lfp = bp.measure.unitary_LFP(times, exc_sps, 'exc') + >>> lfp += bp.measure.unitary_LFP(times, inh_sps, 'inh') + + Parameters + ---------- + times: ndarray + The times of the recording points. + spikes: ndarray + The spikes of excitatory neurons recorded by brainpy monitors. + spike_type: str + The neuron type of the spike trains. It can be "exc" or "inh". + location: str + The location of the spikes recorded. It can be "soma layer", "deep layer", + "superficial layer" and "surface". + xmax: float + Size of the array (in mm). + ymax: float + Size of the array (in mm). + va: int, float + The axon velocity (mm/sec). + lambda_: float + The space constant (mm). + sig_i: float + The std-dev of inhibition (in ms) + sig_e: float + The std-dev for excitation (in ms). + seed: int + The random seed. + + References + ---------- + .. [1] Telenczuk, Bartosz, Maria Telenczuk, and Alain Destexhe. "A kernel-based + method to calculate local field potentials from networks of spiking + neurons." Journal of Neuroscience Methods 344 (2020): 108871. + + """ + times = bm.as_device_array(times) + spikes = bm.as_device_array(spikes) + if spike_type not in ['exc', 'inh']: + raise ValueError('"spike_type" should be "exc or ""inh". ') + if spikes.ndim != 2: + raise ValueError('"E_spikes" should be a matrix with shape of (num_time, num_neuron). ' + f'But we got {spikes.shape}') + if times.shape[0] != spikes.shape[0]: + raise ValueError('times and spikes should be consistent at the firs axis. ' + f'Bug we got {times.shape[0]} != {spikes.shape}.') + + # Distributing cells in a 2D grid + rng = bm.random.RandomState(seed) + num_neuron = spikes.shape[1] + pos_xs, pos_ys = rng.rand(2, num_neuron).value * jnp.array([[xmax], [ymax]]) + pos_xs, pos_ys = jnp.asarray(pos_xs), jnp.asarray(pos_ys) + + # distance/coordinates + xe, ye = xmax / 2, ymax / 2 # coordinates of electrode + dist = jnp.sqrt((pos_xs - xe) ** 2 + (pos_ys - ye) ** 2) # distance to electrode in mm + + # amplitude + if location == 'soma layer': + amp_e, amp_i = 0.48, 3. # exc/inh uLFP amplitude (soma layer) + elif location == 'deep layer': + amp_e, amp_i = -0.16, -0.2 # exc/inh uLFP amplitude (deep layer) + elif location == 'superficial layer': + amp_e, amp_i = 0.24, -1.2 # exc/inh uLFP amplitude (superficial layer) + elif location == 'surface layer': + amp_e, amp_i = -0.08, 0.3 # exc/inh uLFP amplitude (surface) + else: + raise NotImplementedError + A = bm.exp(-dist / lambda_) * (amp_e if spike_type == 'exc' else amp_i) + + # delay + delay = 10.4 + dist / va # delay to peak (in ms) + + # LFP Calculation + iis, ids = jnp.where(spikes) + tts = times[iis] + delay[ids] + exc_amp = A[ids] + tau = (2 * sig_e * sig_e) if spike_type == 'exc' else (2 * sig_i * sig_i) + return bm.for_loop(lambda t: bm.sum(exc_amp * bm.exp(-(t - tts) ** 2 / tau)), [], times) diff --git a/brainpy/measure/tests/test_correlation.py b/brainpy/measure/tests/test_correlation.py index 188fe66dd..aa963fe34 100644 --- a/brainpy/measure/tests/test_correlation.py +++ b/brainpy/measure/tests/test_correlation.py @@ -3,13 +3,19 @@ import unittest import brainpy as bp +import brainpy.math as bm +from jax import jit +from functools import partial class TestCrossCorrelation(unittest.TestCase): def test_c(self): spikes = bp.math.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T cc1 = bp.measure.cross_correlation(spikes, 1., dt=1.) - print(cc1) + f_cc = jit(partial(bp.measure.cross_correlation, numpy=False, bin=1, dt=1.)) + cc2 = f_cc(spikes) + print(cc1, cc2) + self.assertTrue(cc1 == cc2) def test_cc(self): spikes = bp.math.ones((1000, 10)) @@ -47,19 +53,31 @@ def test_cc5(self): class TestVoltageFluctuation(unittest.TestCase): def test_vf1(self): - bp.math.random.seed() - voltages = bp.math.random.normal(0, 10, size=(1000, 100)) + rng = bp.math.random.RandomState(122) + voltages = rng.normal(0, 10, size=(1000, 100)) print(bp.measure.voltage_fluctuation(voltages)) voltages = bp.math.ones((1000, 100)) - print(bp.measure.voltage_fluctuation(voltages)) + r1 = bp.measure.voltage_fluctuation(voltages) + + jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False)) + r2 = jit_f(voltages) + + print(r1, r2) # TODO: JIT results are different? + + # self.assertTrue(r1 == r2) class TestFunctionalConnectivity(unittest.TestCase): def test_cf1(self): bp.math.random.seed() act = bp.math.random.random((10000, 3)) - print(bp.measure.functional_connectivity(act)) + r1 = bp.measure.functional_connectivity(act) + + jit_f = jit(partial(bp.measure.functional_connectivity, numpy=False)) + r2 = jit_f(act) + + self.assertTrue(bm.allclose(r1, r2)) class TestMatrixCorrelation(unittest.TestCase): @@ -67,5 +85,11 @@ def test_mc(self): bp.math.random.seed() A = bp.math.random.random((100, 100)) B = bp.math.random.random((100, 100)) - print(bp.measure.matrix_correlation(A, B)) + r1 = (bp.measure.matrix_correlation(A, B)) + + jit_f = jit(partial(bp.measure.matrix_correlation, numpy=False)) + r2 = jit_f(A, B) + + self.assertTrue(bm.allclose(r1, r2)) + diff --git a/docs/auto_generater.py b/docs/auto_generater.py index e95994bb7..0eeb8b3ec 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -491,6 +491,19 @@ def generate_math_docs(path='apis/auto/math/'): with open(os.path.join(path, 'comparison_table.rst.inc'), 'w') as f: f.write(codes) + module_and_name = [ + ('pre_syn_post', '``pre-syn-post`` Transformations',), + ('multiplication', 'Sparse Matrix Multiplication',), + ('spikegrad', 'Surrogate Gradients for Spike Operation',), + ('op_register', 'Operator Registration',), + ('wrap_jax', 'Other Operators',), + ] + write_submodules(module_name='brainpy.math.operators', + filename=os.path.join(path, 'operators.rst'), + header='Sparse & Event-based Operators', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) + write_module(module_name='brainpy.math.activations', filename=os.path.join(path, 'activations.rst'), header='Activation Functions') @@ -500,9 +513,7 @@ def generate_math_docs(path='apis/auto/math/'): write_module(module_name='brainpy.math.controls', filename=os.path.join(path, 'controls.rst'), header='Control Flows') - write_module(module_name='brainpy.math.operators', - filename=os.path.join(path, 'operators.rst'), - header='Operators') + write_module(module_name='brainpy.math.parallels', filename=os.path.join(path, 'parallels.rst'), header='Parallel Compilation') diff --git a/docs/tutorial_math/control_flows.ipynb b/docs/tutorial_math/control_flows.ipynb index de8719421..f5b8616f5 100644 --- a/docs/tutorial_math/control_flows.ipynb +++ b/docs/tutorial_math/control_flows.ipynb @@ -3,11 +3,7 @@ { "cell_type": "markdown", "id": "254bbbf2", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "# Control Flows" ] @@ -15,11 +11,7 @@ { "cell_type": "markdown", "id": "355bb9b6", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "@[Chaoming Wang](https://github.com/chaoming0625)" ] @@ -37,31 +29,20 @@ "In this section, we are going to talk about how to build effective control flows with BrainPy and JAX." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { "cell_type": "markdown", "id": "465bd161", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 1, "id": "38a2bb50", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "import brainpy as bp\n", @@ -76,10 +57,7 @@ "## 1\\. Selection" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -92,10 +70,7 @@ "- if-elif-else" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -104,10 +79,7 @@ "### Non-`Variable`-based control statements" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -116,10 +88,7 @@ "Actually, BrainPy (based on JAX) **allows to write control flows normally like your familiar Python programs, when the conditional statement depends on non-Variable instances**. For example," ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -143,10 +112,7 @@ " return self.a" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -155,10 +121,7 @@ "In the above example, the target *statement* in ``if (statement)`` syntax relies on a scalar, which is not an instance of [brainpy.math.Variable](./arrays_and_variables.ipynb). In this case, the conditional statements can be arbitrarily complex. You can write your models with normal Python codes. These models will work very well with [JIT compilation](./jit_compilation.ipynb)." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -180,10 +143,7 @@ "model()" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -205,10 +165,7 @@ "model()" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -231,10 +188,7 @@ " print(f\"ValueError: {str(e)}\")" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -243,10 +197,7 @@ "### `Variable`-based control statements" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -255,10 +206,7 @@ "However, if the `statement` target in a ``if ... else ...`` syntax relies on instances of [brainpy.math.Variable](./arrays_and_variables.ipynb), writing Pythonic control flows will cause errors when using JIT compilation." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -278,10 +226,7 @@ " return self.a" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -308,10 +253,7 @@ " print(f\"{e.__class__.__name__}: {str(e)}\")" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -323,10 +265,7 @@ "- [brainpy.math.ifelse](../apis/auto/math/generated/brainpy.math.controls.ifelse.rst): Conditional statements of `if-else`, or `if-elif-else`, ... for a scalar-typed value." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -335,10 +274,7 @@ "### `brainpy.math.where`" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -347,10 +283,7 @@ "``where(condition, x, y)`` function returns elements chosen from *x* or *y* depending on *condition*. It can perform well on scalars, vectors, and high-dimensional arrays." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -371,10 +304,7 @@ "bm.where(a < 0, 0., 1.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -395,10 +325,7 @@ "bm.where(a < 0.5, 0., 1.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -419,10 +346,7 @@ "bm.where(a < 0.5, 0., 1.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -431,10 +355,7 @@ "For the above example, we can rewrite it by using `where` syntax as:" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -453,10 +374,7 @@ " return self.a" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -477,10 +395,7 @@ "model()" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -489,10 +404,7 @@ "### `brainpy.math.ifelse`" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -512,10 +424,7 @@ "```" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -524,10 +433,7 @@ "Based on this function, we can rewrite the above example by using `cond` syntax as:" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -547,10 +453,7 @@ " return self.a" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -571,10 +474,7 @@ "model()" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -596,10 +496,7 @@ "```" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -624,10 +521,7 @@ "It can be expressed as:" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -640,10 +534,7 @@ " branches=[1., 2., 3., 4., 5.])" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -663,10 +554,7 @@ "f(11.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -686,10 +574,7 @@ "f(6.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -709,10 +594,7 @@ "f(1.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -732,10 +614,7 @@ "f(-4.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -755,10 +634,7 @@ "f(-6.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -767,10 +643,7 @@ "A more complex example is:" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -788,10 +661,7 @@ " operands=x)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -811,10 +681,7 @@ "f2(11, 1.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -834,10 +701,7 @@ "f2(6, 1.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -857,10 +721,7 @@ "f2(1, 1.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -880,10 +741,7 @@ "f2(-4, 1.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -903,10 +761,7 @@ "f2(-6, 1.)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -915,10 +770,7 @@ "If instances of `brainpy.math.Variable` are used in branching functions, you can declare them in the `dyn_vars` argument." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -947,10 +799,7 @@ "print('b:', b)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -959,10 +808,7 @@ "## 2\\. Repetition" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -976,10 +822,7 @@ "- **while loop**: Execute a block of statements repeatedly until a given condition is satisfied." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -988,10 +831,7 @@ "### Pythonic loop syntax" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1000,10 +840,7 @@ "Actually, JAX enables to write Pythonic loops. You just need to iterate over you sequence data and then apply your logic on the iterated items. Such kind of Pythonic loop syntax can be compatible with JIT compilation, but will cause long time to trace and compile. For example," ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1024,10 +861,7 @@ " return self.res.value" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -1046,10 +880,7 @@ " return r if return_res else None" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -1071,10 +902,7 @@ "measure_time(model)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -1094,10 +922,7 @@ "measure_time(model)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -1106,10 +931,7 @@ "When the model is complex and the iteration is long, the compilation during the first running will become unbearable. For such cases, you need structural loop syntax." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1128,10 +950,7 @@ "In this section, we only talk about how to use our provided loop functions." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1140,10 +959,7 @@ "### ``brainpy.math.for_loop()``" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1155,21 +971,18 @@ "\n", "```python\n", "\n", - "def for_loop_function(body_fun, dyn_vars, out_vars, xs):\n", + "def for_loop_function(body_fun, dyn_vars, xs):\n", " ys = []\n", " for x in xs:\n", - " # 'dyn_vars' and 'out_vars' are updated in 'body_fun()'\n", + " # 'dyn_vars' are updated in 'body_fun()'\n", " results = body_fun(x)\n", - " ys.append([out_vars, results])\n", + " ys.append(results)\n", " return ys\n", "\n", "```" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1184,10 +997,7 @@ "```" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1196,10 +1006,7 @@ "For the above example, we can rewrite it by using ``brainpy.math.for_loop`` as:" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1222,10 +1029,7 @@ " return bm.for_loop(body_fun=add, dyn_vars=[self.res], operands=self.seq)" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -1248,10 +1052,7 @@ "r.shape" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -1262,10 +1063,7 @@ "``operands`` specified the inputs of the ``body_fun``. It will be looped over the fist axis." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1274,10 +1072,7 @@ "### ``brainpy.math.while_loop()``" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1307,10 +1102,7 @@ "A concreate example of ``brainpy.math.while_loop`` is as the follows:" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1340,10 +1132,7 @@ "bm.while_loop(body_f, cond_f, dyn_vars=[i, counter], operands=())" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -1352,10 +1141,7 @@ "In the above example, we try to implement a sum from 0 to 10 by using two JaxArrays ``i`` and ``counter``." ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1375,10 +1161,7 @@ "counter" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -1398,10 +1181,7 @@ "i" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } }, { @@ -1410,10 +1190,7 @@ "Or, similarly," ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } + "collapsed": false } }, { @@ -1442,10 +1219,7 @@ "bm.while_loop(body_f, cond_f, dyn_vars=[i], operands=(1., ))" ], "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } + "collapsed": false } } ], @@ -1509,4 +1283,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/simulation/Susin_2021_gamma_oscillation_nets.py b/examples/simulation/Susin_2021_gamma_oscillation_nets.py new file mode 100644 index 000000000..d55bb66b2 --- /dev/null +++ b/examples/simulation/Susin_2021_gamma_oscillation_nets.py @@ -0,0 +1,647 @@ +# -*- coding: utf-8 -*- + +""" +Implementation of the paper: + +- Susin, Eduarda, and Alain Destexhe. "Integration, coincidence detection and + resonance in networks of spiking neurons expressing gamma oscillations and + asynchronous states." PLoS computational biology 17.9 (2021): e1009416. + +""" + +import numpy as np +import matplotlib.pyplot as plt +from scipy.signal import kaiserord, lfilter, firwin, hilbert + +import brainpy as bp +import brainpy.math as bm + +# Table 1: specific neuron model parameters +RS_par = dict(Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, gL=10, EL=-65, V_reset=-65, + E_e=0., E_i=-80.) +FS_par = dict(Vth=-47.5, delta=0.5, tau_ref=5., tau_w=500, a=0, b=0, C=150, gL=10, EL=-65, V_reset=-65, + E_e=0., E_i=-80.) +Ch_par = dict(Vth=-47.5, delta=0.5, tau_ref=1., tau_w=50, a=80, b=150, C=150, gL=10, EL=-58, V_reset=-65, + E_e=0., E_i=-80.) + + +class AdEx(bp.NeuGroup): + def __init__( + self, + size, + + # neuronal parameters + Vth=-40, delta=2., tau_ref=5., tau_w=500, a=4, b=20, C=150, + gL=10, EL=-65, V_reset=-65, V_sp_th=-30., + + # synaptic parameters + tau_e=1.5, tau_i=7.5, E_e=0., E_i=-80., + + # other parameters + name=None, method='exp_euler', + V_initializer=bp.init.Uniform(-65, -50), + w_initializer=bp.init.Constant(0.), + ): + super(AdEx, self).__init__(size=size, name=name) + + # neuronal parameters + self.Vth = Vth + self.delta = delta + self.tau_ref = tau_ref + self.tau_w = tau_w + self.a = a + self.b = b + self.C = C + self.gL = gL + self.EL = EL + self.V_reset = V_reset + self.V_sp_th = V_sp_th + + # synaptic parameters + self.tau_e = tau_e + self.tau_i = tau_i + self.E_e = E_e + self.E_i = E_i + + # neuronal variables + self.V = bp.init.variable_(V_initializer, self.num) + self.w = bp.init.variable_(w_initializer, self.num) + self.spike = bm.Variable(self.num, dtype=bool) + self.refractory = bm.Variable(self.num, dtype=bool) + self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e8) + + # synaptic parameters + self.ge = bm.Variable(self.num) + self.gi = bm.Variable(self.num) + + # integral + self.integral = bp.odeint(bp.JointEq(self.dV, self.dw, self.dge, self.dgi), method=method) + + def dge(self, ge, t): + return -ge / self.tau_e + + def dgi(self, gi, t): + return -gi / self.tau_i + + def dV(self, V, t, w, ge, gi, Iext=None): + I = ge * (self.E_e - V) + gi * (self.E_i - V) + if Iext is not None: I += Iext + dVdt = (self.gL * self.delta * bm.exp((V - self.Vth) / self.delta) + - w + self.gL * (self.EL - V) + I) / self.C + return dVdt + + def dw(self, w, t, V): + dwdt = (self.a * (V - self.EL) - w) / self.tau_w + return dwdt + + def update(self, tdi, x=None): + V, w, ge, gi = self.integral(self.V.value, self.w.value, self.ge.value, self.gi.value, + tdi.t, Iext=x, dt=tdi.dt) + refractory = (tdi.t - self.t_last_spike) <= self.tau_ref + V = bm.where(refractory, self.V.value, V) + spike = V >= self.V_sp_th + self.V.value = bm.where(spike, self.V_reset, V) + self.w.value = bm.where(spike, w + self.b, w) + self.ge.value = ge + self.gi.value = gi + self.spike.value = spike + self.refractory.value = bm.logical_or(refractory, spike) + self.t_last_spike.value = bm.where(spike, tdi.t, self.t_last_spike) + + +class PINGNet(bp.Network): + def __init__(self, ext_varied_rates, ext_weight=4., method='exp_euler', dt=bm.get_dt()): + super(PINGNet, self).__init__() + + self.num_exc = 20000 + self.num_inh = 5000 + self.exc_syn_tau = 1. # ms + self.inh_syn_tau = 7.5 # ms + self.exc_syn_weight = 5. # nS + self.inh_syn_weight = 3.34 # nS + self.num_delay_step = int(1.5 / dt) + self.ext_varied_rates = ext_varied_rates + + # neuronal populations + RS_par_ = RS_par.copy() + FS_par_ = FS_par.copy() + RS_par_.update(Vth=-50, V_sp_th=-40) + FS_par_.update(Vth=-50, V_sp_th=-40) + self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_) + self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_) + self.ext_pop = bp.neurons.PoissonGroup(self.num_exc, freqs=bm.Variable(1)) + + # Poisson inputs + self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=ext_weight) + self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=ext_weight) + + # synaptic projections + self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight, + delay_step=self.num_delay_step) + self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight, + delay_step=self.num_delay_step) + + def change_freq(self, tdi): + self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i] + + +class AINet(bp.Network): + def __init__(self, ext_varied_rates, ext_weight=1., method='exp_euler', dt=bm.get_dt()): + super(AINet, self).__init__() + + self.num_exc = 20000 + self.num_inh = 5000 + self.exc_syn_tau = 5. # ms + self.inh_syn_tau = 5. # ms + self.exc_syn_weight = 1. # nS + self.inh_syn_weight = 5. # nS + self.num_delay_step = int(1.5 / dt) + self.ext_varied_rates = ext_varied_rates + + # neuronal populations + RS_par_ = RS_par.copy() + FS_par_ = FS_par.copy() + RS_par_.update(Vth=-50, V_sp_th=-40) + FS_par_.update(Vth=-50, V_sp_th=-40) + self.rs_pop = AdEx(self.num_exc, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_) + self.fs_pop = AdEx(self.num_inh, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_) + self.ext_pop = bp.neurons.PoissonGroup(self.num_exc, freqs=bm.Variable(1)) + + # Poisson inputs + self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=ext_weight) + self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=ext_weight) + + # synaptic projections + self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight, + delay_step=self.num_delay_step) + self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight, + delay_step=self.num_delay_step) + + def change_freq(self, tdi): + self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i] + + +class INGNet(bp.Network): + def __init__(self, ext_varied_rates, ext_weight=0.9, method='exp_euler', dt=bm.get_dt()): + super(INGNet, self).__init__() + + self.num_rs = 20000 + self.num_fs = 4000 + self.num_fs2 = 1000 + self.exc_syn_tau = 5. # ms + self.inh_syn_tau = 5. # ms + self.exc_syn_weight = 1. # nS + self.inh_syn_weight = 5. # nS + self.num_delay_step = int(1.5 / dt) + self.ext_varied_rates = ext_varied_rates + + # neuronal populations + RS_par_ = RS_par.copy() + FS_par_ = FS_par.copy() + FS2_par_ = FS_par.copy() + RS_par_.update(Vth=-50, V_sp_th=-40) + FS_par_.update(Vth=-50, V_sp_th=-40) + FS2_par_.update(Vth=-50, V_sp_th=-40) + self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_) + self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_) + self.fs2_pop = AdEx(self.num_fs2, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS2_par_) + self.ext_pop = bp.neurons.PoissonGroup(self.num_rs, freqs=bm.Variable(1)) + + # Poisson inputs + self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=ext_weight) + self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=ext_weight) + self.ext_to_FS2 = bp.synapses.Delta(self.ext_pop, self.fs2_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=ext_weight) + + # synaptic projections + self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.RS_to_FS2 = bp.synapses.Delta(self.rs_pop, self.fs2_pop, bp.conn.FixedProb(0.15), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + + self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight, + delay_step=self.num_delay_step) + self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight, + delay_step=self.num_delay_step) + self.FS_to_FS2 = bp.synapses.Delta(self.fs_pop, self.fs2_pop, bp.conn.FixedProb(0.03), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight, + delay_step=self.num_delay_step) + + self.FS2_to_RS = bp.synapses.Delta(self.fs2_pop, self.rs_pop, bp.conn.FixedProb(0.15), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.FS2_to_FS = bp.synapses.Delta(self.fs2_pop, self.fs_pop, bp.conn.FixedProb(0.15), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.FS2_to_FS2 = bp.synapses.Delta(self.fs2_pop, self.fs2_pop, bp.conn.FixedProb(0.6), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + + def change_freq(self, tdi): + self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i] + + +class CHINGNet(bp.Network): + def __init__(self, ext_varied_rates, method='exp_euler', dt=bm.get_dt()): + super(CHINGNet, self).__init__() + + self.num_rs = 19000 + self.num_fs = 5000 + self.num_ch = 1000 + self.exc_syn_tau = 5. # ms + self.inh_syn_tau = 5. # ms + self.exc_syn_weight = 1. # nS + self.inh_syn_weight1 = 7. # nS + self.inh_syn_weight2 = 5. # nS + self.ext_weight1 = 1. # nS + self.ext_weight2 = 0.75 # nS + self.num_delay_step = int(1.5 / dt) + self.ext_varied_rates = ext_varied_rates + + # neuronal populations + RS_par_ = RS_par.copy() + FS_par_ = FS_par.copy() + Ch_par_ = Ch_par.copy() + RS_par_.update(Vth=-50, V_sp_th=-40) + FS_par_.update(Vth=-50, V_sp_th=-40) + Ch_par_.update(Vth=-50, V_sp_th=-40) + self.rs_pop = AdEx(self.num_rs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **RS_par_) + self.fs_pop = AdEx(self.num_fs, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **FS_par_) + self.ch_pop = AdEx(self.num_ch, tau_e=self.exc_syn_tau, tau_i=self.inh_syn_tau, method=method, **Ch_par_) + self.ext_pop = bp.neurons.PoissonGroup(self.num_rs, freqs=bm.Variable(1)) + + # Poisson inputs + self.ext_to_FS = bp.synapses.Delta(self.ext_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.ext_weight2) + self.ext_to_RS = bp.synapses.Delta(self.ext_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.ext_weight1) + self.ext_to_CH = bp.synapses.Delta(self.ext_pop, self.ch_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.ext_weight1) + + # synaptic projections + self.RS_to_FS = bp.synapses.Delta(self.rs_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.RS_to_RS = bp.synapses.Delta(self.rs_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.RS_to_Ch = bp.synapses.Delta(self.rs_pop, self.ch_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + + self.FS_to_RS = bp.synapses.Delta(self.fs_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight1, + delay_step=self.num_delay_step) + self.FS_to_FS = bp.synapses.Delta(self.fs_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight2, + delay_step=self.num_delay_step) + self.FS_to_Ch = bp.synapses.Delta(self.fs_pop, self.ch_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='gi'), + g_max=self.inh_syn_weight1, + delay_step=self.num_delay_step) + + self.Ch_to_RS = bp.synapses.Delta(self.ch_pop, self.rs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.Ch_to_FS = bp.synapses.Delta(self.ch_pop, self.fs_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + self.Ch_to_Ch = bp.synapses.Delta(self.ch_pop, self.ch_pop, bp.conn.FixedProb(0.02), + output=bp.synouts.CUBA(target_var='ge'), + g_max=self.exc_syn_weight, + delay_step=self.num_delay_step) + + def change_freq(self, tdi): + self.ext_pop.freqs[0] = self.ext_varied_rates[tdi.i] + + +def get_inputs(c_low, c_high, t_transition, t_min_plato, t_max_plato, t_gap, t_total, dt=None): + dt = bm.get_dt() if dt is None else dt + t = 0 + num_gap = int(t_gap / dt) + num_total = int(t_total / dt) + num_transition = int(t_transition / dt) + + inputs = [] + ramp_up = np.linspace(c_low, c_high, num_transition) + ramp_down = np.linspace(c_high, c_low, num_transition) + plato_base = np.ones(num_gap) * c_low + while t < num_total: + num_plato = int(np.random.uniform(low=t_min_plato, high=t_max_plato, size=1) / dt) + inputs.extend([plato_base, ramp_up, np.ones(num_plato) * c_high, ramp_down]) + t += (num_gap + num_transition + num_plato + num_transition) + return bm.asarray(np.concatenate(inputs)[:num_total]) + + +def signal_phase_by_Hilbert(signal, signal_time, low_cut, high_cut, sampling_space): + # sampling_space: in seconds (no units) + # signal_time: in seconds (no units) + # low_cut: in Hz (no units)(band to filter) + # high_cut: in Hz (no units)(band to filter) + + signal = signal - np.mean(signal) + width = 5.0 # The desired width in Hz of the transition from pass to stop + ripple_db = 60.0 # The desired attenuation in the stop band, in dB. + sampling_rate = 1. / sampling_space + Nyquist = sampling_rate / 2. + + num_taps, beta = kaiserord(ripple_db, width / Nyquist) + if num_taps % 2 == 0: + num_taps = num_taps + 1 # Numtaps must be odd + taps = firwin(num_taps, [low_cut / Nyquist, high_cut / Nyquist], window=('kaiser', beta), nyq=1.0, + pass_zero=False, scale=True) + filtered_signal = lfilter(taps, 1.0, signal) + delay = 0.5 * (num_taps - 1) / sampling_rate # To corrected to zero-phase + delay_index = int(np.floor(delay * sampling_rate)) + filtered_signal = filtered_signal[num_taps - 1:] # taking out the "corrupted" signal + # correcting the delay and taking out the "corrupted" signal part + filtered_time = signal_time[num_taps - 1:] - delay + cutted_signal = signal[(num_taps - 1 - delay_index): (len(signal) - (num_taps - 1 - delay_index))] + + # -------------------------------------------------------------------------- + # The hilbert transform are very slow when the signal has odd lenght, + # This part check if the length is odd, and if this is the case it adds a zero in the end + # of all the vectors related to the filtered Signal: + if len(filtered_signal) % 2 != 0: # If the lengh is odd + tmp1 = filtered_signal.tolist() + tmp1.append(0) + tmp2 = filtered_time.tolist() + tmp2.append((len(filtered_time) + 1) * sampling_space + filtered_time[0]) + tmp3 = cutted_signal.tolist() + tmp3.append(0) + filtered_signal = np.asarray(tmp1) + filtered_time = np.asarray(tmp2) + cutted_signal = np.asarray(tmp3) + # -------------------------------------------------------------------------- + + ht_filtered_signal = hilbert(filtered_signal) + envelope = np.abs(ht_filtered_signal) + phase = np.angle(ht_filtered_signal) # The phase is between -pi and pi in radians + + return filtered_time, filtered_signal, cutted_signal, envelope, phase + + +def visualize_simulation_results(times, spikes, example_potentials, varied_rates, + xlim=None, t_lfp_start=None, t_lfp_end=None, filename=None): + fig, gs = bp.visualize.get_figure(7, 1, 1, 12) + # 1. input firing rate + ax = fig.add_subplot(gs[0]) + plt.plot(times, varied_rates) + if xlim is None: + xlim = (0, times[-1]) + ax.set_xlim(*xlim) + ax.set_xticks([]) + ax.set_ylabel('External\nRate (Hz)') + + # 2. inhibitory cell rater plot + ax = fig.add_subplot(gs[1: 3]) + i = 0 + y_ticks = ([], []) + for key, (sp_matrix, sp_type) in spikes.items(): + iis, sps = np.where(sp_matrix) + tts = times[iis] + plt.plot(tts, sps + i, '.', markersize=1, label=key) + y_ticks[0].append(i + sp_matrix.shape[1] / 2) + y_ticks[1].append(key) + i += sp_matrix.shape[1] + ax.set_xlim(*xlim) + ax.set_xlabel('') + ax.set_ylabel('Neuron Index') + ax.set_xticks([]) + ax.set_yticks(*y_ticks) + # ax.legend() + + # 3. example membrane potential + ax = fig.add_subplot(gs[3: 5]) + for key, potential in example_potentials.items(): + vs = np.where(spikes[key][0][:, 0], 0, potential) + plt.plot(times, vs, label=key) + ax.set_xlim(*xlim) + ax.set_xticks([]) + ax.set_ylabel('V (mV)') + ax.legend() + + # 4. LFP + ax = fig.add_subplot(gs[5:7]) + ax.set_xlim(*xlim) + t1 = int(t_lfp_start / bm.get_dt()) if t_lfp_start is not None else 0 + t2 = int(t_lfp_end / bm.get_dt()) if t_lfp_end is not None else len(times) + times = times[t1: t2] + lfp = 0 + for sp_matrix, sp_type in spikes.values(): + lfp += bp.measure.unitary_LFP(times, sp_matrix[t1: t2], sp_type) + phase_ts, filtered, cutted, envelope, _ = signal_phase_by_Hilbert(bm.as_numpy(lfp), times * 1e-3, 30, 50, + bm.get_dt() * 1e-3) + plt.plot(phase_ts * 1e3, cutted, color='k', label='Raw LFP') + plt.plot(phase_ts * 1e3, filtered, color='orange', label="Filtered LFP (30-50 Hz)") + plt.plot(phase_ts * 1e3, envelope, color='purple', label="Hilbert Envelope") + plt.legend(loc='best') + plt.xlabel('Time (ms)') + + # save or show + if filename: + plt.savefig(filename, dpi=500) + plt.show() + + +def simulate_single_neuron(duration=4e3): + input_currents = get_inputs(0., 500, 50, 500, 600, 2e3, duration) + + RS_cell = AdEx(1, V_sp_th=RS_par['Vth'], **RS_par) + runner = bp.DSRunner(RS_cell, monitors=['V']) + runner.run(duration, inputs=input_currents) + + FS_cell = AdEx(1, V_sp_th=FS_par['Vth'], **FS_par) + runner2 = bp.DSRunner(FS_cell, monitors=['V']) + runner2.run(duration, inputs=input_currents) + + fig, gs = bp.visualize.get_figure(3, 1, 3, 10) + ax = fig.add_subplot(gs[0, 0]) + bp.visualize.line_plot(runner.mon.ts, input_currents) + ax.set_xlim(1600, 3000) + ax.set_title('Input Current') + + ax = fig.add_subplot(gs[1, 0]) + ax.set_xlim(1600, 3000) + ax.set_title('RS Neuron') + bp.visualize.line_plot(runner.mon.ts, runner.mon.V) + + ax = fig.add_subplot(gs[2, 0]) + ax.set_xlim(1600, 3000) + ax.set_title('FS Neuron') + bp.visualize.line_plot(runner2.mon.ts, runner2.mon.V) + plt.show() + + +def simulate_ping_net(): + duration = 6e3 + varied_rates = get_inputs(2., 3., 50., 150, 600, 1e3, duration) + + net = PINGNet(varied_rates, ext_weight=4.) + runner = bp.DSRunner( + net, + fun_inputs=net.change_freq, + fun_monitors={'FS.V0': lambda tdi: net.fs_pop.V[0], + 'RS.V0': lambda tdi: net.rs_pop.V[0], + 'FS.spike': lambda tdi: net.fs_pop.spike, + 'RS.spike': lambda tdi: net.rs_pop.spike} + ) + runner.run(duration) + + visualize_simulation_results(times=runner.mon.ts, + spikes={'FS': (runner.mon['FS.spike'], 'inh'), + 'RS': (runner.mon['RS.spike'], 'exc')}, + example_potentials={'FS': runner.mon['FS.V0'], + 'RS': runner.mon['RS.V0']}, + varied_rates=varied_rates.to_numpy(), + xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3) + + +def simulate_ai_net(): + duration = 2e3 + varied_rates = get_inputs(2., 2., 50., 150, 600, 1e3, duration) + + net = PINGNet(varied_rates, ext_weight=5.) + runner = bp.DSRunner( + net, + fun_inputs=net.change_freq, + fun_monitors={'FS.V0': lambda tdi: net.fs_pop.V[0], + 'RS.V0': lambda tdi: net.rs_pop.V[0], + 'FS.spike': lambda tdi: net.fs_pop.spike, + 'RS.spike': lambda tdi: net.rs_pop.spike} + ) + runner.run(duration) + + visualize_simulation_results(times=runner.mon.ts, + spikes={'FS': (runner.mon['FS.spike'], 'inh'), + 'RS': (runner.mon['RS.spike'], 'exc')}, + example_potentials={'FS': runner.mon['FS.V0'], + 'RS': runner.mon['RS.V0']}, + varied_rates=varied_rates.to_numpy()) + + +def simulate_ing_net(): + duration = 6e3 + varied_rates = get_inputs(2., 3., 50., 350, 600, 1e3, duration) + + net = INGNet(varied_rates, ext_weight=0.9) + runner = bp.DSRunner( + net, + fun_inputs=net.change_freq, + fun_monitors={'FS.V0': lambda tdi: net.fs_pop.V[0], + 'FS2.V0': lambda tdi: net.fs2_pop.V[0], + 'RS.V0': lambda tdi: net.rs_pop.V[0], + 'FS.spike': lambda tdi: net.fs_pop.spike, + 'FS2.spike': lambda tdi: net.fs2_pop.spike, + 'RS.spike': lambda tdi: net.rs_pop.spike} + ) + runner.run(duration) + + visualize_simulation_results(times=runner.mon.ts, + spikes={'FS': (runner.mon['FS.spike'], 'inh'), + 'FS2': (runner.mon['FS2.spike'], 'inh'), + 'RS': (runner.mon['RS.spike'], 'exc')}, + example_potentials={'FS': runner.mon['FS.V0'], + 'FS2': runner.mon['FS2.V0'], + 'RS': runner.mon['RS.V0']}, + varied_rates=varied_rates.to_numpy(), + xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3) + + +def simulate_ching_net(): + duration = 6e3 + varied_rates = get_inputs(1., 2., 50., 150, 600, 1e3, duration) + + net = CHINGNet(varied_rates) + runner = bp.DSRunner( + net, + fun_inputs=net.change_freq, + fun_monitors={'FS.V0': lambda tdi: net.fs_pop.V[0], + 'CH.V0': lambda tdi: net.ch_pop.V[0], + 'RS.V0': lambda tdi: net.rs_pop.V[0], + 'FS.spike': lambda tdi: net.fs_pop.spike, + 'CH.spike': lambda tdi: net.ch_pop.spike, + 'RS.spike': lambda tdi: net.rs_pop.spike} + ) + runner.run(duration) + + visualize_simulation_results(times=runner.mon.ts, + spikes={'FS': (runner.mon['FS.spike'], 'inh'), + 'CH': (runner.mon['CH.spike'], 'exc'), + 'RS': (runner.mon['RS.spike'], 'exc')}, + example_potentials={'FS': runner.mon['FS.V0'], + 'CH': runner.mon['CH.V0'], + 'RS': runner.mon['RS.V0']}, + varied_rates=varied_rates.to_numpy(), + xlim=(2e3, 3.4e3), t_lfp_start=1e3, t_lfp_end=5e3) + + +if __name__ == '__main__': + simulate_single_neuron() + simulate_ping_net() + simulate_ai_net() + simulate_ing_net() + simulate_ching_net()