diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 333d6bd0c..2b795276d 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -124,7 +124,7 @@ def register_delay( """ # delay steps if delay_step is None: - return delay_step + delay_type = 'none' elif isinstance(delay_step, int): delay_type = 'homo' elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)): diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index d8d69652b..a6c0a3169 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -1052,7 +1052,7 @@ class NMDA(TwoEndConn): .. math:: - I_{syn} = g_{NMDA}(t) (V(t)-E) \cdot g_{\infty} + I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the reversal potential. @@ -1061,7 +1061,7 @@ class NMDA(TwoEndConn): .. math:: - & g_{NMDA} (t) = g_{max} g \\ + & g_\mathrm{NMDA} (t) = g_{max} g \\ & \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\ & \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k}) @@ -1235,6 +1235,10 @@ def __init__( # integral self.integral = odeint(method=method, f=JointEq([self.dg, self.dx])) + def reset(self): + self.g.value = bm.zeros(self.pre.num) + self.x.value = bm.zeros(self.pre.num) + def dg(self, g, t, x): return -g / self.tau_decay + self.a * x * (1 - g) diff --git a/brainpy/dyn/synapses/biological_models.py b/brainpy/dyn/synapses/biological_models.py index 7babd50ea..98f02f6b0 100644 --- a/brainpy/dyn/synapses/biological_models.py +++ b/brainpy/dyn/synapses/biological_models.py @@ -6,12 +6,13 @@ from brainpy.connect import TwoEndConnector, All2All, One2One from brainpy.dyn.base import NeuGroup, TwoEndConn from brainpy.initialize import Initializer, init_param -from brainpy.integrators import odeint +from brainpy.integrators import odeint, JointEq from brainpy.types import Tensor __all__ = [ 'AMPA', 'GABAa', + 'BioNMDA', ] @@ -328,3 +329,274 @@ def __init__( T_duration=T_duration, method=method, name=name) + + +class BioNMDA(TwoEndConn): + r"""Biological NMDA synapse model. + + **Model Descriptions** + + The NMDA receptor is a glutamate receptor and ion channel found in neurons. + The NMDA receptor is one of three types of ionotropic glutamate receptors, + the other two being AMPA and kainate receptors. + + The NMDA receptor mediated conductance depends on the postsynaptic voltage. + The voltage dependence is due to the blocking of the pore of the NMDA receptor + from the outside by a positively charged magnesium ion. The channel is + nearly completely blocked at resting potential, but the magnesium block is + relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` + that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\a V} + \frac{[{Mg}^{2+}]_{o}} {\b})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, + usually 1 mM. Thus, the channel acts as a + "coincidence detector" and only once both of these conditions are met, the + channel opens and it allows positively charged ions (cations) to flow through + the cell membrane [2]_. + + If we make the approximation that the magnesium block changes + instantaneously with voltage and is independent of the gating of the channel, + the net NMDA receptor-mediated synaptic current is given by + + .. math:: + + I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} + + where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the + reversal potential. + + Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_: + + .. math:: + + & g_\mathrm{NMDA} (t) = g_{max} g \\ + & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\ + & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x + + where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and + :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x. + + The NMDA receptor has been thought to be very important for controlling + synaptic plasticity and mediating learning and memory functions [3]_. + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = bp.dyn.HH(1) + >>> neu2 = bp.dyn.HH(1) + >>> syn1 = bp.dyn.BioNMDA(neu1, neu2, bp.connect.All2All(), E=0.) + >>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x') + >>> plt.legend() + >>> plt.show() + + Parameters + ---------- + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + conn_type: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `dense`. + delay_step: int, ndarray, JaxArray, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + g_max: float, ndarray, JaxArray, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + E: float, JaxArray, ndarray + The reversal potential for the synaptic current. [mV] + a: float, JaxArray, ndarray + Binding constant. Default 0.062 + b: float, JaxArray, ndarray + Unbinding constant. Default 3.57 + cc_Mg: float, JaxArray, ndarray + Concentration of Magnesium ion. Default 1.2 [mM]. + alpha1: float, JaxArray, ndarray + The conversion rate of g from inactive to active. Default 2 ms^-1. + beta1: float, JaxArray, ndarray + The conversion rate of g from active to inactive. Default 0.01 ms^-1. + alpha2: float, JaxArray, ndarray + The conversion rate of x from inactive to active. Default 1 ms^-1. + beta2: float, JaxArray, ndarray + The conversion rate of x from active to inactive. Default 0.5 ms^-1. + + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- + + .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M]. + Springer New York, 2010: 162. + .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and + Eric Gouaux. "Subunit arrangement and function in NMDA receptors." + Nature 438, no. 7065 (2005): 185-192. + .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New + England journal of medicine, 361(3), p.302. + .. [4] https://en.wikipedia.org/wiki/NMDA_receptor + + """ + + def __init__( + self, + pre: NeuGroup, + post: NeuGroup, + conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], + conn_type: str = 'dense', + g_max: Union[float, Tensor, Initializer, Callable] = 0.15, + delay_step: Union[int, Tensor, Initializer, Callable] = None, + E: Union[float, Tensor] = 0., + cc_Mg: Union[float, Tensor] = 1.2, + a: Union[float, Tensor] = 0.062, + b: Union[float, Tensor] = 3.57, + alpha1: Union[float, Tensor] = 2., + beta1: Union[float, Tensor] = 0.01, + alpha2: Union[float, Tensor] = 1., + beta2: Union[float, Tensor] = 0.5, + T_0: Union[float, Tensor] = 1., + T_dur: Union[float, Tensor] = 0.5, + method: str = 'exp_auto', + name: str = None, + ): + super(BioNMDA, self).__init__(pre=pre, post=post, conn=conn, name=name) + self.check_pre_attrs('spike') + self.check_post_attrs('input', 'V') + + # parameters + self.E = E + self.alpha = a + self.beta = b + self.cc_Mg = cc_Mg + self.beta1 = beta1 + self.beta2 = beta2 + self.alpha1 = alpha1 + self.alpha2 = alpha2 + self.T_0 = T_0 + self.T_dur = T_dur + if bm.size(alpha1) != 1: + raise ValueError(f'"alpha1" must be a scalar or a tensor with size of 1. But we got {alpha1}') + if bm.size(beta1) != 1: + raise ValueError(f'"beta1" must be a scalar or a tensor with size of 1. But we got {beta1}') + if bm.size(alpha2) != 1: + raise ValueError(f'"alpha2" must be a scalar or a tensor with size of 1. But we got {alpha2}') + if bm.size(beta2) != 1: + raise ValueError(f'"beta2" must be a scalar or a tensor with size of 1. But we got {beta2}') + if bm.size(E) != 1: + raise ValueError(f'"E" must be a scalar or a tensor with size of 1. But we got {E}') + if bm.size(a) != 1: + raise ValueError(f'"a" must be a scalar or a tensor with size of 1. But we got {a}') + if bm.size(b) != 1: + raise ValueError(f'"b" must be a scalar or a tensor with size of 1. But we got {b}') + if bm.size(cc_Mg) != 1: + raise ValueError(f'"cc_Mg" must be a scalar or a tensor with size of 1. But we got {cc_Mg}') + if bm.size(T_0) != 1: + raise ValueError(f'"T_0" must be a scalar or a tensor with size of 1. But we got {T_0}') + if bm.size(T_dur) != 1: + raise ValueError(f'"T_dur" must be a scalar or a tensor with size of 1. But we got {T_dur}') + + # connections and weights + self.conn_type = conn_type + if conn_type not in ['sparse', 'dense']: + raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') + if self.conn is None: + raise ValueError(f'Must provide "conn" when initialize the model {self.name}') + if isinstance(self.conn, One2One): + self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + elif isinstance(self.conn, All2All): + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + if bm.size(self.g_max) != 1: + self.weight_type = 'heter' + bm.fill_diagonal(self.g_max, 0.) + else: + self.weight_type = 'homo' + else: + if conn_type == 'sparse': + self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') + self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + elif conn_type == 'dense': + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + if self.weight_type == 'homo': + self.conn_mat = self.conn.require('conn_mat') + else: + raise ValueError(f'Unknown connection type: {conn_type}') + + # variables + self.g = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_)) + self.x = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_)) + self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num, dtype=bm.float_) * -1e7) + self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike) + + # integral + self.integral = odeint(method=method, f=JointEq([self.dg, self.dx])) + + def reset(self): + self.g.value = bm.zeros(self.pre.num) + self.x.value = bm.zeros(self.pre.num) + self.spike_arrival_time.value = bm.ones(self.pre.num) * -1e7 + + def dg(self, g, t, x): + return self.alpha1 * x * (1 - g) - self.beta1 * g + + def dx(self, x, t, T): + return self.alpha2 * T * (1 - x) - self.beta2 * x + + def update(self, t, dt): + # delays + delayed_pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) + + # update synapse variables + self.spike_arrival_time.value = bm.where(delayed_pre_spike, t, self.spike_arrival_time) + T = ((t - self.spike_arrival_time) < self.T_dur) * self.T_0 + self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt=dt) + + # post-synaptic value + assert self.weight_type in ['homo', 'heter'] + assert self.conn_type in ['sparse', 'dense'] + if isinstance(self.conn, All2All): + if self.weight_type == 'homo': + post_g = bm.sum(self.g) + if not self.conn.include_self: + post_g = post_g - self.g + post_g = post_g * self.g_max + else: + post_g = self.g @ self.g_max + elif isinstance(self.conn, One2One): + post_g = self.g_max * self.g + else: + if self.conn_type == 'sparse': + post_g = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids) + else: + if self.weight_type == 'homo': + post_g = (self.g_max * self.g) @ self.conn_mat + else: + post_g = self.g @ self.g_max + + # output + g_inf = 1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post.V) + self.post.input += post_g * (self.E - self.post.V) / g_inf