diff --git a/brainpy/dyn/neurons/lif.py b/brainpy/dyn/neurons/lif.py index 23fa28c8..f7ddf886 100644 --- a/brainpy/dyn/neurons/lif.py +++ b/brainpy/dyn/neurons/lif.py @@ -857,6 +857,10 @@ class ExpIF(ExpIFLTC): conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire + .. seealso:: + + :class:`brainpy.state.ExpIF` provides the state-based formulation of this neuron. + **Examples** There is a simple usage example:: @@ -978,6 +982,10 @@ class ExpIFRefLTC(ExpIFLTC): conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire + .. seealso:: + + :class:`brainpy.state.ExpIFRef` provides the state-based formulation of this neuron. + **Examples** There is a simple usage example:: @@ -1319,6 +1327,10 @@ class AdExIFLTC(GradNeuDyn): inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model + .. seealso:: + + :class:`brainpy.state.AdExIF` provides the state-based formulation of this model. + **Examples** An example usage: diff --git a/brainpy/state/__init__.py b/brainpy/state/__init__.py index 05fa5fed..8fbeed14 100644 --- a/brainpy/state/__init__.py +++ b/brainpy/state/__init__.py @@ -18,8 +18,12 @@ from ._base import __all__ as base_all from ._exponential import * from ._exponential import __all__ as exp_all +from ._hh import * +from ._hh import __all__ as hh_all from ._inputs import * from ._inputs import __all__ as inputs_all +from ._izhikevich import * +from ._izhikevich import __all__ as izh_all from ._lif import * from ._lif import __all__ as neuron_all from ._projection import * @@ -34,13 +38,8 @@ from ._synaptic_projection import __all__ as synproj_all from ._synouts import * from ._synouts import __all__ as synout_all -from .. import mixin -__main__ = ['version2', 'mixin'] + inputs_all + neuron_all + readout_all + stp_all + synapse_all +__main__ = inputs_all + neuron_all + izh_all + hh_all + readout_all + stp_all + synapse_all __main__ = __main__ + synout_all + base_all + exp_all + proj_all + synproj_all -del inputs_all, neuron_all, readout_all, stp_all, synapse_all, synout_all, base_all +del inputs_all, neuron_all, izh_all, hh_all, readout_all, stp_all, synapse_all, synout_all, base_all del exp_all, proj_all, synproj_all - -if __name__ == '__main__': - mixin - diff --git a/brainpy/state/_base.py b/brainpy/state/_base.py index 8be13440..7b17dd78 100644 --- a/brainpy/state/_base.py +++ b/brainpy/state/_base.py @@ -40,6 +40,8 @@ def _input_label_repr(name: str, label: Optional[str] = None): class Dynamics(brainstate.nn.Dynamics): + __module__ = 'brainpy.state' + def __init__(self, in_size: Size, name: Optional[str] = None): # initialize super().__init__(name=name, in_size=in_size) @@ -401,21 +403,21 @@ def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T: Examples -------- >>> import brainstate - >>> n1 = brainstate.nn.LIF(10) - >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1 + >>> n1 = brainpy.state.LIF(10) + >>> n1.align_pre(brainpy.state.Expon.desc(n1.varshape)) # n2 will run after n1 """ if isinstance(dyn, Dynamics): - self._add_after_update(id(dyn), dyn) + self.add_after_update(id(dyn), dyn) return dyn elif isinstance(dyn, ParamDescriber): if not issubclass(dyn.cls, Dynamics): raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.') - if not self._has_after_update(dyn.identifier): - self._add_after_update( + if not self.has_after_update(dyn.identifier): + self.add_after_update( dyn.identifier, dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape) ) - return self._get_after_update(dyn.identifier) + return self.get_after_update(dyn.identifier) else: raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.') @@ -425,7 +427,7 @@ class Neuron(Dynamics): Base class for all spiking neuron models. This abstract class serves as the foundation for implementing various spiking neuron - models in the BrainPy framework. It extends the ``brainstate.nn.Dynamics`` class and + models in the BrainPy framework. It extends the ``brainpy.state.Dynamics`` class and provides common functionality for spike generation, membrane potential dynamics, and surrogate gradient handling required for training spiking neural networks. @@ -595,7 +597,7 @@ class Neuron(Dynamics): .. [3] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -849,4 +851,4 @@ class Synapse(Dynamics): .. [3] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' diff --git a/brainpy/state/_base_test.py b/brainpy/state/_base_test.py index b378ed55..f3dcdad5 100644 --- a/brainpy/state/_base_test.py +++ b/brainpy/state/_base_test.py @@ -148,7 +148,7 @@ def test_neuron_soft_vs_hard_reset(self): def test_neuron_module_attribute(self): """Test __module__ attribute is correctly set.""" neuron = Neuron(in_size=self.in_size) - self.assertEqual(neuron.__module__, 'brainpy') + self.assertEqual(neuron.__module__, 'brainpy.state') class TestSynapseBaseClass(unittest.TestCase): @@ -247,7 +247,7 @@ def update(self, x=None): def test_synapse_module_attribute(self): """Test __module__ attribute is correctly set.""" synapse = Synapse(in_size=self.in_size) - self.assertEqual(synapse.__module__, 'brainpy') + self.assertEqual(synapse.__module__, 'brainpy.state') def test_synapse_varshape_attribute(self): """Test varshape attribute is correctly set.""" diff --git a/brainpy/state/_exponential.py b/brainpy/state/_exponential.py index 1f0ff1b0..30b39189 100644 --- a/brainpy/state/_exponential.py +++ b/brainpy/state/_exponential.py @@ -71,7 +71,7 @@ class Expon(Synapse, AlignPost): where synaptic variables are aligned with post-synaptic neurons, enabling event-driven computation and more efficient handling of sparse connectivity patterns. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -156,7 +156,7 @@ class DualExpon(Synapse, AlignPost): where synaptic variables are aligned with post-synaptic neurons, enabling event-driven computation and more efficient handling of sparse connectivity patterns. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, diff --git a/brainpy/state/_hh.py b/brainpy/state/_hh.py new file mode 100644 index 00000000..86d4f1f2 --- /dev/null +++ b/brainpy/state/_hh.py @@ -0,0 +1,666 @@ +# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# -*- coding: utf-8 -*- + +from typing import Callable + +import brainstate +import braintools +import brainunit as u +import jax +from brainstate.typing import ArrayLike, Size + +from ._base import Neuron + +__all__ = [ + 'HH', 'MorrisLecar', 'WangBuzsakiHH', +] + + +class HH(Neuron): + r"""Hodgkin–Huxley neuron model. + + **Model Descriptions** + + The Hodgkin-Huxley (HH; Hodgkin & Huxley, 1952) model for the generation of + the nerve action potential is one of the most successful mathematical models of + a complex biological process that has ever been formulated. The basic concepts + expressed in the model have proved a valid approach to the study of bio-electrical + activity from the most primitive single-celled organisms such as *Paramecium*, + right through to the neurons within our own brains. + + Mathematically, the model is given by, + + $$ + C \frac {dV} {dt} = -(\bar{g}_{Na} m^3 h (V-E_{Na}) + + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) + $$ + + $$ + \frac {dx} {dt} = \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} + $$ + + where + + $$ + \alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} + $$ + + $$ + \beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) + $$ + + $$ + \alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) + $$ + + $$ + \beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} + $$ + + $$ + \alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} + $$ + + $$ + \beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) + $$ + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + ENa : ArrayLike, default=50. * u.mV + Reversal potential of sodium. + gNa : ArrayLike, default=120. * u.msiemens + Maximum conductance of sodium channel. + EK : ArrayLike, default=-77. * u.mV + Reversal potential of potassium. + gK : ArrayLike, default=36. * u.msiemens + Maximum conductance of potassium channel. + EL : ArrayLike, default=-54.387 * u.mV + Reversal potential of leak channel. + gL : ArrayLike, default=0.03 * u.msiemens + Conductance of leak channel. + V_th : ArrayLike, default=20. * u.mV + Threshold of the membrane spike. + C : ArrayLike, default=1.0 * u.ufarad + Membrane capacitance. + V_initializer : Callable + Initializer for membrane potential. + m_initializer : Callable, optional + Initializer for m channel. If None, uses steady state. + h_initializer : Callable, optional + Initializer for h channel. If None, uses steady state. + n_initializer : Callable, optional + Initializer for n channel. If None, uses steady state. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + m : HiddenState + Sodium activation variable. + h : HiddenState + Sodium inactivation variable. + n : HiddenState + Potassium activation variable. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create an HH neuron layer with 10 neurons + >>> hh = brainpy.state.HH(10) + >>> + >>> # Initialize the state + >>> hh.init_state(batch_size=1) + >>> + >>> # Apply an input current and update the neuron state + >>> spikes = hh.update(x=10.*u.uA) + + References + ---------- + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description + of membrane current and its application to conduction and excitation + in nerve." The Journal of physiology 117.4 (1952): 500. + .. [2] https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model + """ + + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + ENa: ArrayLike = 50. * u.mV, + gNa: ArrayLike = 120. * u.msiemens, + EK: ArrayLike = -77. * u.mV, + gK: ArrayLike = 36. * u.msiemens, + EL: ArrayLike = -54.387 * u.mV, + gL: ArrayLike = 0.03 * u.msiemens, + V_th: ArrayLike = 20. * u.mV, + C: ArrayLike = 1.0 * u.ufarad, + V_initializer: Callable = braintools.init.Uniform(-70. * u.mV, -60. * u.mV), + m_initializer: Callable = None, + h_initializer: Callable = None, + n_initializer: Callable = None, + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.ENa = braintools.init.param(ENa, self.varshape) + self.EK = braintools.init.param(EK, self.varshape) + self.EL = braintools.init.param(EL, self.varshape) + self.gNa = braintools.init.param(gNa, self.varshape) + self.gK = braintools.init.param(gK, self.varshape) + self.gL = braintools.init.param(gL, self.varshape) + self.C = braintools.init.param(C, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + + # initializers + self.V_initializer = V_initializer + self.m_initializer = m_initializer + self.h_initializer = h_initializer + self.n_initializer = n_initializer + + def m_alpha(self, V): + return 1. / u.math.exprel(-(V + 40. * u.mV) / (10. * u.mV)) / u.ms + + def m_beta(self, V): + return 4.0 / u.ms * u.math.exp(-(V + 65. * u.mV) / (18. * u.mV)) + + def m_inf(self, V): + return self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) + + def h_alpha(self, V): + return 0.07 / u.ms * u.math.exp(-(V + 65. * u.mV) / (20. * u.mV)) + + def h_beta(self, V): + return 1. / u.ms / (1. + u.math.exp(-(V + 35. * u.mV) / (10. * u.mV))) + + def h_inf(self, V): + return self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) + + def n_alpha(self, V): + return 0.1 / u.ms / u.math.exprel(-(V + 55. * u.mV) / (10. * u.mV)) + + def n_beta(self, V): + return 0.125 / u.ms * u.math.exp(-(V + 65. * u.mV) / (80. * u.mV)) + + def n_inf(self, V): + return self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + if self.m_initializer is None: + self.m = brainstate.HiddenState(self.m_inf(self.V.value)) + else: + self.m = brainstate.HiddenState(braintools.init.param(self.m_initializer, self.varshape, batch_size)) + if self.h_initializer is None: + self.h = brainstate.HiddenState(self.h_inf(self.V.value)) + else: + self.h = brainstate.HiddenState(braintools.init.param(self.h_initializer, self.varshape, batch_size)) + if self.n_initializer is None: + self.n = brainstate.HiddenState(self.n_inf(self.V.value)) + else: + self.n = brainstate.HiddenState(braintools.init.param(self.n_initializer, self.varshape, batch_size)) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + if self.m_initializer is None: + self.m.value = self.m_inf(self.V.value) + else: + self.m.value = braintools.init.param(self.m_initializer, self.varshape, batch_size) + if self.h_initializer is None: + self.h.value = self.h_inf(self.V.value) + else: + self.h.value = braintools.init.param(self.h_initializer, self.varshape, batch_size) + if self.n_initializer is None: + self.n.value = self.n_inf(self.V.value) + else: + self.n.value = braintools.init.param(self.n_initializer, self.varshape, batch_size) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / self.V_th + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.uA): + last_V = self.V.value + last_m = self.m.value + last_h = self.h.value + last_n = self.n.value + + # Ionic currents + I_Na = (self.gNa * last_m ** 3 * last_h) * (last_V - self.ENa) + I_K = (self.gK * last_n ** 4) * (last_V - self.EK) + I_leak = self.gL * (last_V - self.EL) + + # Voltage dynamics + I_total = self.sum_current_inputs(x, last_V) + dV = lambda V: (-I_Na - I_K - I_leak + I_total) / self.C + + # Gating variable dynamics + dm = lambda m: self.m_alpha(last_V) * (1. - m) - self.m_beta(last_V) * m + dh = lambda h: self.h_alpha(last_V) * (1. - h) - self.h_beta(last_V) * h + dn = lambda n: self.n_alpha(last_V) * (1. - n) - self.n_beta(last_V) * n + + V = brainstate.nn.exp_euler_step(dV, last_V) + V = self.sum_delta_inputs(V) + m = brainstate.nn.exp_euler_step(dm, last_m) + h = brainstate.nn.exp_euler_step(dh, last_h) + n = brainstate.nn.exp_euler_step(dn, last_n) + + self.V.value = V + self.m.value = m + self.h.value = h + self.n.value = n + return self.get_spike(V) + + +class MorrisLecar(Neuron): + r"""The Morris-Lecar neuron model. + + **Model Descriptions** + + The Morris-Lecar model (Also known as :math:`I_{Ca}+I_K`-model) + is a two-dimensional "reduced" excitation model applicable to + systems having two non-inactivating voltage-sensitive conductances. + This model was named after Cathy Morris and Harold Lecar, who + derived it in 1981. Because it is two-dimensional, the Morris-Lecar + model is one of the favorite conductance-based models in computational neuroscience. + + The original form of the model employed an instantaneously + responding voltage-sensitive Ca2+ conductance for excitation and a delayed + voltage-dependent K+ conductance for recovery. The equations of the model are: + + $$ + \begin{aligned} + C\frac{dV}{dt} =& - g_{Ca} M_{\infty} (V - V_{Ca})- g_{K} W(V - V_{K}) - + g_{Leak} (V - V_{Leak}) + I_{ext} \\ + \frac{dW}{dt} =& \frac{W_{\infty}(V) - W}{ \tau_W(V)} + \end{aligned} + $$ + + Here, :math:`V` is the membrane potential, :math:`W` is the "recovery variable", + which is almost invariably the normalized :math:`K^+`-ion conductance, and + :math:`I_{ext}` is the applied current stimulus. + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + V_Ca : ArrayLike, default=130. * u.mV + Equilibrium potential of Ca+. + g_Ca : ArrayLike, default=4.4 * u.msiemens + Maximum conductance of Ca+. + V_K : ArrayLike, default=-84. * u.mV + Equilibrium potential of K+. + g_K : ArrayLike, default=8. * u.msiemens + Maximum conductance of K+. + V_leak : ArrayLike, default=-60. * u.mV + Equilibrium potential of leak current. + g_leak : ArrayLike, default=2. * u.msiemens + Conductance of leak current. + C : ArrayLike, default=20. * u.ufarad + Membrane capacitance. + V1 : ArrayLike, default=-1.2 * u.mV + Potential at which M_inf = 0.5. + V2 : ArrayLike, default=18. * u.mV + Reciprocal of slope of voltage dependence of M_inf. + V3 : ArrayLike, default=2. * u.mV + Potential at which W_inf = 0.5. + V4 : ArrayLike, default=30. * u.mV + Reciprocal of slope of voltage dependence of W_inf. + phi : ArrayLike, default=0.04 / u.ms + Temperature factor. + V_th : ArrayLike, default=10. * u.mV + Spike threshold. + V_initializer : Callable + Initializer for membrane potential. + W_initializer : Callable + Initializer for recovery variable. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + W : HiddenState + Recovery variable. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create a Morris-Lecar neuron layer with 10 neurons + >>> ml = brainpy.state.MorrisLecar(10) + >>> + >>> # Initialize the state + >>> ml.init_state(batch_size=1) + >>> + >>> # Apply an input current and update the neuron state + >>> spikes = ml.update(x=100.*u.uA) + + References + ---------- + .. [1] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. + .. [2] http://www.scholarpedia.org/article/Morris-Lecar_model + .. [3] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model + """ + + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + V_Ca: ArrayLike = 130. * u.mV, + g_Ca: ArrayLike = 4.4 * u.msiemens, + V_K: ArrayLike = -84. * u.mV, + g_K: ArrayLike = 8. * u.msiemens, + V_leak: ArrayLike = -60. * u.mV, + g_leak: ArrayLike = 2. * u.msiemens, + C: ArrayLike = 20. * u.ufarad, + V1: ArrayLike = -1.2 * u.mV, + V2: ArrayLike = 18. * u.mV, + V3: ArrayLike = 2. * u.mV, + V4: ArrayLike = 30. * u.mV, + phi: ArrayLike = 0.04 / u.ms, + V_th: ArrayLike = 10. * u.mV, + V_initializer: Callable = braintools.init.Uniform(-70. * u.mV, -60. * u.mV), + W_initializer: Callable = braintools.init.Constant(0.02), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.V_Ca = braintools.init.param(V_Ca, self.varshape) + self.g_Ca = braintools.init.param(g_Ca, self.varshape) + self.V_K = braintools.init.param(V_K, self.varshape) + self.g_K = braintools.init.param(g_K, self.varshape) + self.V_leak = braintools.init.param(V_leak, self.varshape) + self.g_leak = braintools.init.param(g_leak, self.varshape) + self.C = braintools.init.param(C, self.varshape) + self.V1 = braintools.init.param(V1, self.varshape) + self.V2 = braintools.init.param(V2, self.varshape) + self.V3 = braintools.init.param(V3, self.varshape) + self.V4 = braintools.init.param(V4, self.varshape) + self.phi = braintools.init.param(phi, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + + # initializers + self.V_initializer = V_initializer + self.W_initializer = W_initializer + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.W = brainstate.HiddenState(braintools.init.param(self.W_initializer, self.varshape, batch_size)) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.W.value = braintools.init.param(self.W_initializer, self.varshape, batch_size) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / self.V_th + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.uA): + last_V = self.V.value + last_W = self.W.value + + # Steady states + M_inf = 0.5 * (1. + u.math.tanh((last_V - self.V1) / self.V2)) + W_inf = 0.5 * (1. + u.math.tanh((last_V - self.V3) / self.V4)) + tau_W = 1. / (self.phi * u.math.cosh((last_V - self.V3) / (2. * self.V4))) + + # Ionic currents + I_Ca = self.g_Ca * M_inf * (last_V - self.V_Ca) + I_K = self.g_K * last_W * (last_V - self.V_K) + I_leak = self.g_leak * (last_V - self.V_leak) + + # Dynamics + I_total = self.sum_current_inputs(x, last_V) + dV = lambda V: (-I_Ca - I_K - I_leak + I_total) / self.C + dW = lambda W: (W_inf - W) / tau_W + + V = brainstate.nn.exp_euler_step(dV, last_V) + V = self.sum_delta_inputs(V) + W = brainstate.nn.exp_euler_step(dW, last_W) + + self.V.value = V + self.W.value = W + return self.get_spike(V) + + +class WangBuzsakiHH(Neuron): + r"""Wang-Buzsaki model, an implementation of a modified Hodgkin-Huxley model. + + Each model is described by a single compartment and obeys the current balance equation: + + $$ + C_{m} \frac{d V}{d t}=-I_{\mathrm{Na}}-I_{\mathrm{K}}-I_{\mathrm{L}}+I_{\mathrm{app}} + $$ + + where :math:`C_{m}=1 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`I_{\mathrm{app}}` is the + injected current (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ). The leak current + :math:`I_{\mathrm{L}}=g_{\mathrm{L}}\left(V-E_{\mathrm{L}}\right)` has a conductance + :math:`g_{\mathrm{L}}=0.1 \mathrm{mS} / \mathrm{cm}^{2}`. + + The spike-generating :math:`\mathrm{Na}^{+}` and :math:`\mathrm{K}^{+}` voltage-dependent ion + currents are of the Hodgkin-Huxley type. The transient sodium current + :math:`I_{\mathrm{Na}}=g_{\mathrm{Na}} m_{\infty}^{3} h\left(V-E_{\mathrm{Na}}\right)`, + where the activation variable :math:`m` is assumed fast and substituted by its steady-state + function :math:`m_{\infty}=\alpha_{m} /\left(\alpha_{m}+\beta_{m}\right)`; + :math:`\alpha_{m}(V)=-0.1(V+35) /(\exp (-0.1(V+35))-1)`, :math:`\beta_{m}(V)=4 \exp (-(V+60) / 18)`. + + The inactivation variable :math:`h` obeys: + + $$ + \frac{d h}{d t}=\phi\left(\alpha_{h}(1-h)-\beta_{h} h\right) + $$ + + where :math:`\alpha_{h}(V)=0.07 \exp (-(V+58) / 20)` and + :math:`\beta_{h}(V)=1 /(\exp (-0.1(V+28)) +1)`. + + The delayed rectifier :math:`I_{\mathrm{K}}=g_{\mathrm{K}} n^{4}\left(V-E_{\mathrm{K}}\right)`, + where the activation variable :math:`n` obeys: + + $$ + \frac{d n}{d t}=\phi\left(\alpha_{n}(1-n)-\beta_{n} n\right) + $$ + + with :math:`\alpha_{n}(V)=-0.01(V+34) /(\exp (-0.1(V+34))-1)` and + :math:`\beta_{n}(V)=0.125\exp (-(V+44) / 80)`. + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + ENa : ArrayLike, default=55. * u.mV + Reversal potential of sodium. + gNa : ArrayLike, default=35. * u.msiemens + Maximum conductance of sodium channel. + EK : ArrayLike, default=-90. * u.mV + Reversal potential of potassium. + gK : ArrayLike, default=9. * u.msiemens + Maximum conductance of potassium channel. + EL : ArrayLike, default=-65. * u.mV + Reversal potential of leak channel. + gL : ArrayLike, default=0.1 * u.msiemens + Conductance of leak channel. + V_th : ArrayLike, default=20. * u.mV + Threshold of the membrane spike. + phi : ArrayLike, default=5.0 + Temperature regulator constant. + C : ArrayLike, default=1.0 * u.ufarad + Membrane capacitance. + V_initializer : Callable + Initializer for membrane potential. + h_initializer : Callable + Initializer for h channel. + n_initializer : Callable + Initializer for n channel. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + h : HiddenState + Sodium inactivation variable. + n : HiddenState + Potassium activation variable. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create a WangBuzsakiHH neuron layer with 10 neurons + >>> wb = brainpy.state.WangBuzsakiHH(10) + >>> + >>> # Initialize the state + >>> wb.init_state(batch_size=1) + >>> + >>> # Apply an input current and update the neuron state + >>> spikes = wb.update(x=1.*u.uA) + + References + ---------- + .. [1] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic + inhibition in a hippocampal interneuronal network model. Journal of + neuroscience, 16(20), pp.6402-6413. + """ + + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + ENa: ArrayLike = 55. * u.mV, + gNa: ArrayLike = 35. * u.msiemens, + EK: ArrayLike = -90. * u.mV, + gK: ArrayLike = 9. * u.msiemens, + EL: ArrayLike = -65. * u.mV, + gL: ArrayLike = 0.1 * u.msiemens, + V_th: ArrayLike = 20. * u.mV, + phi: ArrayLike = 5.0, + C: ArrayLike = 1.0 * u.ufarad, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), + h_initializer: Callable = braintools.init.Constant(0.6), + n_initializer: Callable = braintools.init.Constant(0.32), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.ENa = braintools.init.param(ENa, self.varshape) + self.EK = braintools.init.param(EK, self.varshape) + self.EL = braintools.init.param(EL, self.varshape) + self.gNa = braintools.init.param(gNa, self.varshape) + self.gK = braintools.init.param(gK, self.varshape) + self.gL = braintools.init.param(gL, self.varshape) + self.phi = braintools.init.param(phi, self.varshape) + self.C = braintools.init.param(C, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + + # initializers + self.V_initializer = V_initializer + self.h_initializer = h_initializer + self.n_initializer = n_initializer + + def m_inf(self, V): + alpha = 1. / u.math.exprel(-0.1 * (V + 35. * u.mV) / u.mV) / u.ms + beta = 4. / u.ms * u.math.exp(-(V + 60. * u.mV) / (18. * u.mV)) + return alpha / (alpha + beta) + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.h = brainstate.HiddenState(braintools.init.param(self.h_initializer, self.varshape, batch_size)) + self.n = brainstate.HiddenState(braintools.init.param(self.n_initializer, self.varshape, batch_size)) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.h.value = braintools.init.param(self.h_initializer, self.varshape, batch_size) + self.n.value = braintools.init.param(self.n_initializer, self.varshape, batch_size) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / self.V_th + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.uA): + last_V = self.V.value + last_h = self.h.value + last_n = self.n.value + + # Ionic currents + m_inf_val = self.m_inf(last_V) + I_Na = self.gNa * m_inf_val ** 3 * last_h * (last_V - self.ENa) + I_K = self.gK * last_n ** 4 * (last_V - self.EK) + I_L = self.gL * (last_V - self.EL) + + # Voltage dynamics + I_total = self.sum_current_inputs(x, last_V) + dV = lambda V: (-I_Na - I_K - I_L + I_total) / self.C + + # Gating variable dynamics + h_alpha = 0.07 / u.ms * u.math.exp(-(last_V + 58. * u.mV) / (20. * u.mV)) + h_beta = 1. / u.ms / (u.math.exp(-0.1 * (last_V + 28. * u.mV) / u.mV) + 1.) + dh = lambda h: self.phi * (h_alpha * (1. - h) - h_beta * h) + + n_alpha = 1. / u.ms / u.math.exprel(-0.1 * (last_V + 34. * u.mV) / u.mV) + n_beta = 0.125 / u.ms * u.math.exp(-(last_V + 44. * u.mV) / (80. * u.mV)) + dn = lambda n: self.phi * (n_alpha * (1. - n) - n_beta * n) + + V = brainstate.nn.exp_euler_step(dV, last_V) + V = self.sum_delta_inputs(V) + h = brainstate.nn.exp_euler_step(dh, last_h) + n = brainstate.nn.exp_euler_step(dn, last_n) + + self.V.value = V + self.h.value = h + self.n.value = n + return self.get_spike(V) diff --git a/brainpy/state/_hh_test.py b/brainpy/state/_hh_test.py new file mode 100644 index 00000000..94930174 --- /dev/null +++ b/brainpy/state/_hh_test.py @@ -0,0 +1,303 @@ +# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# -*- coding: utf-8 -*- + + +import unittest + +import brainstate +import braintools +import brainunit as u +import jax +import jax.numpy as jnp + +from brainpy.state import HH, MorrisLecar, WangBuzsakiHH + + +class TestHHNeuron(unittest.TestCase): + def setUp(self): + self.in_size = 10 + self.batch_size = 5 + self.time_steps = 100 + self.dt = 0.01 * u.ms + + def generate_input(self): + return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.uA + + def test_hh_neuron(self): + with brainstate.environ.context(dt=self.dt): + neuron = HH(self.in_size) + inputs = self.generate_input() + + # Test initialization + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + + # Test forward pass + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + # Check state variables + self.assertEqual(neuron.V.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.m.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.h.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.n.value.shape, (self.batch_size, self.in_size)) + + def test_morris_lecar_neuron(self): + with brainstate.environ.context(dt=self.dt): + neuron = MorrisLecar(self.in_size) + inputs = self.generate_input() + + # Test initialization + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + + # Test forward pass + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + # Check state variables + self.assertEqual(neuron.V.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.W.value.shape, (self.batch_size, self.in_size)) + + def test_wang_buzsaki_hh_neuron(self): + with brainstate.environ.context(dt=self.dt): + neuron = WangBuzsakiHH(self.in_size) + inputs = self.generate_input() + + # Test initialization + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + + # Test forward pass + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + # Check state variables + self.assertEqual(neuron.V.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.h.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.n.value.shape, (self.batch_size, self.in_size)) + + def test_spike_function(self): + for NeuronClass in [HH, MorrisLecar, WangBuzsakiHH]: + neuron = NeuronClass(self.in_size) + neuron.init_state() + v = jnp.linspace(-80, 40, self.in_size) * u.mV + spikes = neuron.get_spike(v) + self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1))) + + def test_soft_reset(self): + for NeuronClass in [HH, MorrisLecar, WangBuzsakiHH]: + neuron = NeuronClass(self.in_size, spk_reset='soft') + inputs = self.generate_input() + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + out = call(inputs[t]) + # Check that voltage doesn't exceed threshold too much + self.assertTrue(jnp.all(neuron.V.value <= neuron.V_th + 20 * u.mV)) + + def test_hard_reset(self): + for NeuronClass in [HH, MorrisLecar, WangBuzsakiHH]: + neuron = NeuronClass(self.in_size, spk_reset='hard') + inputs = self.generate_input() + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + out = call(inputs[t]) + # Just check that it runs without error + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + def test_detach_spike(self): + for NeuronClass in [HH, MorrisLecar, WangBuzsakiHH]: + neuron = NeuronClass(self.in_size) + inputs = self.generate_input() + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type) + + def test_keep_size(self): + in_size = (2, 3) + for NeuronClass in [HH, MorrisLecar, WangBuzsakiHH]: + neuron = NeuronClass(in_size) + self.assertEqual(neuron.in_size, in_size) + self.assertEqual(neuron.out_size, in_size) + + inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.uA + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, *in_size)) + + def test_hh_gating_variables(self): + # Test that gating variables are properly initialized and updated + neuron = HH(self.in_size) + neuron.init_state(self.batch_size) + + # Check initial values are in valid range [0, 1] + self.assertTrue(jnp.all((neuron.m.value >= 0) & (neuron.m.value <= 1))) + self.assertTrue(jnp.all((neuron.h.value >= 0) & (neuron.h.value <= 1))) + self.assertTrue(jnp.all((neuron.n.value >= 0) & (neuron.n.value <= 1))) + + # Run for some time steps + inputs = self.generate_input() + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(20): + out = call(inputs[t]) + + # Gating variables should still be in valid range + self.assertTrue(jnp.all((neuron.m.value >= 0) & (neuron.m.value <= 1))) + self.assertTrue(jnp.all((neuron.h.value >= 0) & (neuron.h.value <= 1))) + self.assertTrue(jnp.all((neuron.n.value >= 0) & (neuron.n.value <= 1))) + + def test_hh_alpha_beta_functions(self): + # Test that alpha and beta functions return positive values + neuron = HH(self.in_size) + neuron.init_state() + + V_test = jnp.linspace(-80, 40, self.in_size) * u.mV + + m_alpha = neuron.m_alpha(V_test) + m_beta = neuron.m_beta(V_test) + h_alpha = neuron.h_alpha(V_test) + h_beta = neuron.h_beta(V_test) + n_alpha = neuron.n_alpha(V_test) + n_beta = neuron.n_beta(V_test) + + # All rate constants should be positive + if hasattr(m_alpha, 'mantissa'): + self.assertTrue(jnp.all(m_alpha.mantissa > 0)) + self.assertTrue(jnp.all(m_beta.mantissa > 0)) + self.assertTrue(jnp.all(h_alpha.mantissa > 0)) + self.assertTrue(jnp.all(h_beta.mantissa > 0)) + self.assertTrue(jnp.all(n_alpha.mantissa > 0)) + self.assertTrue(jnp.all(n_beta.mantissa > 0)) + else: + self.assertTrue(jnp.all(m_alpha > 0)) + self.assertTrue(jnp.all(m_beta > 0)) + self.assertTrue(jnp.all(h_alpha > 0)) + self.assertTrue(jnp.all(h_beta > 0)) + self.assertTrue(jnp.all(n_alpha > 0)) + self.assertTrue(jnp.all(n_beta > 0)) + + def test_morris_lecar_steady_states(self): + # Test that steady-state functions return values in valid range + neuron = MorrisLecar(self.in_size) + neuron.init_state() + + V_test = jnp.linspace(-100, 50, self.in_size) * u.mV + + # Manually compute steady states + M_inf = 0.5 * (1. + u.math.tanh((V_test - neuron.V1) / neuron.V2)) + W_inf = 0.5 * (1. + u.math.tanh((V_test - neuron.V3) / neuron.V4)) + + # Steady states should be in [0, 1] + if hasattr(M_inf, 'mantissa'): + self.assertTrue(jnp.all((M_inf.mantissa >= 0) & (M_inf.mantissa <= 1))) + self.assertTrue(jnp.all((W_inf.mantissa >= 0) & (W_inf.mantissa <= 1))) + else: + self.assertTrue(jnp.all((M_inf >= 0) & (M_inf <= 1))) + self.assertTrue(jnp.all((W_inf >= 0) & (W_inf <= 1))) + + def test_wang_buzsaki_m_inf(self): + # Test that m_inf is properly computed and in valid range + neuron = WangBuzsakiHH(self.in_size) + neuron.init_state() + + V_test = jnp.linspace(-80, 40, self.in_size) * u.mV + m_inf = neuron.m_inf(V_test) + + # m_inf should be in [0, 1] + if hasattr(m_inf, 'mantissa'): + self.assertTrue(jnp.all((m_inf.mantissa >= 0) & (m_inf.mantissa <= 1))) + else: + self.assertTrue(jnp.all((m_inf >= 0) & (m_inf <= 1))) + + def test_different_parameters(self): + # Test HH with different conductance values + hh_custom = HH( + self.in_size, + ENa=50. * u.mV, + gNa=100. * u.msiemens, + EK=-80. * u.mV, + gK=30. * u.msiemens + ) + hh_custom.init_state(self.batch_size) + self.assertEqual(hh_custom.ENa, 50. * u.mV) + self.assertEqual(hh_custom.gNa, 100. * u.msiemens) + + # Test MorrisLecar with different parameters + ml_custom = MorrisLecar( + self.in_size, + V_Ca=120. * u.mV, + g_Ca=4.0 * u.msiemens, + phi=0.05 / u.ms + ) + ml_custom.init_state(self.batch_size) + self.assertEqual(ml_custom.V_Ca, 120. * u.mV) + self.assertEqual(ml_custom.phi, 0.05 / u.ms) + + # Test WangBuzsakiHH with different phi + wb_custom = WangBuzsakiHH( + self.in_size, + phi=10.0 + ) + wb_custom.init_state(self.batch_size) + if hasattr(wb_custom.phi, 'mantissa'): + self.assertEqual(float(wb_custom.phi.mantissa), 10.0) + else: + self.assertEqual(float(wb_custom.phi), 10.0) + + def test_ionic_currents(self): + # Test that ionic currents are computed + neuron = HH(self.in_size) + neuron.init_state(self.batch_size) + + # Run one update + inputs = jnp.ones((self.batch_size, self.in_size)) * 10. * u.uA + with brainstate.environ.context(dt=self.dt): + out = neuron.update(inputs) + + # Check that state variables have changed (indicating currents were applied) + initial_V = braintools.init.param(neuron.V_initializer, neuron.varshape, self.batch_size) + if hasattr(initial_V, 'mantissa'): + self.assertFalse(jnp.allclose(neuron.V.value.mantissa, initial_V.mantissa)) + else: + self.assertFalse(jnp.allclose(neuron.V.value, initial_V)) + + +if __name__ == '__main__': + unittest.main() diff --git a/brainpy/state/_inputs.py b/brainpy/state/_inputs.py index 6ab87a85..8b88f9c8 100644 --- a/brainpy/state/_inputs.py +++ b/brainpy/state/_inputs.py @@ -60,7 +60,7 @@ class SpikeTime(brainstate.nn.Dynamics): name : str, optional The name of the dynamic system. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -116,7 +116,7 @@ class PoissonSpike(brainstate.nn.Dynamics): """ Poisson Neuron Group. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -173,7 +173,7 @@ class PoissonEncoder(brainstate.nn.Dynamics): >>> import numpy as np >>> >>> # Create a Poisson encoder for 10 neurons - >>> encoder = brainpy.PoissonEncoder(10) + >>> encoder = brainpy.state.PoissonEncoder(10) >>> >>> # Generate spikes with varying firing rates >>> rates = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100]) * u.Hz @@ -186,7 +186,7 @@ class PoissonEncoder(brainstate.nn.Dynamics): >>> spike_train = encoder.update(firing_rates) >>> >>> # Feed the spikes into a spiking neural network - >>> neuron_layer = brainpy.LIF(10) + >>> neuron_layer = brainpy.state.LIF(10) >>> neuron_layer.init_state(batch_size=1) >>> output_spikes = neuron_layer.update(spike_train) @@ -203,7 +203,7 @@ class PoissonEncoder(brainstate.nn.Dynamics): - The independence of spike generation between time steps results in renewal process statistics without memory of previous spiking history. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -278,11 +278,11 @@ class PoissonInput(brainstate.nn.Module): >>> import numpy as np >>> >>> # Create a neuron group with membrane potential - >>> neuron = brainpy.LIF(100) + >>> neuron = brainpy.state.LIF(100) >>> neuron.init_state(batch_size=1) >>> >>> # Add Poisson input to all neurons - >>> poisson_in = brainpy.PoissonInput( + >>> poisson_in = brainpy.state.PoissonInput( ... target=neuron.V, ... indices=None, ... num_input=200, @@ -292,7 +292,7 @@ class PoissonInput(brainstate.nn.Module): >>> >>> # Add Poisson input only to specific neurons >>> indices = np.array([0, 10, 20, 30]) - >>> specific_input = brainpy.PoissonInput( + >>> specific_input = brainpy.state.PoissonInput( ... target=neuron.V, ... indices=indices, ... num_input=50, @@ -316,7 +316,7 @@ class PoissonInput(brainstate.nn.Module): - The update method internally calls the poisson_input function which handles the spike generation and target state updates. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -348,7 +348,7 @@ def update(self): ) -@set_module_as('brainpy') +@set_module_as('brainpy.state') def poisson_input( freq: u.Quantity[u.Hz], num_input: int, @@ -421,7 +421,7 @@ def poisson_input( >>> V = brainstate.HiddenState(np.zeros(100) * u.mV) >>> >>> # Add Poisson input to all neurons at 50 Hz - >>> brainpy.poisson_input( + >>> brainpy.state.poisson_input( ... freq=50 * u.Hz, ... num_input=200, ... weight=0.1 * u.mV, @@ -430,7 +430,7 @@ def poisson_input( >>> >>> # Apply Poisson input only to a subset of neurons >>> indices = np.array([0, 10, 20, 30]) - >>> brainpy.poisson_input( + >>> brainpy.state.poisson_input( ... freq=100 * u.Hz, ... num_input=50, ... weight=0.2 * u.mV, @@ -441,7 +441,7 @@ def poisson_input( >>> # Apply input with refractory mask >>> refractory = np.zeros(100, dtype=bool) >>> refractory[40:60] = True # neurons 40-59 are in refractory period - >>> brainpy.poisson_input( + >>> brainpy.state.poisson_input( ... freq=75 * u.Hz, ... num_input=100, ... weight=0.15 * u.mV, diff --git a/brainpy/state/_izhikevich.py b/brainpy/state/_izhikevich.py new file mode 100644 index 00000000..586e23c7 --- /dev/null +++ b/brainpy/state/_izhikevich.py @@ -0,0 +1,407 @@ +# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# -*- coding: utf-8 -*- + +from typing import Callable + +import brainstate +import braintools +import brainunit as u +import jax +from brainstate.typing import ArrayLike, Size + +from ._base import Neuron + +__all__ = [ + 'Izhikevich', 'IzhikevichRef', +] + + +class Izhikevich(Neuron): + r"""Izhikevich neuron model. + + This class implements the Izhikevich neuron model, a two-dimensional spiking neuron + model that can reproduce a wide variety of neuronal firing patterns observed in + biological neurons. The model combines computational efficiency with biological + plausibility through a quadratic voltage dynamics and a linear recovery variable. + + The model is characterized by the following differential equations: + + $$ + \frac{dV}{dt} = 0.04 V^2 + 5V + 140 - u + I(t) + $$ + + $$ + \frac{du}{dt} = a(bV - u) + $$ + + Spike condition: + If $V \geq V_{th}$: emit spike, set $V = c$ and $u = u + d$ + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + a : ArrayLike, default=0.02 / u.ms + Time scale of the recovery variable u. Smaller values result in slower recovery. + b : ArrayLike, default=0.2 / u.ms + Sensitivity of the recovery variable u to the membrane potential V. + c : ArrayLike, default=-65. * u.mV + After-spike reset value of the membrane potential. + d : ArrayLike, default=8. * u.mV / u.ms + After-spike increment of the recovery variable u. + V_th : ArrayLike, default=30. * u.mV + Spike threshold voltage. + V_initializer : Callable + Initializer for the membrane potential state. + u_initializer : Callable + Initializer for the recovery variable state. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function for the non-differentiable spike generation. + spk_reset : str, default='hard' + Reset mechanism after spike generation: + - 'soft': subtract threshold V = V - V_th + - 'hard': strict reset using stop_gradient + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + u : HiddenState + Recovery variable. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create an Izhikevich neuron layer with 10 neurons + >>> izh = brainpy.state.Izhikevich(10) + >>> + >>> # Initialize the state + >>> izh.init_state(batch_size=1) + >>> + >>> # Apply an input current and update the neuron state + >>> spikes = izh.update(x=10.*u.mV/u.ms) + + Notes + ----- + - The quadratic term in the voltage equation (0.04*V^2) provides a sharp spike + upstroke similar to biological neurons. + - Different combinations of parameters (a, b, c, d) can reproduce various neuronal + behaviors including regular spiking, intrinsically bursting, chattering, and + fast spiking. + - The model uses a hard reset mechanism where V is set to c and u is incremented + by d when a spike occurs. + - Parameter ranges: a ∈ [0.01, 0.1], b ∈ [0.2, 0.3], c ∈ [-65, -50], d ∈ [0.1, 10] + + References + ---------- + .. [1] Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions + on neural networks, 14(6), 1569-1572. + .. [2] Izhikevich, E. M. (2004). Which model to use for cortical spiking neurons?. + IEEE transactions on neural networks, 15(5), 1063-1070. + """ + + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + a: ArrayLike = 0.02 / u.ms, + b: ArrayLike = 0.2 / u.ms, + c: ArrayLike = -65. * u.mV, + d: ArrayLike = 8. * u.mV / u.ms, + V_th: ArrayLike = 30. * u.mV, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), + u_initializer: Callable = braintools.init.Constant(0. * u.mV / u.ms), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'hard', + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.a = braintools.init.param(a, self.varshape) + self.b = braintools.init.param(b, self.varshape) + self.c = braintools.init.param(c, self.varshape) + self.d = braintools.init.param(d, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + + # pre-computed coefficients for quadratic equation + self.p1 = 0.04 / (u.ms * u.mV) + self.p2 = 5. / u.ms + self.p3 = 140. * u.mV / u.ms + + # initializers + self.V_initializer = V_initializer + self.u_initializer = u_initializer + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.u = brainstate.HiddenState(braintools.init.param(self.u_initializer, self.varshape, batch_size)) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.u.value = braintools.init.param(self.u_initializer, self.varshape, batch_size) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / self.V_th + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mV / u.ms): + last_v = self.V.value + last_u = self.u.value + last_spk = self.get_spike(last_v) + + # Izhikevich uses hard reset: V → c, u → u + d + V = u.math.where(last_spk > 0., self.c, last_v) + u_val = last_u + self.d * last_spk + + # voltage dynamics: dV/dt = 0.04*V^2 + 5*V + 140 - u + I + def dv(v): + I_total = self.sum_current_inputs(x, v) + return self.p1 * v * v + self.p2 * v + self.p3 - u_val + I_total + + # recovery dynamics: du/dt = a(bV - u) + def du(u_): + return self.a * (self.b * V - u_) + + V = brainstate.nn.exp_euler_step(dv, V) + V = self.sum_delta_inputs(V) + u_val = brainstate.nn.exp_euler_step(du, u_val) + + self.V.value = V + self.u.value = u_val + return self.get_spike(V) + + +class IzhikevichRef(Neuron): + r"""Izhikevich neuron model with refractory period. + + This class implements the Izhikevich neuron model with an absolute refractory period. + During the refractory period after a spike, the neuron cannot fire regardless of input, + which better captures the behavior of biological neurons that exhibit a recovery period + after action potential generation. + + The model is characterized by the following equations: + + When not in refractory period: + + $$ + \frac{dV}{dt} = 0.04 V^2 + 5V + 140 - u + I(t) + $$ + + $$ + \frac{du}{dt} = a(bV - u) + $$ + + During refractory period: + + $$ + V = c, \quad u = u + $$ + + Spike condition: + If $V \geq V_{th}$ and not in refractory period: emit spike, set $V = c$, $u = u + d$, + and enter refractory period for $\tau_{ref}$ + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + a : ArrayLike, default=0.02 / u.ms + Time scale of the recovery variable u. + b : ArrayLike, default=0.2 / u.ms + Sensitivity of the recovery variable u to the membrane potential V. + c : ArrayLike, default=-65. * u.mV + After-spike reset value of the membrane potential. + d : ArrayLike, default=8. * u.mV / u.ms + After-spike increment of the recovery variable u. + V_th : ArrayLike, default=30. * u.mV + Spike threshold voltage. + tau_ref : ArrayLike, default=0. * u.ms + Refractory period duration. + V_initializer : Callable + Initializer for the membrane potential state. + u_initializer : Callable + Initializer for the recovery variable state. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function for the non-differentiable spike generation. + spk_reset : str, default='hard' + Reset mechanism after spike generation. + ref_var : bool, default=False + Whether to expose a boolean refractory state variable. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + u : HiddenState + Recovery variable. + last_spike_time : ShortTermState + Time of the last spike, used to implement refractory period. + refractory : HiddenState + Neuron refractory state (if ref_var=True). + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create an IzhikevichRef neuron layer with 10 neurons + >>> izh_ref = brainpy.state.IzhikevichRef(10, tau_ref=2.*u.ms) + >>> + >>> # Initialize the state + >>> izh_ref.init_state(batch_size=1) + >>> + >>> # Generate inputs and run simulation + >>> time_steps = 100 + >>> inputs = brainstate.random.randn(time_steps, 1, 10) * u.mV / u.ms + >>> + >>> with brainstate.environ.context(dt=0.1 * u.ms): + >>> for t in range(time_steps): + >>> with brainstate.environ.context(t=t*0.1*u.ms): + >>> spikes = izh_ref.update(x=inputs[t]) + + Notes + ----- + - The refractory period is implemented by tracking the time of the last spike + and preventing membrane potential updates if the elapsed time is less than tau_ref. + - During the refractory period, the membrane potential remains at the reset value c + regardless of input current strength. + - Refractory periods prevent high-frequency repetitive firing and are critical + for realistic neural dynamics. + - The simulation environment time variable 't' is used to track the refractory state. + + References + ---------- + .. [1] Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions + on neural networks, 14(6), 1569-1572. + .. [2] Izhikevich, E. M. (2004). Which model to use for cortical spiking neurons?. + IEEE transactions on neural networks, 15(5), 1063-1070. + """ + + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + a: ArrayLike = 0.02 / u.ms, + b: ArrayLike = 0.2 / u.ms, + c: ArrayLike = -65. * u.mV, + d: ArrayLike = 8. * u.mV / u.ms, + V_th: ArrayLike = 30. * u.mV, + tau_ref: ArrayLike = 0. * u.ms, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), + u_initializer: Callable = braintools.init.Constant(0. * u.mV / u.ms), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'hard', + ref_var: bool = False, + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.a = braintools.init.param(a, self.varshape) + self.b = braintools.init.param(b, self.varshape) + self.c = braintools.init.param(c, self.varshape) + self.d = braintools.init.param(d, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + self.tau_ref = braintools.init.param(tau_ref, self.varshape) + + # pre-computed coefficients for quadratic equation + self.p1 = 0.04 / (u.ms * u.mV) + self.p2 = 5. / u.ms + self.p3 = 140. * u.mV / u.ms + + # initializers + self.V_initializer = V_initializer + self.u_initializer = u_initializer + self.ref_var = ref_var + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.u = brainstate.HiddenState(braintools.init.param(self.u_initializer, self.varshape, batch_size)) + self.last_spike_time = brainstate.ShortTermState( + braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) + ) + if self.ref_var: + self.refractory = brainstate.HiddenState( + braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) + ) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.u.value = braintools.init.param(self.u_initializer, self.varshape, batch_size) + self.last_spike_time.value = braintools.init.param( + braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size + ) + if self.ref_var: + self.refractory.value = braintools.init.param( + braintools.init.Constant(False), self.varshape, batch_size + ) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / self.V_th + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mV / u.ms): + t = brainstate.environ.get('t') + last_v = self.V.value + last_u = self.u.value + last_spk = self.get_spike(last_v) + + # Izhikevich uses hard reset: V → c, u → u + d + v_reset = u.math.where(last_spk > 0., self.c, last_v) + u_reset = last_u + self.d * last_spk + + # voltage dynamics: dV/dt = 0.04*V^2 + 5*V + 140 - u + I + def dv(v): + I_total = self.sum_current_inputs(x, v) + return self.p1 * v * v + self.p2 * v + self.p3 - u_reset + I_total + + # recovery dynamics: du/dt = a(bV - u) + def du(u_): + return self.a * (self.b * V_candidate - u_) + + V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) + V_candidate = self.sum_delta_inputs(V_candidate) + u_candidate = brainstate.nn.exp_euler_step(du, u_reset) + + # apply refractory period + refractory = (t - self.last_spike_time.value) < self.tau_ref + self.V.value = u.math.where(refractory, v_reset, V_candidate) + self.u.value = u.math.where(refractory, u_reset, u_candidate) + + # spike time evaluation + spike_cond = self.V.value >= self.V_th + self.last_spike_time.value = jax.lax.stop_gradient( + u.math.where(spike_cond, t, self.last_spike_time.value) + ) + if self.ref_var: + self.refractory.value = jax.lax.stop_gradient( + u.math.logical_or(refractory, spike_cond) + ) + return self.get_spike() diff --git a/brainpy/state/_izhikevich_test.py b/brainpy/state/_izhikevich_test.py new file mode 100644 index 00000000..c6cd8a23 --- /dev/null +++ b/brainpy/state/_izhikevich_test.py @@ -0,0 +1,291 @@ +# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# -*- coding: utf-8 -*- + + +import unittest + +import brainstate +import brainunit as u +import jax +import jax.numpy as jnp + +from brainpy.state import Izhikevich, IzhikevichRef + + +class TestIzhikevichNeuron(unittest.TestCase): + def setUp(self): + self.in_size = 10 + self.batch_size = 5 + self.time_steps = 100 + self.dt = 0.1 * u.ms + + def generate_input(self): + return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mV / u.ms + + def test_izhikevich_neuron(self): + with brainstate.environ.context(dt=self.dt): + neuron = Izhikevich(self.in_size) + inputs = self.generate_input() + + # Test initialization + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + + # Test forward pass + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + # Check state variables + self.assertEqual(neuron.V.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.u.value.shape, (self.batch_size, self.in_size)) + + def test_izhikevich_ref_neuron(self): + tau_ref = 2.0 * u.ms + neuron = IzhikevichRef(self.in_size, tau_ref=tau_ref) + inputs = self.generate_input() + + # Test initialization + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau_ref, tau_ref) + + # Test forward pass + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t * self.dt): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + # Check state variables + self.assertEqual(neuron.V.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.u.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.last_spike_time.value.shape, (self.batch_size, self.in_size)) + + def test_izhikevich_ref_with_ref_var(self): + tau_ref = 2.0 * u.ms + ref_var = True + neuron = IzhikevichRef(self.in_size, tau_ref=tau_ref, ref_var=ref_var) + inputs = self.generate_input() + + # Test initialization + self.assertEqual(neuron.ref_var, ref_var) + + # Test forward pass + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t * self.dt): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + # Check refractory variable + if neuron.ref_var: + self.assertEqual(neuron.refractory.value.shape, (self.batch_size, self.in_size)) + + def test_spike_function(self): + for NeuronClass in [Izhikevich, IzhikevichRef]: + neuron = NeuronClass(self.in_size) + neuron.init_state() + v = jnp.linspace(-80, 40, self.in_size) * u.mV + spikes = neuron.get_spike(v) + self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1))) + + def test_soft_reset(self): + for NeuronClass in [Izhikevich, IzhikevichRef]: + neuron = NeuronClass(self.in_size, spk_reset='soft') + inputs = self.generate_input() + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t * self.dt): + out = call(inputs[t]) + # For Izhikevich model, soft reset still applies hard reset logic + # So we just check that V doesn't exceed V_th significantly + self.assertTrue(jnp.all(neuron.V.value <= neuron.V_th + 10 * u.mV)) + + def test_hard_reset(self): + for NeuronClass in [Izhikevich, IzhikevichRef]: + neuron = NeuronClass(self.in_size, spk_reset='hard') + inputs = self.generate_input() + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t * self.dt): + out = call(inputs[t]) + # For Izhikevich, after spike V should be reset to c + # Check that V is either below threshold or near reset value + above_c = neuron.V.value >= (neuron.c - 5 * u.mV) + below_th = neuron.V.value < neuron.V_th + self.assertTrue(jnp.all(above_c | below_th)) + + def test_detach_spike(self): + for NeuronClass in [Izhikevich, IzhikevichRef]: + neuron = NeuronClass(self.in_size) + inputs = self.generate_input() + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t * self.dt): + out = call(inputs[t]) + self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type) + + def test_keep_size(self): + in_size = (2, 3) + for NeuronClass in [Izhikevich, IzhikevichRef]: + neuron = NeuronClass(in_size) + self.assertEqual(neuron.in_size, in_size) + self.assertEqual(neuron.out_size, in_size) + + inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mV / u.ms + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t * self.dt): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, *in_size)) + + def test_different_parameters(self): + # Test regular spiking (RS) parameters + rs_neuron = Izhikevich( + self.in_size, + a=0.02 / u.ms, + b=0.2 / u.ms, + c=-65. * u.mV, + d=8. * u.mV / u.ms + ) + rs_neuron.init_state(self.batch_size) + self.assertEqual(rs_neuron.a, 0.02 / u.ms) + self.assertEqual(rs_neuron.b, 0.2 / u.ms) + + # Test intrinsically bursting (IB) parameters + ib_neuron = Izhikevich( + self.in_size, + a=0.02 / u.ms, + b=0.2 / u.ms, + c=-55. * u.mV, + d=4. * u.mV / u.ms + ) + ib_neuron.init_state(self.batch_size) + self.assertEqual(ib_neuron.c, -55. * u.mV) + self.assertEqual(ib_neuron.d, 4. * u.mV / u.ms) + + # Test chattering (CH) parameters + ch_neuron = Izhikevich( + self.in_size, + a=0.02 / u.ms, + b=0.2 / u.ms, + c=-50. * u.mV, + d=2. * u.mV / u.ms + ) + ch_neuron.init_state(self.batch_size) + self.assertEqual(ch_neuron.c, -50. * u.mV) + + # Test fast spiking (FS) parameters + fs_neuron = Izhikevich( + self.in_size, + a=0.1 / u.ms, + b=0.2 / u.ms, + c=-65. * u.mV, + d=2. * u.mV / u.ms + ) + fs_neuron.init_state(self.batch_size) + self.assertEqual(fs_neuron.a, 0.1 / u.ms) + + def test_refractory_period_effectiveness(self): + # Test that refractory period actually prevents firing + tau_ref = 5.0 * u.ms + neuron = IzhikevichRef(self.in_size, tau_ref=tau_ref) + neuron.init_state(self.batch_size) + + # Strong constant input to encourage firing + strong_input = jnp.ones((self.batch_size, self.in_size)) * 20. * u.mV / u.ms + + spike_times = [] + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t * self.dt): + out = call(strong_input) + if jnp.any(out > 0): + spike_times.append(t * self.dt) + + # Check that consecutive spikes are separated by at least tau_ref + if len(spike_times) > 1: + for i in range(len(spike_times) - 1): + time_diff = spike_times[i + 1] - spike_times[i] + # Allow small numerical errors + self.assertGreaterEqual(time_diff.to_value(u.ms), (tau_ref - 0.5 * self.dt).to_value(u.ms)) + + def test_quadratic_dynamics(self): + # Test that the quadratic term in voltage dynamics is working + neuron = Izhikevich(self.in_size) + neuron.init_state(1) + + # Set initial conditions + V_low = -70. * u.mV + V_high = -50. * u.mV + + # Check that dV/dt has quadratic relationship with V + # At low V, dV/dt should be more negative + # At high V, dV/dt should be less negative or positive + + # This is a qualitative test to ensure the quadratic term is present + # coefficient p1 should be positive for upward parabola + # p1 = 0.04 / (ms * mV), just check it's set and positive + self.assertIsNotNone(neuron.p1) + # Extract the mantissa value for comparison + if hasattr(neuron.p1, 'mantissa'): + self.assertGreater(float(neuron.p1.mantissa), 0) + else: + self.assertGreater(float(neuron.p1), 0) + + def test_recovery_variable_dynamics(self): + # Test that recovery variable u properly tracks and affects V + neuron = Izhikevich(self.in_size) + neuron.init_state(self.batch_size) + + initial_u = neuron.u.value.mantissa.copy() + + # Run for some time steps with moderate input + moderate_input = jnp.ones((self.batch_size, self.in_size)) * 5. * u.mV / u.ms + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(20): + call(moderate_input) + + # u should change from initial value + self.assertFalse(jnp.allclose(neuron.u.value.mantissa, initial_u, rtol=0.01)) + + # After a spike, u should increase by d + # (This is implicitly tested in the spike generation tests) + + +if __name__ == '__main__': + unittest.main() diff --git a/brainpy/state/_lif.py b/brainpy/state/_lif.py index 0a4eb164..713f4fe2 100644 --- a/brainpy/state/_lif.py +++ b/brainpy/state/_lif.py @@ -26,7 +26,8 @@ from ._base import Neuron __all__ = [ - 'IF', 'LIF', 'LIFRef', 'ALIF', + 'IF', 'LIF', 'ExpIF', 'ExpIFRef', 'AdExIF', 'AdExIFRef', 'LIFRef', 'ALIF', + 'QuaIF', 'AdQuaIF', 'AdQuaIFRef', 'Gif', 'GifRef', ] @@ -112,7 +113,7 @@ class IF(Neuron): I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -207,12 +208,12 @@ class LIF(Neuron): Examples -------- - >>> import brainpy.state as brainpy + >>> import brainpy >>> import brainstate >>> import brainunit as u >>> >>> # Create a LIF neuron layer with 10 neurons - >>> lif = brainpy.LIF(10, tau=10*u.ms, V_th=0.8*u.mV) + >>> lif = brainpy.state.LIF(10, tau=10*u.ms, V_th=0.8*u.mV) >>> >>> # Initialize the state >>> lif.init_state(batch_size=1) @@ -236,7 +237,7 @@ class LIF(Neuron): .. [2] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model: I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -274,9 +275,9 @@ def get_spike(self, V: ArrayLike = None): def update(self, x=0. * u.mA): last_v = self.V.value - lst_spk = self.get_spike(last_v) + last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) - V = last_v - (V_th - self.V_reset) * lst_spk + V = last_v - (V_th - self.V_reset) * last_spk # membrane potential dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau V = brainstate.nn.exp_euler_step(dv, V) @@ -285,28 +286,22 @@ def update(self, x=0. * u.mA): return self.get_spike(V) -class LIFRef(Neuron): - r"""Leaky Integrate-and-Fire neuron model with refractory period. - - This class implements a Leaky Integrate-and-Fire neuron model that includes a - refractory period after spiking, during which the neuron cannot fire regardless - of input. This better captures the behavior of biological neurons that exhibit - a recovery period after action potential generation. +class ExpIF(Neuron): + r"""Exponential Integrate-and-Fire (ExpIF) neuron model. - The model is characterized by the following equations: + This model augments the LIF neuron by adding an exponential spike-initiation + term, which provides a smooth approximation of the action potential onset + and improves biological plausibility for cortical pyramidal cells. - When not in refractory period: - $$ - \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t) - $$ + The membrane potential dynamics follow: - During refractory period: $$ - V = V_{reset} + \tau \frac{dV}{dt} = -(V - V_{rest}) + \Delta_T \exp\left(\frac{V - V_T}{\Delta_T}\right) + R \cdot I(t) $$ Spike condition: - If $V \geq V_{th}$: emit spike, set $V = V_{reset}$, and enter refractory period for $\tau_{ref}$ + If $V \geq V_{th}$: emit spike and reset $V = V_{reset}$ (hard reset) or + $V = V - (V_{th} - V_{reset})$ (soft reset). Parameters ---------- @@ -314,24 +309,24 @@ class LIFRef(Neuron): Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. - tau : ArrayLike, default=5. * u.ms + tau : ArrayLike, default=10. * u.ms Membrane time constant. - tau_ref : ArrayLike, default=5. * u.ms - Refractory period duration. - V_th : ArrayLike, default=1. * u.mV - Firing threshold voltage. - V_reset : ArrayLike, default=0. * u.mV + V_th : ArrayLike, default=-30. * u.mV + Numerical firing threshold voltage. + V_reset : ArrayLike, default=-68. * u.mV Reset voltage after spike. - V_rest : ArrayLike, default=0. * u.mV + V_rest : ArrayLike, default=-65. * u.mV Resting membrane potential. + V_T : ArrayLike, default=-59.9 * u.mV + Threshold potential of the exponential term. + delta_T : ArrayLike, default=3.48 * u.mV + Spike slope factor controlling the sharpness of spike initiation. V_initializer : Callable Initializer for the membrane potential state. spk_fun : Callable, default=surrogate.ReluGrad() - Surrogate gradient function for the non-differentiable spike generation. + Surrogate gradient function for the spike generation. spk_reset : str, default='soft' - Reset mechanism after spike generation: - - 'soft': subtract threshold V = V - V_th - - 'hard': strict reset using stop_gradient + Reset mechanism after spike generation. name : str, optional Name of the neuron layer. @@ -339,66 +334,65 @@ class LIFRef(Neuron): ---------- V : HiddenState Membrane potential. - last_spike_time : ShortTermState - Time of the last spike, used to implement refractory period. Examples -------- - >>> import brainpy.state as brainpy + >>> import brainpy >>> import brainstate >>> import brainunit as u >>> - >>> # Create a LIFRef neuron layer with 10 neurons - >>> lifref = brainpy.LIFRef(10, - ... tau=10*u.ms, - ... tau_ref=5*u.ms, - ... V_th=0.8*u.mV) + >>> # Create a ExpIF neuron layer with 10 neurons + >>> expif = brainpy.state.ExpIF(10, tau=10*u.ms, V_th=-30*u.mV) >>> >>> # Initialize the state - >>> lifref.init_state(batch_size=1) + >>> expif.init_state(batch_size=1) >>> >>> # Apply an input current and update the neuron state - >>> spikes = lifref.update(x=1.5*u.mA) - >>> - >>> # Create a network with refractory neurons - >>> network = brainstate.nn.Sequential([ - ... brainpy.LIFRef(100, tau_ref=4*u.ms), - ... brainstate.nn.Linear(100, 10) - ... ]) + >>> spikes = expif.update(x=1.5*u.mA) Notes ----- - - The refractory period is implemented by tracking the time of the last spike - and preventing membrane potential updates if the elapsed time is less than tau_ref. - - During the refractory period, the membrane potential remains at the reset value - regardless of input current strength. - - Refractory periods prevent high-frequency repetitive firing and are critical - for realistic neural dynamics. - - The time-dependent dynamics are integrated using an exponential Euler method. - - The simulation environment time variable 't' is used to track the refractory state. + - The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk + and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. + It is one of the prominent examples of a precise theoretical prediction in computational + neuroscience that was later confirmed by experimental neuroscience. + - The right-hand side of the above equation contains a nonlinearity + that can be directly extracted from experimental data [3]_. In this sense the exponential + nonlinearity is not an arbitrary choice but directly supported by experimental evidence. + - Even though it is a nonlinear model, it is simple enough to calculate the firing + rate for constant input, and the linear response to fluctuations, even in the presence + of input noise [4]_. References ---------- - .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). - Neuronal dynamics: From single neurons to networks and models of cognition. - Cambridge University Press. - .. [2] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model: - I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19. - .. [3] Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on - neural networks, 14(6), 1569-1572. + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). + Neuronal dynamics: From single neurons to networks and models + of cognition. Cambridge University Press. + .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, + Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves + are reliable predictors of naturalistic pyramidal-neuron voltage + traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. + .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear + integrate-and-fire neurons to modulated current-based and + conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. + .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, - tau: ArrayLike = 5. * u.ms, - tau_ref: ArrayLike = 5. * u.ms, - V_th: ArrayLike = 1. * u.mV, - V_reset: ArrayLike = 0. * u.mV, - V_rest: ArrayLike = 0. * u.mV, - V_initializer: Callable = braintools.init.Constant(0. * u.mV), + tau: ArrayLike = 10. * u.ms, + V_th: ArrayLike = -30. * u.mV, + V_reset: ArrayLike = -68. * u.mV, + V_rest: ArrayLike = -65. * u.mV, + V_T: ArrayLike = -59.9 * u.mV, + delta_T: ArrayLike = 3.48 * u.mV, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', name: str = None, @@ -408,23 +402,162 @@ def __init__( # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) - self.tau_ref = braintools.init.param(tau_ref, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) + self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) + self.V_T = braintools.init.param(V_T, self.varshape) + self.delta_T = braintools.init.param(delta_T, self.varshape) + self.V_initializer = V_initializer + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mA): + last_v = self.V.value + last_spk = self.get_spike(last_v) + V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) + V = last_v - (V_th - self.V_reset) * last_spk + + def dv(v): + exp_term = self.delta_T * u.math.exp((v - self.V_T) / self.delta_T) + return (-(v - self.V_rest) + exp_term + self.R * self.sum_current_inputs(x, v)) / self.tau + + V = brainstate.nn.exp_euler_step(dv, V) + V = self.sum_delta_inputs(V) + self.V.value = V + return self.get_spike(V) + + +class ExpIFRef(Neuron): + r"""Exponential Integrate-and-Fire neuron model with refractory mechanism. + + This neuron adds an absolute refractory period to :class:`ExpIF`. While the exponential + spike-initiation term keeps the membrane potential dynamics smooth, the refractory + mechanism prevents the neuron from firing within ``tau_ref`` after a spike. + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + R : ArrayLike, default=1. * u.ohm + Membrane resistance. + tau : ArrayLike, default=10. * u.ms + Membrane time constant. + tau_ref : ArrayLike, default=1.7 * u.ms + Absolute refractory period duration. + V_th : ArrayLike, default=-30. * u.mV + Numerical firing threshold voltage. + V_reset : ArrayLike, default=-68. * u.mV + Reset voltage after spike. + V_rest : ArrayLike, default=-65. * u.mV + Resting membrane potential. + V_T : ArrayLike, default=-59.9 * u.mV + Threshold potential of the exponential term. + delta_T : ArrayLike, default=3.48 * u.mV + Spike slope factor controlling spike initiation sharpness. + V_initializer : Callable + Initializer for the membrane potential state. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function for the spike generation. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + ref_var : bool, default=False + Whether to expose a boolean refractory state variable. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + last_spike_time : ShortTermState + Last spike time recorder. + refractory : HiddenState + Neuron refractory state. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create a ExpIF neuron layer with 10 neurons + >>> expif = brainpy.state.ExpIF(10, tau=10*u.ms, V_th=-30*u.mV) + >>> + >>> # Initialize the state + >>> expif.init_state(batch_size=1) + >>> + >>> # Generate inputs + >>> time_steps = 100 + >>> inputs = brainstate.random.randn(time_steps, 1, 10) * u.mA + >>> + >>> # Apply an input current and update the neuron state + >>> + >>> with brainstate.environ.context(dt=0.1 * u.ms): + >>> for t in range(time_steps): + >>> with brainstate.environ.context(t=t*0.1*u.ms): + >>> spikes = expif.update(x=inputs[t]) + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + R: ArrayLike = 1. * u.ohm, + tau: ArrayLike = 10. * u.ms, + tau_ref: ArrayLike = 1.7 * u.ms, + V_th: ArrayLike = -30. * u.mV, + V_reset: ArrayLike = -68. * u.mV, + V_rest: ArrayLike = -65. * u.mV, + V_T: ArrayLike = -59.9 * u.mV, + delta_T: ArrayLike = 3.48 * u.mV, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + ref_var: bool = False, + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.R = braintools.init.param(R, self.varshape) + self.tau = braintools.init.param(tau, self.varshape) + self.tau_ref = braintools.init.param(tau_ref, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) + self.V_rest = braintools.init.param(V_rest, self.varshape) + self.V_T = braintools.init.param(V_T, self.varshape) + self.delta_T = braintools.init.param(delta_T, self.varshape) self.V_initializer = V_initializer + self.ref_var = ref_var def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.last_spike_time = brainstate.ShortTermState( braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) ) + if self.ref_var: + self.refractory = brainstate.HiddenState( + braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) + ) def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.last_spike_time.value = braintools.init.param( braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size ) + if self.ref_var: + self.refractory.value = braintools.init.param( + braintools.init.Constant(False), self.varshape, batch_size + ) def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V @@ -434,42 +567,50 @@ def get_spike(self, V: ArrayLike = None): def update(self, x=0. * u.mA): t = brainstate.environ.get('t') last_v = self.V.value - lst_spk = self.get_spike(last_v) + last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) - last_v = last_v - (V_th - self.V_reset) * lst_spk - # membrane potential - dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau - V = brainstate.nn.exp_euler_step(dv, last_v) - V = self.sum_delta_inputs(V) - self.V.value = u.math.where(t - self.last_spike_time.value < self.tau_ref, last_v, V) - # spike time evaluation - lst_spk_time = u.math.where( - self.V.value >= self.V_th, brainstate.environ.get('t'), self.last_spike_time.value) - self.last_spike_time.value = jax.lax.stop_gradient(lst_spk_time) - return self.get_spike() + v_reset = last_v - (V_th - self.V_reset) * last_spk + def dv(v): + exp_term = self.delta_T * u.math.exp((v - self.V_T) / self.delta_T) + return (-(v - self.V_rest) + exp_term + self.R * self.sum_current_inputs(x, v)) / self.tau -class ALIF(Neuron): - r"""Adaptive Leaky Integrate-and-Fire (ALIF) neuron model. + V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) + V_candidate = self.sum_delta_inputs(V_candidate) - This class implements the Adaptive Leaky Integrate-and-Fire neuron model, which extends - the basic LIF model by adding an adaptation variable. This adaptation mechanism increases - the effective firing threshold after each spike, allowing the neuron to exhibit - spike-frequency adaptation - a common feature in biological neurons that reduces - firing rate during sustained stimulation. + refractory = (t - self.last_spike_time.value) < self.tau_ref + self.V.value = u.math.where(refractory, v_reset, V_candidate) - The model is characterized by the following differential equations: + spike_cond = self.V.value >= self.V_th + self.last_spike_time.value = jax.lax.stop_gradient( + u.math.where(spike_cond, t, self.last_spike_time.value) + ) + if self.ref_var: + self.refractory.value = jax.lax.stop_gradient( + u.math.logical_or(refractory, spike_cond) + ) + return self.get_spike() + + +class AdExIF(Neuron): + r"""Adaptive exponential Integrate-and-Fire (AdExIF) neuron model. + + This model extends :class:`ExpIF` by adding an adaptation current ``w`` that is + incremented after each spike and relaxes with time constant ``tau_w``. The membrane + dynamics are governed by two coupled differential equations [1]_: $$ - \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t) + \tau \frac{dV}{dt} = -(V - V_{rest}) + \Delta_T + \exp\left(\frac{V - V_T}{\Delta_T}\right) - R w + R \cdot I(t) $$ $$ - \tau_a \frac{da}{dt} = -a + \tau_w \frac{dw}{dt} = a (V - V_{rest}) - w $$ - Spike condition: - If $V \geq V_{th} + \beta \cdot a$: emit spike, set $V = V_{reset}$, and increment $a = a + 1$ + After each spike the membrane potential is reset and the adaptation current + increases by ``b``. This simple mechanism generates rich firing patterns such + as spike-frequency adaptation and bursting. Parameters ---------- @@ -477,29 +618,32 @@ class ALIF(Neuron): Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. - tau : ArrayLike, default=5. * u.ms + tau : ArrayLike, default=10. * u.ms Membrane time constant. - tau_a : ArrayLike, default=100. * u.ms - Adaptation time constant (typically much longer than tau). - V_th : ArrayLike, default=1. * u.mV - Base firing threshold voltage. - V_reset : ArrayLike, default=0. * u.mV - Reset voltage after spike. - V_rest : ArrayLike, default=0. * u.mV + tau_w : ArrayLike, default=30. * u.ms + Adaptation current time constant. + V_th : ArrayLike, default=-55. * u.mV + Spike threshold used for reset. + V_reset : ArrayLike, default=-68. * u.mV + Reset potential after spike. + V_rest : ArrayLike, default=-65. * u.mV Resting membrane potential. - beta : ArrayLike, default=0.1 * u.mV - Adaptation coupling parameter that scales the effect of the adaptation variable. - spk_fun : Callable - Surrogate gradient function for the non-differentiable spike generation. - spk_reset : str, default='soft' - Reset mechanism after spike generation: - - - 'soft': subtract threshold V = V - V_th - - 'hard': strict reset using stop_gradient + V_T : ArrayLike, default=-59.9 * u.mV + Threshold of the exponential term. + delta_T : ArrayLike, default=3.48 * u.mV + Spike slope factor controlling the sharpness of spike initiation. + a : ArrayLike, default=1. * u.siemens + Coupling strength from voltage to adaptation current. + b : ArrayLike, default=1. * u.mA + Increment of the adaptation current after a spike. V_initializer : Callable Initializer for the membrane potential state. - a_initializer : Callable - Initializer for the adaptation variable. + w_initializer : Callable + Initializer for the adaptation current. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function for the spike generation. + spk_reset : str, default='soft' + Reset mechanism after spike generation. name : str, optional Name of the neuron layer. @@ -507,74 +651,54 @@ class ALIF(Neuron): ---------- V : HiddenState Membrane potential. - a : HiddenState - Adaptation variable that increases after each spike and decays exponentially. + w : HiddenState + Adaptation current. Examples -------- - >>> import brainpy.state as brainpy + >>> import brainpy >>> import brainstate >>> import brainunit as u >>> - >>> # Create an ALIF neuron layer with 10 neurons - >>> alif = brainpy.ALIF(10, - ... tau=10*u.ms, - ... tau_a=200*u.ms, - ... beta=0.2*u.mV) + >>> # Create a AdExIF neuron layer with 10 neurons + >>> adexif = brainpy.state.AdExIF(10, tau=10*u.ms) >>> >>> # Initialize the state - >>> alif.init_state(batch_size=1) + >>> adexif.init_state(batch_size=1) >>> >>> # Apply an input current and update the neuron state - >>> spikes = alif.update(x=1.5*u.mA) - >>> - >>> # Create a network with adaptation for burst detection - >>> network = brainstate.nn.Sequential([ - ... brainpy.ALIF(100, tau_a=150*u.ms, beta=0.3*u.mV), - ... brainstate.nn.Linear(100, 10) - ... ]) - - Notes - ----- - - The adaptation variable 'a' increases by 1 with each spike and decays exponentially - with time constant tau_a between spikes. - - The effective threshold increases by beta*a, making it progressively harder for the - neuron to fire when it has recently been active. - - This adaptation mechanism creates spike-frequency adaptation, allowing the neuron - to respond strongly to input onset but then reduce its firing rate even if the - input remains constant. - - The adaptation time constant tau_a is typically much larger than the membrane time - constant tau, creating a longer-lasting adaptation effect. - - The time-dependent dynamics are integrated using an exponential Euler method. + >>> spikes = adexif.update(x=1.5*u.mA) References ---------- - .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). - Neuronal dynamics: From single neurons to networks and models of cognition. - Cambridge University Press. - .. [2] Brette, R., & Gerstner, W. (2005). Adaptive exponential integrate-and-fire model - as an effective description of neuronal activity. Journal of neurophysiology, - 94(5), 3637-3642. - .. [3] Naud, R., Marcille, N., Clopath, C., & Gerstner, W. (2008). Firing patterns in - the adaptive exponential integrate-and-fire model. Biological cybernetics, - 99(4), 335-347. + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model + + .. seealso:: + + :class:`brainpy.dyn.AdExIF` for the dynamical-system counterpart. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, - tau: ArrayLike = 5. * u.ms, - tau_a: ArrayLike = 100. * u.ms, - V_th: ArrayLike = 1. * u.mV, - V_reset: ArrayLike = 0. * u.mV, - V_rest: ArrayLike = 0. * u.mV, - beta: ArrayLike = 0.1 * u.mV, + tau: ArrayLike = 10. * u.ms, + tau_w: ArrayLike = 30. * u.ms, + V_th: ArrayLike = -55. * u.mV, + V_reset: ArrayLike = -68. * u.mV, + V_rest: ArrayLike = -65. * u.mV, + V_T: ArrayLike = -59.9 * u.mV, + delta_T: ArrayLike = 3.48 * u.mV, + a: ArrayLike = 1. * u.siemens, + b: ArrayLike = 1. * u.mA, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), + w_initializer: Callable = braintools.init.Constant(0. * u.mA), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', - V_initializer: Callable = braintools.init.Constant(0. * u.mV), - a_initializer: Callable = braintools.init.Constant(0.), name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) @@ -582,42 +706,1582 @@ def __init__( # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) - self.tau_a = braintools.init.param(tau_a, self.varshape) + self.tau_w = braintools.init.param(tau_w, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) - self.beta = braintools.init.param(beta, self.varshape) + self.V_T = braintools.init.param(V_T, self.varshape) + self.delta_T = braintools.init.param(delta_T, self.varshape) + self.a = braintools.init.param(a, self.varshape) + self.b = braintools.init.param(b, self.varshape) - # functions + # initializers self.V_initializer = V_initializer - self.a_initializer = a_initializer + self.w_initializer = w_initializer def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) - self.a = brainstate.HiddenState(braintools.init.param(self.a_initializer, self.varshape, batch_size)) + self.w = brainstate.HiddenState(braintools.init.param(self.w_initializer, self.varshape, batch_size)) def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) - self.a.value = braintools.init.param(self.a_initializer, self.varshape, batch_size) + self.w.value = braintools.init.param(self.w_initializer, self.varshape, batch_size) - def get_spike(self, V=None, a=None): + def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V - a = self.a.value if a is None else a - v_scaled = (V - self.V_th - self.beta * a) / (self.V_th - self.V_reset) + v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled) def update(self, x=0. * u.mA): last_v = self.V.value - last_a = self.a.value - lst_spk = self.get_spike(last_v, last_a) + last_w = self.w.value + last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) - V = last_v - (V_th - self.V_reset) * lst_spk - a = last_a + lst_spk - # membrane potential - dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau - da = lambda a: -a / self.tau_a + V = last_v - (V_th - self.V_reset) * last_spk + w = last_w + self.b * last_spk + + def dv(v): + exp_term = self.delta_T * u.math.exp((v - self.V_T) / self.delta_T) + I_total = self.sum_current_inputs(x, v) + return (-(v - self.V_rest) + exp_term - self.R * w + self.R * I_total) / self.tau + V = brainstate.nn.exp_euler_step(dv, V) - a = brainstate.nn.exp_euler_step(da, a) - self.V.value = self.sum_delta_inputs(V) + V = self.sum_delta_inputs(V) + + def dw_func(w_val): + return (self.a * (V - self.V_rest) - w_val) / self.tau_w + + w = brainstate.nn.exp_euler_step(dw_func, w) + self.V.value = V + self.w.value = w + return self.get_spike(self.V.value) + + +class AdExIFRef(Neuron): + r"""Adaptive exponential Integrate-and-Fire neuron model with refractory mechanism. + + This model extends :class:`AdExIF` by adding an absolute refractory period. While the + exponential spike-initiation term and adaptation current keep the membrane potential + dynamics biologically realistic, the refractory mechanism prevents the neuron from + firing within ``tau_ref`` after a spike. + + The membrane dynamics are governed by two coupled differential equations: + + $$ + \tau \frac{dV}{dt} = -(V - V_{rest}) + \Delta_T + \exp\left(\frac{V - V_T}{\Delta_T}\right) - R w + R \cdot I(t) + $$ + + $$ + \tau_w \frac{dw}{dt} = a (V - V_{rest}) - w + $$ + + After each spike the membrane potential is reset and the adaptation current + increases by ``b``. During the refractory period, the membrane potential + remains at the reset value. + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + R : ArrayLike, default=1. * u.ohm + Membrane resistance. + tau : ArrayLike, default=10. * u.ms + Membrane time constant. + tau_w : ArrayLike, default=30. * u.ms + Adaptation current time constant. + tau_ref : ArrayLike, default=1.7 * u.ms + Absolute refractory period duration. + V_th : ArrayLike, default=-55. * u.mV + Spike threshold used for reset. + V_reset : ArrayLike, default=-68. * u.mV + Reset potential after spike. + V_rest : ArrayLike, default=-65. * u.mV + Resting membrane potential. + V_T : ArrayLike, default=-59.9 * u.mV + Threshold of the exponential term. + delta_T : ArrayLike, default=3.48 * u.mV + Spike slope factor controlling the sharpness of spike initiation. + a : ArrayLike, default=1. * u.siemens + Coupling strength from voltage to adaptation current. + b : ArrayLike, default=1. * u.mA + Increment of the adaptation current after a spike. + V_initializer : Callable + Initializer for the membrane potential state. + w_initializer : Callable + Initializer for the adaptation current. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function for the spike generation. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + ref_var : bool, default=False + Whether to expose a boolean refractory state variable. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + w : HiddenState + Adaptation current. + last_spike_time : ShortTermState + Last spike time recorder. + refractory : HiddenState + Neuron refractory state (if ref_var=True). + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create an AdExIFRef neuron layer with 10 neurons + >>> adexif_ref = brainpy.state.AdExIFRef(10, tau=10*u.ms, tau_ref=2*u.ms) + >>> + >>> # Initialize the state + >>> adexif_ref.init_state(batch_size=1) + >>> + >>> # Generate inputs + >>> time_steps = 100 + >>> inputs = brainstate.random.randn(time_steps, 1, 10) * u.mA + >>> + >>> # Apply input currents and update the neuron state + >>> with brainstate.environ.context(dt=0.1 * u.ms): + >>> for t in range(time_steps): + >>> with brainstate.environ.context(t=t*0.1*u.ms): + >>> spikes = adexif_ref.update(x=inputs[t]) + + References + ---------- + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model + + .. seealso:: + + :class:`brainpy.dyn.AdExIFRef` for the dynamical-system counterpart. + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + R: ArrayLike = 1. * u.ohm, + tau: ArrayLike = 10. * u.ms, + tau_w: ArrayLike = 30. * u.ms, + tau_ref: ArrayLike = 1.7 * u.ms, + V_th: ArrayLike = -55. * u.mV, + V_reset: ArrayLike = -68. * u.mV, + V_rest: ArrayLike = -65. * u.mV, + V_T: ArrayLike = -59.9 * u.mV, + delta_T: ArrayLike = 3.48 * u.mV, + a: ArrayLike = 1. * u.siemens, + b: ArrayLike = 1. * u.mA, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), + w_initializer: Callable = braintools.init.Constant(0. * u.mA), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + ref_var: bool = False, + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.R = braintools.init.param(R, self.varshape) + self.tau = braintools.init.param(tau, self.varshape) + self.tau_w = braintools.init.param(tau_w, self.varshape) + self.tau_ref = braintools.init.param(tau_ref, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + self.V_reset = braintools.init.param(V_reset, self.varshape) + self.V_rest = braintools.init.param(V_rest, self.varshape) + self.V_T = braintools.init.param(V_T, self.varshape) + self.delta_T = braintools.init.param(delta_T, self.varshape) + self.a = braintools.init.param(a, self.varshape) + self.b = braintools.init.param(b, self.varshape) + + # initializers + self.V_initializer = V_initializer + self.w_initializer = w_initializer + self.ref_var = ref_var + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.w = brainstate.HiddenState(braintools.init.param(self.w_initializer, self.varshape, batch_size)) + self.last_spike_time = brainstate.ShortTermState( + braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) + ) + if self.ref_var: + self.refractory = brainstate.HiddenState( + braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) + ) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.w.value = braintools.init.param(self.w_initializer, self.varshape, batch_size) + self.last_spike_time.value = braintools.init.param( + braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size + ) + if self.ref_var: + self.refractory.value = braintools.init.param( + braintools.init.Constant(False), self.varshape, batch_size + ) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mA): + t = brainstate.environ.get('t') + last_v = self.V.value + last_w = self.w.value + last_spk = self.get_spike(last_v) + V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) + v_reset = last_v - (V_th - self.V_reset) * last_spk + w_reset = last_w + self.b * last_spk + + def dv(v): + exp_term = self.delta_T * u.math.exp((v - self.V_T) / self.delta_T) + I_total = self.sum_current_inputs(x, v) + return (-(v - self.V_rest) + exp_term - self.R * w_reset + self.R * I_total) / self.tau + + V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) + V_candidate = self.sum_delta_inputs(V_candidate) + + def dw_func(w_val): + return (self.a * (V_candidate - self.V_rest) - w_val) / self.tau_w + + w_candidate = brainstate.nn.exp_euler_step(dw_func, w_reset) + + refractory = (t - self.last_spike_time.value) < self.tau_ref + self.V.value = u.math.where(refractory, v_reset, V_candidate) + self.w.value = u.math.where(refractory, w_reset, w_candidate) + + spike_cond = self.V.value >= self.V_th + self.last_spike_time.value = jax.lax.stop_gradient( + u.math.where(spike_cond, t, self.last_spike_time.value) + ) + if self.ref_var: + self.refractory.value = jax.lax.stop_gradient( + u.math.logical_or(refractory, spike_cond) + ) + return self.get_spike() + + +class LIFRef(Neuron): + r"""Leaky Integrate-and-Fire neuron model with refractory period. + + This class implements a Leaky Integrate-and-Fire neuron model that includes a + refractory period after spiking, during which the neuron cannot fire regardless + of input. This better captures the behavior of biological neurons that exhibit + a recovery period after action potential generation. + + The model is characterized by the following equations: + + When not in refractory period: + $$ + \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t) + $$ + + During refractory period: + $$ + V = V_{reset} + $$ + + Spike condition: + If $V \geq V_{th}$: emit spike, set $V = V_{reset}$, and enter refractory period for $\tau_{ref}$ + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + R : ArrayLike, default=1. * u.ohm + Membrane resistance. + tau : ArrayLike, default=5. * u.ms + Membrane time constant. + tau_ref : ArrayLike, default=5. * u.ms + Refractory period duration. + V_th : ArrayLike, default=1. * u.mV + Firing threshold voltage. + V_reset : ArrayLike, default=0. * u.mV + Reset voltage after spike. + V_rest : ArrayLike, default=0. * u.mV + Resting membrane potential. + V_initializer : Callable + Initializer for the membrane potential state. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function for the non-differentiable spike generation. + spk_reset : str, default='soft' + Reset mechanism after spike generation: + - 'soft': subtract threshold V = V - V_th + - 'hard': strict reset using stop_gradient + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + last_spike_time : ShortTermState + Time of the last spike, used to implement refractory period. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create a LIFRef neuron layer with 10 neurons + >>> lifref = brainpy.state.LIFRef(10, + ... tau=10*u.ms, + ... tau_ref=5*u.ms, + ... V_th=0.8*u.mV) + >>> + >>> # Initialize the state + >>> lifref.init_state(batch_size=1) + >>> + >>> # Apply an input current and update the neuron state + >>> spikes = lifref.update(x=1.5*u.mA) + >>> + >>> # Create a network with refractory neurons + >>> network = brainstate.nn.Sequential([ + ... brainpy.state.LIFRef(100, tau_ref=4*u.ms), + ... brainstate.nn.Linear(100, 10) + ... ]) + + Notes + ----- + - The refractory period is implemented by tracking the time of the last spike + and preventing membrane potential updates if the elapsed time is less than tau_ref. + - During the refractory period, the membrane potential remains at the reset value + regardless of input current strength. + - Refractory periods prevent high-frequency repetitive firing and are critical + for realistic neural dynamics. + - The time-dependent dynamics are integrated using an exponential Euler method. + - The simulation environment time variable 't' is used to track the refractory state. + + References + ---------- + .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). + Neuronal dynamics: From single neurons to networks and models of cognition. + Cambridge University Press. + .. [2] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model: + I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19. + .. [3] Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on + neural networks, 14(6), 1569-1572. + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + R: ArrayLike = 1. * u.ohm, + tau: ArrayLike = 5. * u.ms, + tau_ref: ArrayLike = 5. * u.ms, + V_th: ArrayLike = 1. * u.mV, + V_reset: ArrayLike = 0. * u.mV, + V_rest: ArrayLike = 0. * u.mV, + V_initializer: Callable = braintools.init.Constant(0. * u.mV), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.R = braintools.init.param(R, self.varshape) + self.tau = braintools.init.param(tau, self.varshape) + self.tau_ref = braintools.init.param(tau_ref, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + self.V_rest = braintools.init.param(V_rest, self.varshape) + self.V_reset = braintools.init.param(V_reset, self.varshape) + self.V_initializer = V_initializer + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.last_spike_time = brainstate.ShortTermState( + braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) + ) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.last_spike_time.value = braintools.init.param( + braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size + ) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mA): + t = brainstate.environ.get('t') + last_v = self.V.value + last_spk = self.get_spike(last_v) + V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) + last_v = last_v - (V_th - self.V_reset) * last_spk + # membrane potential + dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau + V = brainstate.nn.exp_euler_step(dv, last_v) + V = self.sum_delta_inputs(V) + self.V.value = u.math.where(t - self.last_spike_time.value < self.tau_ref, last_v, V) + # spike time evaluation + last_spk_time = u.math.where( + self.V.value >= self.V_th, brainstate.environ.get('t'), self.last_spike_time.value) + self.last_spike_time.value = jax.lax.stop_gradient(last_spk_time) + return self.get_spike() + + +class ALIF(Neuron): + r"""Adaptive Leaky Integrate-and-Fire (ALIF) neuron model. + + This class implements the Adaptive Leaky Integrate-and-Fire neuron model, which extends + the basic LIF model by adding an adaptation variable. This adaptation mechanism increases + the effective firing threshold after each spike, allowing the neuron to exhibit + spike-frequency adaptation - a common feature in biological neurons that reduces + firing rate during sustained stimulation. + + The model is characterized by the following differential equations: + + $$ + \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t) + $$ + + $$ + \tau_a \frac{da}{dt} = -a + $$ + + Spike condition: + If $V \geq V_{th} + \beta \cdot a$: emit spike, set $V = V_{reset}$, and increment $a = a + 1$ + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + R : ArrayLike, default=1. * u.ohm + Membrane resistance. + tau : ArrayLike, default=5. * u.ms + Membrane time constant. + tau_a : ArrayLike, default=100. * u.ms + Adaptation time constant (typically much longer than tau). + V_th : ArrayLike, default=1. * u.mV + Base firing threshold voltage. + V_reset : ArrayLike, default=0. * u.mV + Reset voltage after spike. + V_rest : ArrayLike, default=0. * u.mV + Resting membrane potential. + beta : ArrayLike, default=0.1 * u.mV + Adaptation coupling parameter that scales the effect of the adaptation variable. + spk_fun : Callable + Surrogate gradient function for the non-differentiable spike generation. + spk_reset : str, default='soft' + Reset mechanism after spike generation: + + - 'soft': subtract threshold V = V - V_th + - 'hard': strict reset using stop_gradient + V_initializer : Callable + Initializer for the membrane potential state. + a_initializer : Callable + Initializer for the adaptation variable. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + a : HiddenState + Adaptation variable that increases after each spike and decays exponentially. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create an ALIF neuron layer with 10 neurons + >>> alif = brainpy.state.ALIF(10, + ... tau=10*u.ms, + ... tau_a=200*u.ms, + ... beta=0.2*u.mV) + >>> + >>> # Initialize the state + >>> alif.init_state(batch_size=1) + >>> + >>> # Apply an input current and update the neuron state + >>> spikes = alif.update(x=1.5*u.mA) + >>> + >>> # Create a network with adaptation for burst detection + >>> network = brainstate.nn.Sequential([ + ... brainpy.state.ALIF(100, tau_a=150*u.ms, beta=0.3*u.mV), + ... brainstate.nn.Linear(100, 10) + ... ]) + + Notes + ----- + - The adaptation variable 'a' increases by 1 with each spike and decays exponentially + with time constant tau_a between spikes. + - The effective threshold increases by beta*a, making it progressively harder for the + neuron to fire when it has recently been active. + - This adaptation mechanism creates spike-frequency adaptation, allowing the neuron + to respond strongly to input onset but then reduce its firing rate even if the + input remains constant. + - The adaptation time constant tau_a is typically much larger than the membrane time + constant tau, creating a longer-lasting adaptation effect. + - The time-dependent dynamics are integrated using an exponential Euler method. + + References + ---------- + .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). + Neuronal dynamics: From single neurons to networks and models of cognition. + Cambridge University Press. + .. [2] Brette, R., & Gerstner, W. (2005). Adaptive exponential integrate-and-fire model + as an effective description of neuronal activity. Journal of neurophysiology, + 94(5), 3637-3642. + .. [3] Naud, R., Marcille, N., Clopath, C., & Gerstner, W. (2008). Firing patterns in + the adaptive exponential integrate-and-fire model. Biological cybernetics, + 99(4), 335-347. + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + R: ArrayLike = 1. * u.ohm, + tau: ArrayLike = 5. * u.ms, + tau_a: ArrayLike = 100. * u.ms, + V_th: ArrayLike = 1. * u.mV, + V_reset: ArrayLike = 0. * u.mV, + V_rest: ArrayLike = 0. * u.mV, + beta: ArrayLike = 0.1 * u.mV, + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + V_initializer: Callable = braintools.init.Constant(0. * u.mV), + a_initializer: Callable = braintools.init.Constant(0.), + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.R = braintools.init.param(R, self.varshape) + self.tau = braintools.init.param(tau, self.varshape) + self.tau_a = braintools.init.param(tau_a, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + self.V_reset = braintools.init.param(V_reset, self.varshape) + self.V_rest = braintools.init.param(V_rest, self.varshape) + self.beta = braintools.init.param(beta, self.varshape) + + # functions + self.V_initializer = V_initializer + self.a_initializer = a_initializer + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.a = brainstate.HiddenState(braintools.init.param(self.a_initializer, self.varshape, batch_size)) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.a.value = braintools.init.param(self.a_initializer, self.varshape, batch_size) + + def get_spike(self, V=None, a=None): + V = self.V.value if V is None else V + a = self.a.value if a is None else a + v_scaled = (V - self.V_th - self.beta * a) / (self.V_th - self.V_reset) + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mA): + last_v = self.V.value + last_a = self.a.value + lst_spk = self.get_spike(last_v, last_a) + V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) + V = last_v - (V_th - self.V_reset) * lst_spk + a = last_a + lst_spk + # membrane potential + dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau + da = lambda a: -a / self.tau_a + V = brainstate.nn.exp_euler_step(dv, V) + a = brainstate.nn.exp_euler_step(da, a) + self.V.value = self.sum_delta_inputs(V) self.a.value = a return self.get_spike(self.V.value, self.a.value) + + +class QuaIF(Neuron): + r"""Quadratic Integrate-and-Fire (QuaIF) neuron model. + + This model extends the basic integrate-and-fire neuron by adding a quadratic + nonlinearity in the voltage dynamics. The quadratic term creates a soft spike + initiation, making the model more biologically realistic than the linear IF model. + + The model is characterized by the following differential equation: + + $$ + \tau \frac{dV}{dt} = c(V - V_{rest})(V - V_c) + R \cdot I(t) + $$ + + Spike condition: + If $V \geq V_{th}$: emit spike and reset $V = V_{reset}$ + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + R : ArrayLike, default=1. * u.ohm + Membrane resistance. + tau : ArrayLike, default=10. * u.ms + Membrane time constant. + V_th : ArrayLike, default=-30. * u.mV + Firing threshold voltage. + V_reset : ArrayLike, default=-68. * u.mV + Reset voltage after spike. + V_rest : ArrayLike, default=-65. * u.mV + Resting membrane potential. + V_c : ArrayLike, default=-50. * u.mV + Critical voltage for spike initiation. Must be larger than V_rest. + c : ArrayLike, default=0.07 / u.mV + Coefficient describing membrane potential update. Larger than 0. + V_initializer : Callable + Initializer for the membrane potential state. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function for the spike generation. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create a QuaIF neuron layer with 10 neurons + >>> quaif = brainpy.state.QuaIF(10, tau=10*u.ms, V_th=-30*u.mV, V_c=-50*u.mV) + >>> + >>> # Initialize the state + >>> quaif.init_state(batch_size=1) + >>> + >>> # Apply an input current and update the neuron state + >>> spikes = quaif.update(x=2.5*u.mA) + >>> + >>> # Create a network with QuaIF neurons + >>> network = brainstate.nn.Sequential([ + ... brainpy.state.QuaIF(100, tau=10.0*u.ms), + ... brainstate.nn.Linear(100, 10) + ... ]) + + Notes + ----- + - The quadratic nonlinearity provides a more realistic spike initiation compared to LIF. + - The critical voltage V_c determines the onset of spike generation. + - When V approaches V_c, the quadratic term causes rapid acceleration toward threshold. + - This model can exhibit Type I excitability (continuous f-I curve). + + References + ---------- + .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg + (2000) Intrinsic dynamics in neuronal networks. I. Theory. + J. Neurophysiology 83, pp. 808–827. + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + R: ArrayLike = 1. * u.ohm, + tau: ArrayLike = 10. * u.ms, + V_th: ArrayLike = -30. * u.mV, + V_reset: ArrayLike = -68. * u.mV, + V_rest: ArrayLike = -65. * u.mV, + V_c: ArrayLike = -50. * u.mV, + c: ArrayLike = 0.07 / u.mV, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.R = braintools.init.param(R, self.varshape) + self.tau = braintools.init.param(tau, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + self.V_reset = braintools.init.param(V_reset, self.varshape) + self.V_rest = braintools.init.param(V_rest, self.varshape) + self.V_c = braintools.init.param(V_c, self.varshape) + self.c = braintools.init.param(c, self.varshape) + self.V_initializer = V_initializer + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mA): + last_v = self.V.value + last_spk = self.get_spike(last_v) + V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) + V = last_v - (V_th - self.V_reset) * last_spk + + def dv(v): + return (self.c * (v - self.V_rest) * (v - self.V_c) + self.R * self.sum_current_inputs(x, v)) / self.tau + + V = brainstate.nn.exp_euler_step(dv, V) + V = self.sum_delta_inputs(V) + self.V.value = V + return self.get_spike(V) + + +class AdQuaIF(Neuron): + r"""Adaptive Quadratic Integrate-and-Fire (AdQuaIF) neuron model. + + This model extends the QuaIF model by adding an adaptation current that increases + after each spike and decays exponentially between spikes. The adaptation mechanism + produces spike-frequency adaptation and enables the neuron to exhibit various + firing patterns. + + The model is characterized by the following differential equations: + + $$ + \tau \frac{dV}{dt} = c(V - V_{rest})(V - V_c) - w + R \cdot I(t) + $$ + + $$ + \tau_w \frac{dw}{dt} = a(V - V_{rest}) - w + $$ + + After a spike: $V \rightarrow V_{reset}$, $w \rightarrow w + b$ + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + R : ArrayLike, default=1. * u.ohm + Membrane resistance. + tau : ArrayLike, default=10. * u.ms + Membrane time constant. + tau_w : ArrayLike, default=10. * u.ms + Adaptation current time constant. + V_th : ArrayLike, default=-30. * u.mV + Firing threshold voltage. + V_reset : ArrayLike, default=-68. * u.mV + Reset voltage after spike. + V_rest : ArrayLike, default=-65. * u.mV + Resting membrane potential. + V_c : ArrayLike, default=-50. * u.mV + Critical voltage for spike initiation. + c : ArrayLike, default=0.07 / u.mV + Coefficient describing membrane potential update. + a : ArrayLike, default=1. * u.siemens + Coupling strength from voltage to adaptation current. + b : ArrayLike, default=0.1 * u.mA + Increment of adaptation current after a spike. + V_initializer : Callable + Initializer for the membrane potential state. + w_initializer : Callable + Initializer for the adaptation current. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + w : HiddenState + Adaptation current. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create an AdQuaIF neuron layer with 10 neurons + >>> adquaif = brainpy.state.AdQuaIF(10, tau=10*u.ms, tau_w=100*u.ms, + ... a=1.0*u.siemens, b=0.1*u.mA) + >>> + >>> # Initialize the state + >>> adquaif.init_state(batch_size=1) + >>> + >>> # Apply an input current and observe spike-frequency adaptation + >>> spikes = adquaif.update(x=3.0*u.mA) + >>> + >>> # Create a network with adaptive neurons + >>> network = brainstate.nn.Sequential([ + ... brainpy.state.AdQuaIF(100, tau=10.0*u.ms, tau_w=100.0*u.ms), + ... brainstate.nn.Linear(100, 10) + ... ]) + + Notes + ----- + - The adaptation current w provides negative feedback, reducing firing rate. + - Parameter 'a' controls subthreshold adaptation (coupling from V to w). + - Parameter 'b' controls spike-triggered adaptation (increment after spike). + - With appropriate parameters, can exhibit regular spiking, bursting, etc. + - The adaptation time constant tau_w determines adaptation speed. + + References + ---------- + .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking + neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. + .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of + nonlinear integrate-and-fire neurons." SIAM Journal on Applied + Mathematics 68, no. 4 (2008): 1045-1079. + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + R: ArrayLike = 1. * u.ohm, + tau: ArrayLike = 10. * u.ms, + tau_w: ArrayLike = 10. * u.ms, + V_th: ArrayLike = -30. * u.mV, + V_reset: ArrayLike = -68. * u.mV, + V_rest: ArrayLike = -65. * u.mV, + V_c: ArrayLike = -50. * u.mV, + c: ArrayLike = 0.07 / u.mV, + a: ArrayLike = 1. * u.siemens, + b: ArrayLike = 0.1 * u.mA, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), + w_initializer: Callable = braintools.init.Constant(0. * u.mA), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.R = braintools.init.param(R, self.varshape) + self.tau = braintools.init.param(tau, self.varshape) + self.tau_w = braintools.init.param(tau_w, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + self.V_reset = braintools.init.param(V_reset, self.varshape) + self.V_rest = braintools.init.param(V_rest, self.varshape) + self.V_c = braintools.init.param(V_c, self.varshape) + self.c = braintools.init.param(c, self.varshape) + self.a = braintools.init.param(a, self.varshape) + self.b = braintools.init.param(b, self.varshape) + self.V_initializer = V_initializer + self.w_initializer = w_initializer + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.w = brainstate.HiddenState(braintools.init.param(self.w_initializer, self.varshape, batch_size)) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.w.value = braintools.init.param(self.w_initializer, self.varshape, batch_size) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mA): + last_v = self.V.value + last_w = self.w.value + last_spk = self.get_spike(last_v) + V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) + V = last_v - (V_th - self.V_reset) * last_spk + w = last_w + self.b * last_spk + + def dv(v): + return (self.c * (v - self.V_rest) * (v - self.V_c) - self.R * w + self.R * self.sum_current_inputs(x, v)) / self.tau + + def dw_func(w_val): + return (self.a * (V - self.V_rest) - w_val) / self.tau_w + + V = brainstate.nn.exp_euler_step(dv, V) + V = self.sum_delta_inputs(V) + w = brainstate.nn.exp_euler_step(dw_func, w) + + self.V.value = V + self.w.value = w + return self.get_spike(V) + + +class AdQuaIFRef(Neuron): + r"""Adaptive Quadratic Integrate-and-Fire neuron model with refractory mechanism. + + This model extends AdQuaIF by adding an absolute refractory period during which + the neuron cannot fire regardless of input. The combination of adaptation and + refractory period creates realistic firing patterns. + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + R : ArrayLike, default=1. * u.ohm + Membrane resistance. + tau : ArrayLike, default=10. * u.ms + Membrane time constant. + tau_w : ArrayLike, default=10. * u.ms + Adaptation current time constant. + tau_ref : ArrayLike, default=1.7 * u.ms + Absolute refractory period duration. + V_th : ArrayLike, default=-30. * u.mV + Firing threshold voltage. + V_reset : ArrayLike, default=-68. * u.mV + Reset voltage after spike. + V_rest : ArrayLike, default=-65. * u.mV + Resting membrane potential. + V_c : ArrayLike, default=-50. * u.mV + Critical voltage for spike initiation. + c : ArrayLike, default=0.07 / u.mV + Coefficient describing membrane potential update. + a : ArrayLike, default=1. * u.siemens + Coupling strength from voltage to adaptation current. + b : ArrayLike, default=0.1 * u.mA + Increment of adaptation current after a spike. + V_initializer : Callable + Initializer for the membrane potential state. + w_initializer : Callable + Initializer for the adaptation current. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + ref_var : bool, default=False + Whether to expose a boolean refractory state variable. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + w : HiddenState + Adaptation current. + last_spike_time : ShortTermState + Last spike time recorder. + refractory : HiddenState + Neuron refractory state (if ref_var=True). + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create an AdQuaIFRef neuron layer with refractory period + >>> adquaif_ref = brainpy.state.AdQuaIFRef(10, tau=10*u.ms, tau_w=100*u.ms, + ... tau_ref=2.0*u.ms, ref_var=True) + >>> + >>> # Initialize the state + >>> adquaif_ref.init_state(batch_size=1) + >>> + >>> # Apply input and observe refractory behavior + >>> with brainstate.environ.context(dt=0.1*u.ms, t=0.0*u.ms): + ... spikes = adquaif_ref.update(x=3.0*u.mA) + >>> + >>> # Create a network with refractory adaptive neurons + >>> network = brainstate.nn.Sequential([ + ... brainpy.state.AdQuaIFRef(100, tau=10.0*u.ms, tau_ref=2.0*u.ms), + ... brainstate.nn.Linear(100, 10) + ... ]) + + Notes + ----- + - Combines spike-frequency adaptation with absolute refractory period. + - During refractory period, neuron state is held at reset values. + - Set ref_var=True to track refractory state as a boolean variable. + - Refractory period prevents unrealistically high firing rates. + - More biologically realistic than AdQuaIF without refractory period. + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + R: ArrayLike = 1. * u.ohm, + tau: ArrayLike = 10. * u.ms, + tau_w: ArrayLike = 10. * u.ms, + tau_ref: ArrayLike = 1.7 * u.ms, + V_th: ArrayLike = -30. * u.mV, + V_reset: ArrayLike = -68. * u.mV, + V_rest: ArrayLike = -65. * u.mV, + V_c: ArrayLike = -50. * u.mV, + c: ArrayLike = 0.07 / u.mV, + a: ArrayLike = 1. * u.siemens, + b: ArrayLike = 0.1 * u.mA, + V_initializer: Callable = braintools.init.Constant(-65. * u.mV), + w_initializer: Callable = braintools.init.Constant(0. * u.mA), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + ref_var: bool = False, + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.R = braintools.init.param(R, self.varshape) + self.tau = braintools.init.param(tau, self.varshape) + self.tau_w = braintools.init.param(tau_w, self.varshape) + self.tau_ref = braintools.init.param(tau_ref, self.varshape) + self.V_th = braintools.init.param(V_th, self.varshape) + self.V_reset = braintools.init.param(V_reset, self.varshape) + self.V_rest = braintools.init.param(V_rest, self.varshape) + self.V_c = braintools.init.param(V_c, self.varshape) + self.c = braintools.init.param(c, self.varshape) + self.a = braintools.init.param(a, self.varshape) + self.b = braintools.init.param(b, self.varshape) + self.V_initializer = V_initializer + self.w_initializer = w_initializer + self.ref_var = ref_var + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.w = brainstate.HiddenState(braintools.init.param(self.w_initializer, self.varshape, batch_size)) + self.last_spike_time = brainstate.ShortTermState( + braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) + ) + if self.ref_var: + self.refractory = brainstate.HiddenState( + braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) + ) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.w.value = braintools.init.param(self.w_initializer, self.varshape, batch_size) + self.last_spike_time.value = braintools.init.param( + braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size + ) + if self.ref_var: + self.refractory.value = braintools.init.param( + braintools.init.Constant(False), self.varshape, batch_size + ) + + def get_spike(self, V: ArrayLike = None): + V = self.V.value if V is None else V + v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mA): + t = brainstate.environ.get('t') + last_v = self.V.value + last_w = self.w.value + last_spk = self.get_spike(last_v) + V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) + v_reset = last_v - (V_th - self.V_reset) * last_spk + w_reset = last_w + self.b * last_spk + + def dv(v): + return (self.c * (v - self.V_rest) * (v - self.V_c) - self.R * w_reset + self.R * self.sum_current_inputs(x, v)) / self.tau + + V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) + V_candidate = self.sum_delta_inputs(V_candidate) + + def dw_func(w_val): + return (self.a * (V_candidate - self.V_rest) - w_val) / self.tau_w + + w_candidate = brainstate.nn.exp_euler_step(dw_func, w_reset) + + refractory = (t - self.last_spike_time.value) < self.tau_ref + self.V.value = u.math.where(refractory, v_reset, V_candidate) + self.w.value = u.math.where(refractory, w_reset, w_candidate) + + spike_cond = self.V.value >= self.V_th + self.last_spike_time.value = jax.lax.stop_gradient( + u.math.where(spike_cond, t, self.last_spike_time.value) + ) + if self.ref_var: + self.refractory.value = jax.lax.stop_gradient( + u.math.logical_or(refractory, spike_cond) + ) + return self.get_spike() + + +class Gif(Neuron): + r"""Generalized Integrate-and-Fire (Gif) neuron model. + + This model extends the basic integrate-and-fire neuron by adding internal + currents and a dynamic threshold. The model can reproduce diverse firing + patterns observed in biological neurons. + + The model is characterized by the following equations: + + $$ + \frac{dI_1}{dt} = -k_1 I_1 + $$ + + $$ + \frac{dI_2}{dt} = -k_2 I_2 + $$ + + $$ + \tau \frac{dV}{dt} = -(V - V_{rest}) + R(I_1 + I_2 + I(t)) + $$ + + $$ + \frac{dV_{th}}{dt} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) + $$ + + When $V \geq V_{th}$: + - $I_1 \leftarrow R_1 I_1 + A_1$ + - $I_2 \leftarrow R_2 I_2 + A_2$ + - $V \leftarrow V_{reset}$ + - $V_{th} \leftarrow \max(V_{th_{reset}}, V_{th})$ + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + R : ArrayLike, default=20. * u.ohm + Membrane resistance. + tau : ArrayLike, default=20. * u.ms + Membrane time constant. + V_rest : ArrayLike, default=-70. * u.mV + Resting potential. + V_reset : ArrayLike, default=-70. * u.mV + Reset potential after spike. + V_th_inf : ArrayLike, default=-50. * u.mV + Target value of threshold potential updating. + V_th_reset : ArrayLike, default=-60. * u.mV + Free parameter, should be larger than V_reset. + V_th_initializer : Callable + Initializer for the threshold potential. + a : ArrayLike, default=0. / u.ms + Coefficient describes dependence of V_th on membrane potential. + b : ArrayLike, default=0.01 / u.ms + Coefficient describes V_th update. + k1 : ArrayLike, default=0.2 / u.ms + Constant of I1. + k2 : ArrayLike, default=0.02 / u.ms + Constant of I2. + R1 : ArrayLike, default=0. + Free parameter describing dependence of I1 reset value on I1 before spiking. + R2 : ArrayLike, default=1. + Free parameter describing dependence of I2 reset value on I2 before spiking. + A1 : ArrayLike, default=0. * u.mA + Free parameter. + A2 : ArrayLike, default=0. * u.mA + Free parameter. + V_initializer : Callable + Initializer for the membrane potential state. + I1_initializer : Callable + Initializer for internal current I1. + I2_initializer : Callable + Initializer for internal current I2. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + I1 : HiddenState + Internal current 1. + I2 : HiddenState + Internal current 2. + V_th : HiddenState + Spiking threshold potential. + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create a Gif neuron layer with dynamic threshold + >>> gif = brainpy.state.Gif(10, tau=20*u.ms, k1=0.2/u.ms, k2=0.02/u.ms, + ... a=0.005/u.ms, b=0.01/u.ms) + >>> + >>> # Initialize the state + >>> gif.init_state(batch_size=1) + >>> + >>> # Apply input and observe diverse firing patterns + >>> spikes = gif.update(x=1.5*u.mA) + >>> + >>> # Create a network with Gif neurons + >>> network = brainstate.nn.Sequential([ + ... brainpy.state.Gif(100, tau=20.0*u.ms), + ... brainstate.nn.Linear(100, 10) + ... ]) + + Notes + ----- + - The Gif model uses internal currents (I1, I2) for complex dynamics. + - Dynamic threshold V_th adapts based on membrane potential and its own dynamics. + - Can reproduce diverse firing patterns: regular spiking, bursting, adaptation. + - Parameters a and b control threshold adaptation. + - Parameters k1, k2, R1, R2, A1, A2 control internal current dynamics. + - More flexible than simpler IF models for matching biological data. + + References + ---------- + .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear + integrate-and-fire neural model produces diverse spiking + behaviors." Neural computation 21.3 (2009): 704-718. + .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan + Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized + leaky integrate-and-fire models classify multiple neuron types." + Nature communications 9, no. 1 (2018): 1-15. + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + R: ArrayLike = 20. * u.ohm, + tau: ArrayLike = 20. * u.ms, + V_rest: ArrayLike = -70. * u.mV, + V_reset: ArrayLike = -70. * u.mV, + V_th_inf: ArrayLike = -50. * u.mV, + V_th_reset: ArrayLike = -60. * u.mV, + V_th_initializer: Callable = braintools.init.Constant(-50. * u.mV), + a: ArrayLike = 0. / u.ms, + b: ArrayLike = 0.01 / u.ms, + k1: ArrayLike = 0.2 / u.ms, + k2: ArrayLike = 0.02 / u.ms, + R1: ArrayLike = 0., + R2: ArrayLike = 1., + A1: ArrayLike = 0. * u.mA, + A2: ArrayLike = 0. * u.mA, + V_initializer: Callable = braintools.init.Constant(-70. * u.mV), + I1_initializer: Callable = braintools.init.Constant(0. * u.mA), + I2_initializer: Callable = braintools.init.Constant(0. * u.mA), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.R = braintools.init.param(R, self.varshape) + self.tau = braintools.init.param(tau, self.varshape) + self.V_rest = braintools.init.param(V_rest, self.varshape) + self.V_reset = braintools.init.param(V_reset, self.varshape) + self.V_th_inf = braintools.init.param(V_th_inf, self.varshape) + self.V_th_reset = braintools.init.param(V_th_reset, self.varshape) + self.a = braintools.init.param(a, self.varshape) + self.b = braintools.init.param(b, self.varshape) + self.k1 = braintools.init.param(k1, self.varshape) + self.k2 = braintools.init.param(k2, self.varshape) + self.R1 = braintools.init.param(R1, self.varshape) + self.R2 = braintools.init.param(R2, self.varshape) + self.A1 = braintools.init.param(A1, self.varshape) + self.A2 = braintools.init.param(A2, self.varshape) + self.V_initializer = V_initializer + self.I1_initializer = I1_initializer + self.I2_initializer = I2_initializer + self.V_th_initializer = V_th_initializer + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.I1 = brainstate.HiddenState(braintools.init.param(self.I1_initializer, self.varshape, batch_size)) + self.I2 = brainstate.HiddenState(braintools.init.param(self.I2_initializer, self.varshape, batch_size)) + self.V_th = brainstate.HiddenState(braintools.init.param(self.V_th_initializer, self.varshape, batch_size)) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.I1.value = braintools.init.param(self.I1_initializer, self.varshape, batch_size) + self.I2.value = braintools.init.param(self.I2_initializer, self.varshape, batch_size) + self.V_th.value = braintools.init.param(self.V_th_initializer, self.varshape, batch_size) + + def get_spike(self, V: ArrayLike = None, V_th: ArrayLike = None): + V = self.V.value if V is None else V + V_th = self.V_th.value if V_th is None else V_th + v_scaled = (V - V_th) / (V_th - self.V_reset) + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mA): + last_v = self.V.value + last_i1 = self.I1.value + last_i2 = self.I2.value + last_v_th = self.V_th.value + last_spk = self.get_spike(last_v, last_v_th) + + # Apply spike effects + V_th_val = last_v_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) + V = last_v - (V_th_val - self.V_reset) * last_spk + I1 = last_i1 + last_spk * (self.R1 * last_i1 + self.A1 - last_i1) + I2 = last_i2 + last_spk * (self.R2 * last_i2 + self.A2 - last_i2) + V_th = last_v_th + last_spk * (u.math.maximum(self.V_th_reset, last_v_th) - last_v_th) + + # Update dynamics + def dI1(i1): + return -self.k1 * i1 + + def dI2(i2): + return -self.k2 * i2 + + def dV_th_func(v_th): + return self.a * (V - self.V_rest) - self.b * (v_th - self.V_th_inf) + + def dv(v): + return (-(v - self.V_rest) + self.R * (I1 + I2 + self.sum_current_inputs(x, v))) / self.tau + + I1 = brainstate.nn.exp_euler_step(dI1, I1) + I2 = brainstate.nn.exp_euler_step(dI2, I2) + V_th = brainstate.nn.exp_euler_step(dV_th_func, V_th) + V = brainstate.nn.exp_euler_step(dv, V) + V = self.sum_delta_inputs(V) + + self.V.value = V + self.I1.value = I1 + self.I2.value = I2 + self.V_th.value = V_th + return self.get_spike(V, V_th) + + +class GifRef(Neuron): + r"""Generalized Integrate-and-Fire neuron model with refractory mechanism. + + This model extends Gif by adding an absolute refractory period during which + the neuron cannot fire. This creates more realistic firing patterns and + prevents unrealistic high-frequency firing. + + Parameters + ---------- + in_size : Size + Size of the input to the neuron. + R : ArrayLike, default=20. * u.ohm + Membrane resistance. + tau : ArrayLike, default=20. * u.ms + Membrane time constant. + tau_ref : ArrayLike, default=1.7 * u.ms + Absolute refractory period duration. + V_rest : ArrayLike, default=-70. * u.mV + Resting potential. + V_reset : ArrayLike, default=-70. * u.mV + Reset potential after spike. + V_th_inf : ArrayLike, default=-50. * u.mV + Target value of threshold potential updating. + V_th_reset : ArrayLike, default=-60. * u.mV + Free parameter, should be larger than V_reset. + V_th_initializer : Callable + Initializer for the threshold potential. + a : ArrayLike, default=0. / u.ms + Coefficient describes dependence of V_th on membrane potential. + b : ArrayLike, default=0.01 / u.ms + Coefficient describes V_th update. + k1 : ArrayLike, default=0.2 / u.ms + Constant of I1. + k2 : ArrayLike, default=0.02 / u.ms + Constant of I2. + R1 : ArrayLike, default=0. + Free parameter. + R2 : ArrayLike, default=1. + Free parameter. + A1 : ArrayLike, default=0. * u.mA + Free parameter. + A2 : ArrayLike, default=0. * u.mA + Free parameter. + V_initializer : Callable + Initializer for the membrane potential state. + I1_initializer : Callable + Initializer for internal current I1. + I2_initializer : Callable + Initializer for internal current I2. + spk_fun : Callable, default=surrogate.ReluGrad() + Surrogate gradient function. + spk_reset : str, default='soft' + Reset mechanism after spike generation. + ref_var : bool, default=False + Whether to expose a boolean refractory state variable. + name : str, optional + Name of the neuron layer. + + Attributes + ---------- + V : HiddenState + Membrane potential. + I1 : HiddenState + Internal current 1. + I2 : HiddenState + Internal current 2. + V_th : HiddenState + Spiking threshold potential. + last_spike_time : ShortTermState + Last spike time recorder. + refractory : HiddenState + Neuron refractory state (if ref_var=True). + + Examples + -------- + >>> import brainpy + >>> import brainstate + >>> import brainunit as u + >>> + >>> # Create a GifRef neuron layer with refractory period + >>> gif_ref = brainpy.state.GifRef(10, tau=20*u.ms, tau_ref=2.0*u.ms, + ... k1=0.2/u.ms, k2=0.02/u.ms, ref_var=True) + >>> + >>> # Initialize the state + >>> gif_ref.init_state(batch_size=1) + >>> + >>> # Apply input and observe refractory behavior + >>> with brainstate.environ.context(dt=0.1*u.ms, t=0.0*u.ms): + ... spikes = gif_ref.update(x=1.5*u.mA) + >>> + >>> # Create a network with refractory Gif neurons + >>> network = brainstate.nn.Sequential([ + ... brainpy.state.GifRef(100, tau=20.0*u.ms, tau_ref=2.0*u.ms), + ... brainstate.nn.Linear(100, 10) + ... ]) + + Notes + ----- + - Combines Gif model's rich dynamics with absolute refractory period. + - During refractory period, all state variables are held at reset values. + - Set ref_var=True to track refractory state as a boolean variable. + - More biologically realistic than Gif without refractory mechanism. + - Can still exhibit diverse firing patterns: regular, bursting, adaptation. + - Refractory period prevents unrealistically high firing rates. + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + R: ArrayLike = 20. * u.ohm, + tau: ArrayLike = 20. * u.ms, + tau_ref: ArrayLike = 1.7 * u.ms, + V_rest: ArrayLike = -70. * u.mV, + V_reset: ArrayLike = -70. * u.mV, + V_th_inf: ArrayLike = -50. * u.mV, + V_th_reset: ArrayLike = -60. * u.mV, + V_th_initializer: Callable = braintools.init.Constant(-50. * u.mV), + a: ArrayLike = 0. / u.ms, + b: ArrayLike = 0.01 / u.ms, + k1: ArrayLike = 0.2 / u.ms, + k2: ArrayLike = 0.02 / u.ms, + R1: ArrayLike = 0., + R2: ArrayLike = 1., + A1: ArrayLike = 0. * u.mA, + A2: ArrayLike = 0. * u.mA, + V_initializer: Callable = braintools.init.Constant(-70. * u.mV), + I1_initializer: Callable = braintools.init.Constant(0. * u.mA), + I2_initializer: Callable = braintools.init.Constant(0. * u.mA), + spk_fun: Callable = braintools.surrogate.ReluGrad(), + spk_reset: str = 'soft', + ref_var: bool = False, + name: str = None, + ): + super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) + + # parameters + self.R = braintools.init.param(R, self.varshape) + self.tau = braintools.init.param(tau, self.varshape) + self.tau_ref = braintools.init.param(tau_ref, self.varshape) + self.V_rest = braintools.init.param(V_rest, self.varshape) + self.V_reset = braintools.init.param(V_reset, self.varshape) + self.V_th_inf = braintools.init.param(V_th_inf, self.varshape) + self.V_th_reset = braintools.init.param(V_th_reset, self.varshape) + self.a = braintools.init.param(a, self.varshape) + self.b = braintools.init.param(b, self.varshape) + self.k1 = braintools.init.param(k1, self.varshape) + self.k2 = braintools.init.param(k2, self.varshape) + self.R1 = braintools.init.param(R1, self.varshape) + self.R2 = braintools.init.param(R2, self.varshape) + self.A1 = braintools.init.param(A1, self.varshape) + self.A2 = braintools.init.param(A2, self.varshape) + self.V_initializer = V_initializer + self.I1_initializer = I1_initializer + self.I2_initializer = I2_initializer + self.V_th_initializer = V_th_initializer + self.ref_var = ref_var + + def init_state(self, batch_size: int = None, **kwargs): + self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) + self.I1 = brainstate.HiddenState(braintools.init.param(self.I1_initializer, self.varshape, batch_size)) + self.I2 = brainstate.HiddenState(braintools.init.param(self.I2_initializer, self.varshape, batch_size)) + self.V_th = brainstate.HiddenState(braintools.init.param(self.V_th_initializer, self.varshape, batch_size)) + self.last_spike_time = brainstate.ShortTermState( + braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) + ) + if self.ref_var: + self.refractory = brainstate.HiddenState( + braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) + ) + + def reset_state(self, batch_size: int = None, **kwargs): + self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) + self.I1.value = braintools.init.param(self.I1_initializer, self.varshape, batch_size) + self.I2.value = braintools.init.param(self.I2_initializer, self.varshape, batch_size) + self.V_th.value = braintools.init.param(self.V_th_initializer, self.varshape, batch_size) + self.last_spike_time.value = braintools.init.param( + braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size + ) + if self.ref_var: + self.refractory.value = braintools.init.param( + braintools.init.Constant(False), self.varshape, batch_size + ) + + def get_spike(self, V: ArrayLike = None, V_th: ArrayLike = None): + V = self.V.value if V is None else V + V_th = self.V_th.value if V_th is None else V_th + v_scaled = (V - V_th) / (V_th - self.V_reset) + return self.spk_fun(v_scaled) + + def update(self, x=0. * u.mA): + t = brainstate.environ.get('t') + last_v = self.V.value + last_i1 = self.I1.value + last_i2 = self.I2.value + last_v_th = self.V_th.value + last_spk = self.get_spike(last_v, last_v_th) + + # Apply spike effects + V_th_val = last_v_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) + v_reset = last_v - (V_th_val - self.V_reset) * last_spk + i1_reset = last_i1 + last_spk * (self.R1 * last_i1 + self.A1 - last_i1) + i2_reset = last_i2 + last_spk * (self.R2 * last_i2 + self.A2 - last_i2) + v_th_reset = last_v_th + last_spk * (u.math.maximum(self.V_th_reset, last_v_th) - last_v_th) + + # Update dynamics + def dI1(i1): + return -self.k1 * i1 + + def dI2(i2): + return -self.k2 * i2 + + def dV_th_func(v_th): + return self.a * (v_reset - self.V_rest) - self.b * (v_th - self.V_th_inf) + + def dv(v): + return (-(v - self.V_rest) + self.R * (i1_reset + i2_reset + self.sum_current_inputs(x, v))) / self.tau + + I1_candidate = brainstate.nn.exp_euler_step(dI1, i1_reset) + I2_candidate = brainstate.nn.exp_euler_step(dI2, i2_reset) + V_th_candidate = brainstate.nn.exp_euler_step(dV_th_func, v_th_reset) + V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) + V_candidate = self.sum_delta_inputs(V_candidate) + + refractory = (t - self.last_spike_time.value) < self.tau_ref + self.V.value = u.math.where(refractory, v_reset, V_candidate) + self.I1.value = u.math.where(refractory, i1_reset, I1_candidate) + self.I2.value = u.math.where(refractory, i2_reset, I2_candidate) + self.V_th.value = u.math.where(refractory, v_th_reset, V_th_candidate) + + spike_cond = self.V.value >= self.V_th.value + self.last_spike_time.value = jax.lax.stop_gradient( + u.math.where(spike_cond, t, self.last_spike_time.value) + ) + if self.ref_var: + self.refractory.value = jax.lax.stop_gradient( + u.math.logical_or(refractory, spike_cond) + ) + return self.get_spike() diff --git a/brainpy/state/_lif_test.py b/brainpy/state/_lif_test.py index 1f320f2f..f0587134 100644 --- a/brainpy/state/_lif_test.py +++ b/brainpy/state/_lif_test.py @@ -23,7 +23,7 @@ import jax import jax.numpy as jnp -from brainpy.state import IF, LIF, ALIF +from brainpy.state import IF, LIF, LIFRef, ALIF, ExpIF, ExpIFRef, AdExIF, AdExIFRef, QuaIF, AdQuaIF, AdQuaIFRef, Gif, GifRef class TestNeuron(unittest.TestCase): @@ -31,12 +31,13 @@ def setUp(self): self.in_size = 10 self.batch_size = 5 self.time_steps = 100 + self.dt = 0.1 * u.ms def generate_input(self): return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mA def test_if_neuron(self): - with brainstate.environ.context(dt=0.1 * u.ms): + with brainstate.environ.context(dt=self.dt): neuron = IF(self.in_size) inputs = self.generate_input() @@ -57,7 +58,7 @@ def test_if_neuron(self): self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1))) def test_lif_neuron(self): - with brainstate.environ.context(dt=0.1 * u.ms): + with brainstate.environ.context(dt=self.dt): tau = 20.0 * u.ms neuron = LIF(self.in_size, tau=tau) inputs = self.generate_input() @@ -90,13 +91,240 @@ def test_alif_neuron(self): # Test forward pass neuron.init_state(self.batch_size) call = brainstate.compile.jit(neuron) - with brainstate.environ.context(dt=0.1 * u.ms): + with brainstate.environ.context(dt=self.dt): for t in range(self.time_steps): out = call(inputs[t]) self.assertEqual(out.shape, (self.batch_size, self.in_size)) + def test_expif_neuron(self): + tau = 10.0 * u.ms + neuron = ExpIF(self.in_size, tau=tau) + inputs = self.generate_input() + + # Test initialization + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + + # Test forward pass + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + def test_LIFRef_neuron(self): + tau = 10.0 * u.ms + tau_ref = 2.0 * u.ms + neuron = LIFRef(self.in_size, tau=tau, tau_ref=tau_ref) + inputs = self.generate_input() + + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + self.assertEqual(neuron.tau_ref, tau_ref) + + neuron.init_state(self.batch_size) + + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t*self.dt): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.last_spike_time.value.shape, (self.batch_size, self.in_size)) + + def test_ExpIFRef_neuron(self): + tau = 10.0 * u.ms + tau_ref = 2.0 * u.ms + ref_var = True + neuron = ExpIFRef(self.in_size, tau=tau, tau_ref=tau_ref, ref_var=ref_var) + inputs = self.generate_input() + + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + self.assertEqual(neuron.tau_ref, tau_ref) + self.assertEqual(neuron.ref_var, ref_var) + + neuron.init_state(self.batch_size) + + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t*self.dt): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.last_spike_time.value.shape, (self.batch_size, self.in_size)) + if neuron.ref_var: + self.assertEqual(neuron.refractory.value.shape, (self.batch_size, self.in_size)) + + def test_adexif_neuron(self): + tau = 10.0 * u.ms + tau_w = 30.0 * u.ms + neuron = AdExIF(self.in_size, tau=tau, tau_w=tau_w) + inputs = self.generate_input() + + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + self.assertEqual(neuron.tau_w, tau_w) + + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.w.value.shape, (self.batch_size, self.in_size)) + + def test_adexifref_neuron(self): + tau = 10.0 * u.ms + tau_w = 30.0 * u.ms + tau_ref = 2.0 * u.ms + ref_var = True + neuron = AdExIFRef(self.in_size, tau=tau, tau_w=tau_w, tau_ref=tau_ref, ref_var=ref_var) + inputs = self.generate_input() + + # Test initialization + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + self.assertEqual(neuron.tau_w, tau_w) + self.assertEqual(neuron.tau_ref, tau_ref) + self.assertEqual(neuron.ref_var, ref_var) + + neuron.init_state(self.batch_size) + + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t*self.dt): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + # Test state variables + self.assertEqual(neuron.V.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.w.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.last_spike_time.value.shape, (self.batch_size, self.in_size)) + if neuron.ref_var: + self.assertEqual(neuron.refractory.value.shape, (self.batch_size, self.in_size)) + + def test_quaif_neuron(self): + tau = 10.0 * u.ms + neuron = QuaIF(self.in_size, tau=tau) + inputs = self.generate_input() + + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + def test_adquaif_neuron(self): + tau = 10.0 * u.ms + tau_w = 30.0 * u.ms + neuron = AdQuaIF(self.in_size, tau=tau, tau_w=tau_w) + inputs = self.generate_input() + + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + self.assertEqual(neuron.tau_w, tau_w) + + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.w.value.shape, (self.batch_size, self.in_size)) + + def test_adquaifref_neuron(self): + tau = 10.0 * u.ms + tau_w = 30.0 * u.ms + tau_ref = 2.0 * u.ms + ref_var = True + neuron = AdQuaIFRef(self.in_size, tau=tau, tau_w=tau_w, tau_ref=tau_ref, ref_var=ref_var) + inputs = self.generate_input() + + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + self.assertEqual(neuron.tau_w, tau_w) + self.assertEqual(neuron.tau_ref, tau_ref) + self.assertEqual(neuron.ref_var, ref_var) + + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t*self.dt): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + self.assertEqual(neuron.V.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.w.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.last_spike_time.value.shape, (self.batch_size, self.in_size)) + if neuron.ref_var: + self.assertEqual(neuron.refractory.value.shape, (self.batch_size, self.in_size)) + + def test_gif_neuron(self): + tau = 20.0 * u.ms + neuron = Gif(self.in_size, tau=tau) + inputs = self.generate_input() + + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.I1.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.I2.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.V_th.value.shape, (self.batch_size, self.in_size)) + + def test_gifref_neuron(self): + tau = 20.0 * u.ms + tau_ref = 2.0 * u.ms + ref_var = True + neuron = GifRef(self.in_size, tau=tau, tau_ref=tau_ref, ref_var=ref_var) + inputs = self.generate_input() + + self.assertEqual(neuron.in_size, (self.in_size,)) + self.assertEqual(neuron.out_size, (self.in_size,)) + self.assertEqual(neuron.tau, tau) + self.assertEqual(neuron.tau_ref, tau_ref) + self.assertEqual(neuron.ref_var, ref_var) + + neuron.init_state(self.batch_size) + call = brainstate.compile.jit(neuron) + with brainstate.environ.context(dt=self.dt): + for t in range(self.time_steps): + with brainstate.environ.context(t=t*self.dt): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, self.in_size)) + + self.assertEqual(neuron.V.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.I1.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.I2.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.V_th.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(neuron.last_spike_time.value.shape, (self.batch_size, self.in_size)) + if neuron.ref_var: + self.assertEqual(neuron.refractory.value.shape, (self.batch_size, self.in_size)) + def test_spike_function(self): - for NeuronClass in [IF, LIF, ALIF]: + for NeuronClass in [IF, LIF, ALIF, ExpIF, LIFRef, ExpIFRef, AdExIF, AdExIFRef, QuaIF, AdQuaIF, AdQuaIFRef, Gif, GifRef]: neuron = NeuronClass(self.in_size) neuron.init_state() v = jnp.linspace(-1, 1, self.in_size) * u.mV @@ -104,41 +332,48 @@ def test_spike_function(self): self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1))) def test_soft_reset(self): - for NeuronClass in [IF, LIF, ALIF]: + for NeuronClass in [IF, LIF, ALIF, ExpIF, LIFRef, ExpIFRef, AdExIF, AdExIFRef, QuaIF, AdQuaIF, AdQuaIFRef, Gif, GifRef]: neuron = NeuronClass(self.in_size, spk_reset='soft') inputs = self.generate_input() state = neuron.init_state(self.batch_size) call = brainstate.compile.jit(neuron) - with brainstate.environ.context(dt=0.1 * u.ms): + with brainstate.environ.context(dt=self.dt): for t in range(self.time_steps): - out = call(inputs[t]) - self.assertTrue(jnp.all(neuron.V.value <= neuron.V_th)) + with brainstate.environ.context(t=t*self.dt): + out = call(inputs[t]) + # For Gif models, V_th is a state variable + V_th = neuron.V_th.value if hasattr(neuron.V_th, 'value') else neuron.V_th + self.assertTrue(jnp.all(neuron.V.value <= V_th)) def test_hard_reset(self): - for NeuronClass in [IF, LIF, ALIF]: + for NeuronClass in [IF, LIF, ALIF, ExpIF, LIFRef, ExpIFRef, AdExIF, AdExIFRef, QuaIF, AdQuaIF, AdQuaIFRef, Gif, GifRef]: neuron = NeuronClass(self.in_size, spk_reset='hard') inputs = self.generate_input() state = neuron.init_state(self.batch_size) call = brainstate.compile.jit(neuron) - with brainstate.environ.context(dt=0.1 * u.ms): + with brainstate.environ.context(dt=self.dt): for t in range(self.time_steps): - out = call(inputs[t]) - self.assertTrue(jnp.all((neuron.V.value < neuron.V_th) | (neuron.V.value == 0. * u.mV))) + with brainstate.environ.context(t=t*self.dt): + out = call(inputs[t]) + # For Gif models, V_th is a state variable + V_th = neuron.V_th.value if hasattr(neuron.V_th, 'value') else neuron.V_th + self.assertTrue(jnp.all((neuron.V.value < V_th) | (neuron.V.value == 0. * u.mV))) def test_detach_spike(self): - for NeuronClass in [IF, LIF, ALIF]: + for NeuronClass in [IF, LIF, ALIF, ExpIF, LIFRef, ExpIFRef, AdExIF, AdExIFRef, QuaIF, AdQuaIF, AdQuaIFRef, Gif, GifRef]: neuron = NeuronClass(self.in_size) inputs = self.generate_input() state = neuron.init_state(self.batch_size) call = brainstate.compile.jit(neuron) - with brainstate.environ.context(dt=0.1 * u.ms): + with brainstate.environ.context(dt=self.dt): for t in range(self.time_steps): - out = call(inputs[t]) - self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type) + with brainstate.environ.context(t=t*self.dt): + out = call(inputs[t]) + self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type) def test_keep_size(self): in_size = (2, 3) - for NeuronClass in [IF, LIF, ALIF]: + for NeuronClass in [IF, LIF, ALIF, ExpIF, LIFRef, ExpIFRef, AdExIF, AdExIFRef, QuaIF, AdQuaIF, AdQuaIFRef, Gif, GifRef]: neuron = NeuronClass(in_size) self.assertEqual(neuron.in_size, in_size) self.assertEqual(neuron.out_size, in_size) @@ -146,12 +381,14 @@ def test_keep_size(self): inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mA state = neuron.init_state(self.batch_size) call = brainstate.compile.jit(neuron) - with brainstate.environ.context(dt=0.1 * u.ms): + with brainstate.environ.context(dt=self.dt): for t in range(self.time_steps): - out = call(inputs[t]) - self.assertEqual(out.shape, (self.batch_size, *in_size)) + with brainstate.environ.context(t=t*self.dt): + out = call(inputs[t]) + self.assertEqual(out.shape, (self.batch_size, *in_size)) if __name__ == '__main__': - with brainstate.environ.context(dt=0.1): - unittest.main() + # with brainstate.environ.context(dt=0.1): + # unittest.main() + unittest.main() diff --git a/brainpy/state/_projection.py b/brainpy/state/_projection.py index 5a47d510..49da161f 100644 --- a/brainpy/state/_projection.py +++ b/brainpy/state/_projection.py @@ -71,7 +71,7 @@ class Projection(brainstate.nn.Module): Derived classes should implement specific projection behaviors, such as dense connectivity, sparse connectivity, or specific weight update rules. """ - __module__ = 'brainstate.nn' + __module__ = 'brainpy.state' def update(self, *args, **kwargs): sub_nodes = tuple(self.nodes(allowed_hierarchy=(1, 1)).values()) @@ -181,7 +181,7 @@ class AlignPostProj(Projection): >>> exe_current = E(pop.get_spike()) """ - __module__ = 'brainstate.nn' + __module__ = 'brainpy.state' def __init__( self, @@ -303,7 +303,7 @@ class DeltaProj(Projection): ... ) >>> delta_input(1.0) # Apply voltage increment directly """ - __module__ = 'brainstate.nn' + __module__ = 'brainpy.state' def __init__( self, @@ -380,7 +380,7 @@ class CurrentProj(Projection): ... ) >>> current_input(0.2) # Apply external current """ - __module__ = 'brainstate.nn' + __module__ = 'brainpy.state' def __init__( self, @@ -439,6 +439,7 @@ class align_pre_projection(Projection): short-term plasticity. """ + __module__ = 'brainpy.state' def __init__( self, @@ -486,6 +487,7 @@ class align_post_projection(Projection): properties, synaptic outputs, post-synaptic dynamics, and short-term plasticity. """ + __module__ = 'brainpy.state' def __init__( self, diff --git a/brainpy/state/_readout.py b/brainpy/state/_readout.py index 3775220a..2125f83f 100644 --- a/brainpy/state/_readout.py +++ b/brainpy/state/_readout.py @@ -74,7 +74,7 @@ class LeakyRateReadout(brainstate.nn.Module): r : HiddenState Hidden state representing the output values """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -160,7 +160,7 @@ class LeakySpikeReadout(Neuron): Synaptic weight matrix """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, diff --git a/brainpy/state/_stp.py b/brainpy/state/_stp.py index 8442bdd5..1dd53cb7 100644 --- a/brainpy/state/_stp.py +++ b/brainpy/state/_stp.py @@ -97,7 +97,7 @@ class STP(Synapse): .. [2] Tsodyks, M., Pawelzik, K., & Markram, H. (1998). Neural networks with dynamic synapses. Neural computation, 10(4), 821-835. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -203,7 +203,7 @@ class STD(Synapse): pyramidal neurons depends on neurotransmitter release probability. Proceedings of the National Academy of Sciences, 94(2), 719-723. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, diff --git a/brainpy/state/_synapse.py b/brainpy/state/_synapse.py index 625d479a..17811014 100644 --- a/brainpy/state/_synapse.py +++ b/brainpy/state/_synapse.py @@ -26,7 +26,7 @@ from ._base import Synapse __all__ = [ - 'Alpha', 'AMPA', 'GABAa', + 'Alpha', 'AMPA', 'GABAa', 'BioNMDA', ] @@ -70,7 +70,7 @@ class Alpha(Synapse): This implementation uses an exponential Euler integration method. The output of this synapse is the conductance value. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -176,7 +176,7 @@ class AMPA(Synapse): and implications for stimulus processing. Proceedings of the National Academy of Sciences, 109(45), 18553-18558. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -289,7 +289,7 @@ class GABAa(AMPA): properties of neocortical pyramidal neurons in vivo. Journal of neurophysiology, 81(4), 1531-1547. """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, @@ -310,3 +310,159 @@ def __init__( in_size=in_size, g_initializer=g_initializer ) + + +class BioNMDA(Synapse): + r"""Biological NMDA receptor synapse model. + + This class implements a detailed kinetic model of NMDA (N-methyl-D-aspartate) + receptor-mediated synaptic transmission. NMDA receptors are ionotropic glutamate + receptors that play a crucial role in synaptic plasticity, learning, and memory. + + Unlike AMPA receptors, NMDA receptors exhibit both ligand-gating and voltage-dependent + properties. The voltage dependence arises from the blocking of the receptor pore by + extracellular magnesium ions (Mg²⁺) at resting potential. The model uses a second-order + kinetic scheme to capture the dynamics of NMDA receptors: + + $$ + \frac{dg}{dt} = \alpha_1 x (1-g) - \beta_1 g + $$ + + $$ + \frac{dx}{dt} = \alpha_2 [T] (1-x) - \beta_2 x + $$ + + $$ + I_{syn} = g_{max} \cdot g \cdot g_{\infty}(V,[Mg^{2+}]_o) \cdot (V - E) + $$ + + where: + + - $g$ represents the fraction of receptors in the open state + - $x$ is an auxiliary variable representing an intermediate state + - $\alpha_1, \beta_1$ are the conversion rates for the $g$ variable [ms^-1] + - $\alpha_2, \beta_2$ are the conversion rates for the $x$ variable [ms^-1] or [ms^-1 mM^-1] + - $[T]$ is the neurotransmitter (glutamate) concentration [mM] + - $g_{\infty}(V,[Mg^{2+}]_o)$ is the voltage-dependent magnesium block function + - $I_{syn}$ is the resulting synaptic current + - $g_{max}$ is the maximum conductance + - $V$ is the membrane potential + - $E$ is the reversal potential + + The magnesium block is typically modeled as: + + $$ + g_{\infty}(V,[Mg^{2+}]_o) = \frac{1}{1 + [Mg^{2+}]_o \cdot \exp(-a \cdot V) / b} + $$ + + The neurotransmitter concentration $[T]$ follows a square pulse of amplitude T and + duration T_dur after each presynaptic spike. + + Parameters + ---------- + in_size : Size + Size of the input. + name : str, optional + Name of the synapse instance. + alpha1 : ArrayLike, default=2.0/u.ms + Conversion rate of g from inactive to active [ms^-1]. + beta1 : ArrayLike, default=0.01/u.ms + Conversion rate of g from active to inactive [ms^-1]. + alpha2 : ArrayLike, default=1.0/(u.ms*u.mM) + Conversion rate of x from inactive to active [ms^-1 mM^-1]. + beta2 : ArrayLike, default=0.5/u.ms + Conversion rate of x from active to inactive [ms^-1]. + T : ArrayLike, default=1.0*u.mM + Peak neurotransmitter concentration when released [mM]. + T_dur : ArrayLike, default=0.5*u.ms + Duration of neurotransmitter presence in the synaptic cleft [ms]. + g_initializer : ArrayLike or Callable, default=init.Constant(0. * u.mS) + Initial value or initializer for the synaptic conductance. + x_initializer : ArrayLike or Callable, default=init.Constant(0.) + Initial value or initializer for the auxiliary state variable. + + Attributes + ---------- + g : HiddenState + Fraction of receptors in the open state. + x : HiddenState + Auxiliary state variable representing intermediate receptor state. + spike_arrival_time : ShortTermState + Time of the most recent presynaptic spike. + + Notes + ----- + - NMDA receptors have slower kinetics compared to AMPA receptors, with rise times + of several milliseconds and decay time constants of tens to hundreds of milliseconds. + - The voltage-dependent magnesium block is typically implemented in the output layer + or postsynaptic neuron model, not in this synapse model itself. + - NMDA receptors are permeable to calcium ions, which can trigger various intracellular + signaling cascades important for synaptic plasticity. + - This implementation uses an exponential Euler integration method for both state variables. + + References + ---------- + .. [1] Devaney, A. J. (2010). Mathematical Foundations of Neuroscience. Springer New York, 162. + .. [2] Furukawa, H., Singh, S. K., Mancusso, R., & Gouaux, E. (2005). Subunit arrangement + and function in NMDA receptors. Nature, 438(7065), 185-192. + .. [3] Li, F., & Tsien, J. Z. (2009). Memory and the NMDA receptors. The New England + Journal of Medicine, 361(3), 302. + .. [4] Jahr, C. E., & Stevens, C. F. (1990). Voltage dependence of NMDA-activated + macroscopic conductances predicted by single-channel kinetics. Journal of Neuroscience, + 10(9), 3178-3182. + """ + __module__ = 'brainpy.state' + + def __init__( + self, + in_size: Size, + name: Optional[str] = None, + alpha1: ArrayLike = 2.0 / u.ms, + beta1: ArrayLike = 0.01 / u.ms, + alpha2: ArrayLike = 1.0 / (u.ms * u.mM), + beta2: ArrayLike = 0.5 / u.ms, + T: ArrayLike = 1.0 * u.mM, + T_dur: ArrayLike = 0.5 * u.ms, + g_initializer: ArrayLike | Callable = braintools.init.Constant(0. * u.mS), + x_initializer: ArrayLike | Callable = braintools.init.Constant(0.), + ): + super().__init__(name=name, in_size=in_size) + + # parameters + self.alpha1 = braintools.init.param(alpha1, self.varshape) + self.beta1 = braintools.init.param(beta1, self.varshape) + self.alpha2 = braintools.init.param(alpha2, self.varshape) + self.beta2 = braintools.init.param(beta2, self.varshape) + self.T = braintools.init.param(T, self.varshape) + self.T_duration = braintools.init.param(T_dur, self.varshape) + self.g_initializer = g_initializer + self.x_initializer = x_initializer + + def init_state(self, batch_size=None): + self.g = brainstate.HiddenState(braintools.init.param(self.g_initializer, self.varshape, batch_size)) + self.x = brainstate.HiddenState(braintools.init.param(self.x_initializer, self.varshape, batch_size)) + self.spike_arrival_time = brainstate.ShortTermState( + braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) + ) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.g.value = braintools.init.param(self.g_initializer, self.varshape, batch_or_mode) + self.x.value = braintools.init.param(self.x_initializer, self.varshape, batch_or_mode) + self.spike_arrival_time.value = braintools.init.param( + braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_or_mode + ) + + def update(self, pre_spike): + t = brainstate.environ.get('t') + self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value) + TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T + + # Update x first (intermediate state) + dx = lambda x: self.alpha2 * TT * (1 * u.get_unit(x) - x) - self.beta2 * x + self.x.value = brainstate.nn.exp_euler_step(dx, self.x.value) + + # Update g (open state) based on current x value + dg = lambda g: self.alpha1 * self.x.value * (1 * u.get_unit(g) - g) - self.beta1 * g + self.g.value = brainstate.nn.exp_euler_step(dg, self.g.value) + + return self.g.value diff --git a/brainpy/state/_synapse_test.py b/brainpy/state/_synapse_test.py index 4154bd9d..0260fd9e 100644 --- a/brainpy/state/_synapse_test.py +++ b/brainpy/state/_synapse_test.py @@ -21,7 +21,7 @@ import jax.numpy as jnp import pytest -from brainpy.state import Expon, STP, STD +from brainpy.state import Expon, STP, STD, AMPA, GABAa, BioNMDA class TestSynapse(unittest.TestCase): @@ -125,7 +125,140 @@ def test_keep_size(self): out = call(inputs[t]) self.assertEqual(out.shape, (self.batch_size, *in_size)) + def test_ampa_synapse(self): + alpha = 0.98 / (u.ms * u.mM) + beta = 0.18 / u.ms + T = 0.5 * u.mM + T_dur = 0.5 * u.ms + synapse = AMPA(self.in_size, alpha=alpha, beta=beta, T=T, T_dur=T_dur) + + # Test initialization + self.assertEqual(synapse.in_size, (self.in_size,)) + self.assertEqual(synapse.out_size, (self.in_size,)) + self.assertEqual(synapse.alpha, alpha) + self.assertEqual(synapse.beta, beta) + self.assertEqual(synapse.T, T) + self.assertEqual(synapse.T_duration, T_dur) + + # Test forward pass + synapse.init_state(self.batch_size) + call = brainstate.compile.jit(synapse) + with brainstate.environ.context(dt=0.1 * u.ms, t=0. * u.ms): + # Test with spike input (True/False array) + spike_input = jnp.zeros((self.batch_size, self.in_size), dtype=bool) + spike_input = spike_input.at[0, 0].set(True) # Single spike + + out1 = call(spike_input) + self.assertEqual(out1.shape, (self.batch_size, self.in_size)) + + # Conductance should increase after spike + out2 = call(jnp.zeros((self.batch_size, self.in_size), dtype=bool)) + self.assertTrue(jnp.any(out2[0, 0] > 0 * u.mS)) # Should have some conductance + + def test_gabaa_synapse(self): + alpha = 0.53 / (u.ms * u.mM) + beta = 0.18 / u.ms + T = 1.0 * u.mM + T_dur = 1.0 * u.ms + synapse = GABAa(self.in_size, alpha=alpha, beta=beta, T=T, T_dur=T_dur) + + # Test initialization + self.assertEqual(synapse.in_size, (self.in_size,)) + self.assertEqual(synapse.out_size, (self.in_size,)) + self.assertEqual(synapse.alpha, alpha) + self.assertEqual(synapse.beta, beta) + self.assertEqual(synapse.T, T) + self.assertEqual(synapse.T_duration, T_dur) + + # Test forward pass + synapse.init_state(self.batch_size) + call = brainstate.compile.jit(synapse) + with brainstate.environ.context(dt=0.1 * u.ms, t=0. * u.ms): + spike_input = jnp.zeros((self.batch_size, self.in_size), dtype=bool) + spike_input = spike_input.at[0, 0].set(True) + + out1 = call(spike_input) + self.assertEqual(out1.shape, (self.batch_size, self.in_size)) + + # Conductance should increase after spike + out2 = call(jnp.zeros((self.batch_size, self.in_size), dtype=bool)) + self.assertTrue(jnp.any(out2[0, 0] > 0 * u.mS)) + + def test_bionmda_synapse(self): + alpha1 = 2.0 / u.ms + beta1 = 0.01 / u.ms + alpha2 = 1.0 / (u.ms * u.mM) + beta2 = 0.5 / u.ms + T = 1.0 * u.mM + T_dur = 0.5 * u.ms + synapse = BioNMDA(self.in_size, alpha1=alpha1, beta1=beta1, + alpha2=alpha2, beta2=beta2, T=T, T_dur=T_dur) + + # Test initialization + self.assertEqual(synapse.in_size, (self.in_size,)) + self.assertEqual(synapse.out_size, (self.in_size,)) + self.assertEqual(synapse.alpha1, alpha1) + self.assertEqual(synapse.beta1, beta1) + self.assertEqual(synapse.alpha2, alpha2) + self.assertEqual(synapse.beta2, beta2) + self.assertEqual(synapse.T, T) + self.assertEqual(synapse.T_duration, T_dur) + + # Test forward pass with spike inputs + synapse.init_state(self.batch_size) + call = brainstate.compile.jit(synapse) + with brainstate.environ.context(dt=0.1 * u.ms, t=0. * u.ms): + # Create spike input at first time step + spike_input = jnp.zeros((self.batch_size, self.in_size), dtype=bool) + spike_input = spike_input.at[0, 0].set(True) # Single spike at position (0, 0) + + # First call with spike + out1 = call(spike_input) + self.assertEqual(out1.shape, (self.batch_size, self.in_size)) + + # Verify state variables exist and have correct shape + self.assertEqual(synapse.g.value.shape, (self.batch_size, self.in_size)) + self.assertEqual(synapse.x.value.shape, (self.batch_size, self.in_size)) + + # Continue simulation without spikes + no_spike = jnp.zeros((self.batch_size, self.in_size), dtype=bool) + + # NMDA should have slower dynamics - collect several time points + outputs = [out1] + for _ in range(10): + out = call(no_spike) + outputs.append(out) + + # Check that conductance increases over time initially (slower rise time for NMDA) + # Due to the two-state kinetics, there should be some non-zero conductance + self.assertTrue(jnp.any(outputs[-1][0, 0] >= 0 * u.mS)) # Should have developed some conductance + + def test_bionmda_two_state_dynamics(self): + """Test that BioNMDA properly implements second-order kinetics with two state variables""" + synapse = BioNMDA(self.in_size) + synapse.init_state(self.batch_size) + call = brainstate.compile.jit(synapse) + + with brainstate.environ.context(dt=0.1 * u.ms, t=0. * u.ms): + # Initial state should be zero (g has units, x is dimensionless) + self.assertTrue(jnp.allclose(synapse.g.value.to_decimal(u.mS), 0.)) + self.assertTrue(jnp.allclose(synapse.x.value, 0.)) + + # Apply a spike + spike_input = jnp.zeros((self.batch_size, self.in_size), dtype=bool) + spike_input = spike_input.at[0, 0].set(True) + + call(spike_input) + + # After spike, both x and g should be non-negative + x_val = synapse.x.value[0, 0] + g_val = synapse.g.value[0, 0] + + # x is dimensionless, g has units + self.assertTrue(x_val >= 0) + self.assertTrue(g_val >= 0 * u.mS) + if __name__ == '__main__': - with brainstate.environ.context(dt=0.1): + with brainstate.environ.context(dt=0.1 * u.ms): unittest.main() diff --git a/brainpy/state/_synaptic_projection.py b/brainpy/state/_synaptic_projection.py index e291100c..b534d5ce 100644 --- a/brainpy/state/_synaptic_projection.py +++ b/brainpy/state/_synaptic_projection.py @@ -22,6 +22,7 @@ import brainunit as u from brainstate.typing import ArrayLike +from ._misc import set_module_as from ._projection import Projection __all__ = [ @@ -31,11 +32,11 @@ class align_pre_ltp(Projection): - pass + __module__ = 'brainpy.state' class align_post_ltp(Projection): - pass + __module__ = 'brainpy.state' def get_gap_junction_post_key(i: int): @@ -82,6 +83,8 @@ class SymmetryGapJunction(Projection): AsymmetryGapJunction : For gap junctions with different conductances in each direction. """ + __module__ = 'brainpy.state' + def __init__( self, couples: Union[Tuple[brainstate.nn.Dynamics, brainstate.nn.Dynamics], brainstate.nn.Dynamics], @@ -132,6 +135,7 @@ def update(self, *args, **kwargs): ) +@set_module_as('brainpy.state') def symmetry_gap_junction_projection( pre: brainstate.nn.Dynamics, pre_value: ArrayLike, @@ -281,6 +285,7 @@ class AsymmetryGapJunction(Projection): -------- SymmetryGapJunction : For gap junctions with identical conductance in both directions. """ + __module__ = 'brainpy.state' def __init__( self, @@ -322,6 +327,7 @@ def update(self, *args, **kwargs): ) +@set_module_as('brainpy.state') def asymmetry_gap_junction_projection( pre: brainstate.nn.Dynamics, pre_value: ArrayLike, diff --git a/brainpy/state/_synouts.py b/brainpy/state/_synouts.py index 54605e52..8592ded8 100644 --- a/brainpy/state/_synouts.py +++ b/brainpy/state/_synouts.py @@ -33,7 +33,7 @@ class SynOut(brainstate.nn.Module, BindCondData): :py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :py:class:`~.BindCondData`. """ - __module__ = 'brainstate.nn' + __module__ = 'brainpy.state' def __init__(self, ): super().__init__() @@ -71,7 +71,7 @@ class COBA(SynOut): -------- CUBA """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__(self, E: brainstate.typing.ArrayLike): super().__init__() @@ -100,7 +100,7 @@ class CUBA(SynOut): -------- COBA """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__(self, scale: brainstate.typing.ArrayLike = u.volt): super().__init__() @@ -140,7 +140,7 @@ class MgBlock(SynOut): V_offset: ArrayLike The offset potential. Default 0. [mV] """ - __module__ = 'brainpy' + __module__ = 'brainpy.state' def __init__( self, diff --git a/changelog.md b/changelog.md index 097aa7c6..31105d61 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,110 @@ # Changelog +## Version 2.7.1 + +**Release Date:** October 2025 + +This is a feature release that introduces new neuron and synapse models in the state-based API (`brainpy.state`) and enhances the Dynamics base class with improved input handling. + +### Major Changes + +#### New Neuron Models (brainpy.state) +- **LIF (Leaky Integrate-and-Fire) Variants**: Added comprehensive set of LIF neuron models + - `LIF`: Basic LIF neuron with exponential synaptic input + - `LifRef`: LIF with refractory period + - `ExpIF`: Exponential Integrate-and-Fire neuron + - `ExpIFRef`: ExpIF with refractory period + - `AdExIF`: Adaptive Exponential Integrate-and-Fire neuron + - `AdExIFRef`: AdExIF with refractory period + - `QuaIF`: Quadratic Integrate-and-Fire neuron + - `QuaIFRef`: QuaIF with refractory period + - `AdQuaIF`: Adaptive Quadratic Integrate-and-Fire neuron + - `AdQuaIFRef`: AdQuaIF with refractory period + - `GifRef`: Generalized Integrate-and-Fire with refractory period + +- **Izhikevich Neuron Models**: Added new Izhikevich neuron implementations + - `Izhikevich`: Basic Izhikevich neuron model + - `IzhikevichRef`: Izhikevich with refractory period + +- **Hodgkin-Huxley Model**: Added classic biophysical neuron model + - `HH`: Classic Hodgkin-Huxley model with Na+ and K+ channels + +#### New Synapse Models (brainpy.state) +- **BioNMDA**: Biological NMDA receptor with second-order kinetics + - Implements two-state cascade dynamics (x and g variables) + - Slower rise time compared to AMPA (biologically realistic) + - Comprehensive documentation with mathematical formulation + +### Features + +#### Model Implementation +- All new models use the brainstate ecosystem (HiddenState, ShortTermState, LongTermState) +- Proper unit support with brainunit integration +- Exponential Euler integration for numerical stability +- Batch processing support across all models +- Consistent API design following BrainPy v2.7+ architecture + +#### Dynamics Class Enhancements +- Enhanced input handling capabilities in the Dynamics base class +- Added new properties for better state management +- Improved integration with brainstate framework +- Refactored to use public methods instead of private counterparts for clarity + +#### Documentation +- Added comprehensive Examples sections to all neuron classes in `_lif.py` +- Each example includes: + - Import statements for required modules + - Basic usage with parameter specifications + - State initialization examples + - Update and spike generation examples + - Network integration with `brainstate.nn.Sequential` + - Notes highlighting key features +- All 13 neuron classes in `_lif.py` now have complete documentation +- Simplified documentation paths by removing 'core-concepts' and 'quickstart' prefixes in index.rst + +### Bug Fixes +- Fixed import paths in `_base.py`: changed references from brainstate to brainpy for consistency (057b872d) +- Fixed test suite issues (95ec2037) +- Fixed test suite for proper unit handling in synapse models + +### Code Quality +- Refactored module assignments to `brainpy.state` for consistency across files (06b2bf4d) +- Refactored method calls in `_base.py`: replaced private methods with public counterparts (210426ab) + +### Testing +- Added comprehensive test suites for all new neuron models +- Added AMPA and GABAa synapse tests +- Added tests for Izhikevich neuron variants +- Added tests for Hodgkin-Huxley model +- All tests passing with proper unit handling + +### Files Modified +- `brainpy/__init__.py`: Updated version to 2.7.1 +- `brainpy/state/_base.py`: Enhanced Dynamics class with improved input handling (447 lines added) +- `brainpy/state/_lif.py`: Added extensive LIF neuron variants (1862 lines total) +- `brainpy/state/_izhikevich.py`: New file with Izhikevich models (407 lines) +- `brainpy/state/_hh.py`: New file with Hodgkin-Huxley model (666 lines) +- `brainpy/state/_synapse.py`: Added BioNMDA model (158 lines) +- `brainpy/state/_projection.py`: Updated for consistency (43 lines modified) +- `brainpy/state/__init__.py`: Updated exports for new models +- Test files added: `_lif_test.py`, `_izhikevich_test.py`, `_hh_test.py`, `_synapse_test.py`, `_base_test.py` +- Documentation updates in `docs_state/index.rst` + +### Removed +- Removed outdated documentation notebooks from `docs_state/`: + - `checkpointing-en.ipynb` and `checkpointing-zh.ipynb` + - `snn_simulation-en.ipynb` and `snn_simulation-zh.ipynb` + - `snn_training-en.ipynb` and `snn_training-zh.ipynb` + +### Notes +- This release significantly expands the `brainpy.state` module with biologically realistic neuron and synapse models +- All new models are fully compatible with the brainstate ecosystem +- Enhanced documentation provides clear usage examples for all models +- The Dynamics class refactoring improves the foundation for future state-based model development + + + + ## Version 3.0.1 **Release Date:** October 2025