Skip to content

Commit

Permalink
No commit message
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed May 21, 2022
2 parents c925ada + ccb9b54 commit 23c3f85
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 4 deletions.
2 changes: 1 addition & 1 deletion brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def register_delay(
"""
# delay steps
if delay_step is None:
return delay_step
delay_type = 'none'
elif isinstance(delay_step, int):
delay_type = 'homo'
elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)):
Expand Down
8 changes: 6 additions & 2 deletions brainpy/dyn/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ class NMDA(TwoEndConn):
.. math::
I_{syn} = g_{NMDA}(t) (V(t)-E) \cdot g_{\infty}
I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}
where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
reversal potential.
Expand All @@ -1061,7 +1061,7 @@ class NMDA(TwoEndConn):
.. math::
& g_{NMDA} (t) = g_{max} g \\
& g_\mathrm{NMDA} (t) = g_{max} g \\
& \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\
& \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k})
Expand Down Expand Up @@ -1235,6 +1235,10 @@ def __init__(
# integral
self.integral = odeint(method=method, f=JointEq([self.dg, self.dx]))

def reset(self):
self.g.value = bm.zeros(self.pre.num)
self.x.value = bm.zeros(self.pre.num)

def dg(self, g, t, x):
return -g / self.tau_decay + self.a * x * (1 - g)

Expand Down
274 changes: 273 additions & 1 deletion brainpy/dyn/synapses/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from brainpy.connect import TwoEndConnector, All2All, One2One
from brainpy.dyn.base import NeuGroup, TwoEndConn
from brainpy.initialize import Initializer, init_param
from brainpy.integrators import odeint
from brainpy.integrators import odeint, JointEq
from brainpy.types import Tensor

__all__ = [
'AMPA',
'GABAa',
'BioNMDA',
]


Expand Down Expand Up @@ -328,3 +329,274 @@ def __init__(
T_duration=T_duration,
method=method,
name=name)


class BioNMDA(TwoEndConn):
r"""Biological NMDA synapse model.
**Model Descriptions**
The NMDA receptor is a glutamate receptor and ion channel found in neurons.
The NMDA receptor is one of three types of ionotropic glutamate receptors,
the other two being AMPA and kainate receptors.
The NMDA receptor mediated conductance depends on the postsynaptic voltage.
The voltage dependence is due to the blocking of the pore of the NMDA receptor
from the outside by a positively charged magnesium ion. The channel is
nearly completely blocked at resting potential, but the magnesium block is
relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}`
that are not blocked by magnesium can be fitted to
.. math::
g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\a V}
\frac{[{Mg}^{2+}]_{o}} {\b})^{-1}
Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration,
usually 1 mM. Thus, the channel acts as a
"coincidence detector" and only once both of these conditions are met, the
channel opens and it allows positively charged ions (cations) to flow through
the cell membrane [2]_.
If we make the approximation that the magnesium block changes
instantaneously with voltage and is independent of the gating of the channel,
the net NMDA receptor-mediated synaptic current is given by
.. math::
I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}
where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
reversal potential.
Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_:
.. math::
& g_\mathrm{NMDA} (t) = g_{max} g \\
& \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\
& \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x
where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and
:math:`\alpha_2, \beta_2` refers to the conversion rate of variable x.
The NMDA receptor has been thought to be very important for controlling
synaptic plasticity and mediating learning and memory functions [3]_.
.. plot::
:include-source: True
>>> import brainpy as bp
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = bp.dyn.HH(1)
>>> neu2 = bp.dyn.HH(1)
>>> syn1 = bp.dyn.BioNMDA(neu1, neu2, bp.connect.All2All(), E=0.)
>>> net = bp.dyn.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.dyn.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>>
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x')
>>> plt.legend()
>>> plt.show()
Parameters
----------
pre: NeuGroup
The pre-synaptic neuron group.
post: NeuGroup
The post-synaptic neuron group.
conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
conn_type: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `dense`.
delay_step: int, ndarray, JaxArray, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
g_max: float, ndarray, JaxArray, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
E: float, JaxArray, ndarray
The reversal potential for the synaptic current. [mV]
a: float, JaxArray, ndarray
Binding constant. Default 0.062
b: float, JaxArray, ndarray
Unbinding constant. Default 3.57
cc_Mg: float, JaxArray, ndarray
Concentration of Magnesium ion. Default 1.2 [mM].
alpha1: float, JaxArray, ndarray
The conversion rate of g from inactive to active. Default 2 ms^-1.
beta1: float, JaxArray, ndarray
The conversion rate of g from active to inactive. Default 0.01 ms^-1.
alpha2: float, JaxArray, ndarray
The conversion rate of x from inactive to active. Default 1 ms^-1.
beta2: float, JaxArray, ndarray
The conversion rate of x from active to inactive. Default 0.5 ms^-1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.
References
----------
.. [1] Devaney A J . Mathematical Foundations of Neuroscience[M].
Springer New York, 2010: 162.
.. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and
Eric Gouaux. "Subunit arrangement and function in NMDA receptors."
Nature 438, no. 7065 (2005): 185-192.
.. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New
England journal of medicine, 361(3), p.302.
.. [4] https://en.wikipedia.org/wiki/NMDA_receptor
"""

def __init__(
self,
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]],
conn_type: str = 'dense',
g_max: Union[float, Tensor, Initializer, Callable] = 0.15,
delay_step: Union[int, Tensor, Initializer, Callable] = None,
E: Union[float, Tensor] = 0.,
cc_Mg: Union[float, Tensor] = 1.2,
a: Union[float, Tensor] = 0.062,
b: Union[float, Tensor] = 3.57,
alpha1: Union[float, Tensor] = 2.,
beta1: Union[float, Tensor] = 0.01,
alpha2: Union[float, Tensor] = 1.,
beta2: Union[float, Tensor] = 0.5,
T_0: Union[float, Tensor] = 1.,
T_dur: Union[float, Tensor] = 0.5,
method: str = 'exp_auto',
name: str = None,
):
super(BioNMDA, self).__init__(pre=pre, post=post, conn=conn, name=name)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')

# parameters
self.E = E
self.alpha = a
self.beta = b
self.cc_Mg = cc_Mg
self.beta1 = beta1
self.beta2 = beta2
self.alpha1 = alpha1
self.alpha2 = alpha2
self.T_0 = T_0
self.T_dur = T_dur
if bm.size(alpha1) != 1:
raise ValueError(f'"alpha1" must be a scalar or a tensor with size of 1. But we got {alpha1}')
if bm.size(beta1) != 1:
raise ValueError(f'"beta1" must be a scalar or a tensor with size of 1. But we got {beta1}')
if bm.size(alpha2) != 1:
raise ValueError(f'"alpha2" must be a scalar or a tensor with size of 1. But we got {alpha2}')
if bm.size(beta2) != 1:
raise ValueError(f'"beta2" must be a scalar or a tensor with size of 1. But we got {beta2}')
if bm.size(E) != 1:
raise ValueError(f'"E" must be a scalar or a tensor with size of 1. But we got {E}')
if bm.size(a) != 1:
raise ValueError(f'"a" must be a scalar or a tensor with size of 1. But we got {a}')
if bm.size(b) != 1:
raise ValueError(f'"b" must be a scalar or a tensor with size of 1. But we got {b}')
if bm.size(cc_Mg) != 1:
raise ValueError(f'"cc_Mg" must be a scalar or a tensor with size of 1. But we got {cc_Mg}')
if bm.size(T_0) != 1:
raise ValueError(f'"T_0" must be a scalar or a tensor with size of 1. But we got {T_0}')
if bm.size(T_dur) != 1:
raise ValueError(f'"T_dur" must be a scalar or a tensor with size of 1. But we got {T_dur}')

# connections and weights
self.conn_type = conn_type
if conn_type not in ['sparse', 'dense']:
raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
if isinstance(self.conn, One2One):
self.g_max = init_param(g_max, (self.pre.num,), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif isinstance(self.conn, All2All):
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
if bm.size(self.g_max) != 1:
self.weight_type = 'heter'
bm.fill_diagonal(self.g_max, 0.)
else:
self.weight_type = 'homo'
else:
if conn_type == 'sparse':
self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
elif conn_type == 'dense':
self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False)
self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo'
if self.weight_type == 'homo':
self.conn_mat = self.conn.require('conn_mat')
else:
raise ValueError(f'Unknown connection type: {conn_type}')

# variables
self.g = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
self.x = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_))
self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num, dtype=bm.float_) * -1e7)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

# integral
self.integral = odeint(method=method, f=JointEq([self.dg, self.dx]))

def reset(self):
self.g.value = bm.zeros(self.pre.num)
self.x.value = bm.zeros(self.pre.num)
self.spike_arrival_time.value = bm.ones(self.pre.num) * -1e7

def dg(self, g, t, x):
return self.alpha1 * x * (1 - g) - self.beta1 * g

def dx(self, x, t, T):
return self.alpha2 * T * (1 - x) - self.beta2 * x

def update(self, t, dt):
# delays
delayed_pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)

# update synapse variables
self.spike_arrival_time.value = bm.where(delayed_pre_spike, t, self.spike_arrival_time)
T = ((t - self.spike_arrival_time) < self.T_dur) * self.T_0
self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt=dt)

# post-synaptic value
assert self.weight_type in ['homo', 'heter']
assert self.conn_type in ['sparse', 'dense']
if isinstance(self.conn, All2All):
if self.weight_type == 'homo':
post_g = bm.sum(self.g)
if not self.conn.include_self:
post_g = post_g - self.g
post_g = post_g * self.g_max
else:
post_g = self.g @ self.g_max
elif isinstance(self.conn, One2One):
post_g = self.g_max * self.g
else:
if self.conn_type == 'sparse':
post_g = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids)
else:
if self.weight_type == 'homo':
post_g = (self.g_max * self.g) @ self.conn_mat
else:
post_g = self.g @ self.g_max

# output
g_inf = 1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post.V)
self.post.input += post_g * (self.E - self.post.V) / g_inf

0 comments on commit 23c3f85

Please sign in to comment.