Skip to content

Commit

Permalink
Support for Delta synapse projections (#568)
Browse files Browse the repository at this point in the history
* [dyn] synaptic projection updates

1. reorganize the projection structures;
2. rename previous reduced projections with intuitive names
3. add `brainpy.dyn.HalfProjDelta` and `brainpy.dyn.FullProjDelta`

* [doc] update doc

* [fix] fix bug

* [doc] upgrade the documentation of synaptic projections
  • Loading branch information
chaoming0625 committed Dec 28, 2023
1 parent fb9a321 commit c346f29
Show file tree
Hide file tree
Showing 32 changed files with 2,024 additions and 1,584 deletions.
10 changes: 10 additions & 0 deletions brainpy/_add_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@
# neurons
'NeuGroup': ('brainpy.dyn.NeuGroup', 'brainpy.dyn.NeuDyn', NeuDyn),

# projections
'ProjAlignPostMg1': ('brainpy.dyn.ProjAlignPostMg1', 'brainpy.dyn.HalfProjAlignPostMg', dyn.HalfProjAlignPostMg),
'ProjAlignPostMg2': ('brainpy.dyn.ProjAlignPostMg2', 'brainpy.dyn.FullProjAlignPostMg', dyn.FullProjAlignPostMg),
'ProjAlignPost1': ('brainpy.dyn.ProjAlignPost1', 'brainpy.dyn.HalfProjAlignPost', dyn.HalfProjAlignPost),
'ProjAlignPost2': ('brainpy.dyn.ProjAlignPost2', 'brainpy.dyn.FullProjAlignPost', dyn.FullProjAlignPost),
'ProjAlignPreMg1': ('brainpy.dyn.ProjAlignPreMg1', 'brainpy.dyn.FullProjAlignPreSDMg', dyn.FullProjAlignPreSDMg),
'ProjAlignPreMg2': ('brainpy.dyn.ProjAlignPreMg2', 'brainpy.dyn.FullProjAlignPreDSMg', dyn.FullProjAlignPreDSMg),
'ProjAlignPre1': ('brainpy.dyn.ProjAlignPre1', 'brainpy.dyn.FullProjAlignPreSD', dyn.FullProjAlignPreSD),
'ProjAlignPre2': ('brainpy.dyn.ProjAlignPre2', 'brainpy.dyn.FullProjAlignPreDS', dyn.FullProjAlignPreDS),

# synapses
'TwoEndConn': ('brainpy.dyn.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
Expand Down
23 changes: 13 additions & 10 deletions brainpy/_src/dyn/neurons/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class CondNeuGroupLTC(HHTypedNeuron, Container, TreeNode):
where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants.
.. versionadded:: 2.1.9
Model the conductance-based neuron model.
Modeling the conductance-based neuron model.
Parameters
----------
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(

def derivative(self, V, t, I):
# synapses
I = self.sum_inputs(V, init=I)
I = self.sum_current_inputs(V, init=I)
# channels
for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values():
I = I + ch.current(V)
Expand All @@ -140,7 +140,7 @@ def update(self, x=None):
x = x * (1e-3 / self.A)

# integral
V = self.integral(self.V.value, share['t'], x, share['dt'])
V = self.integral(self.V.value, share['t'], x, share['dt']) + self.sum_delta_inputs()

# check whether the children channels have the correct parents.
channels = self.nodes(level=1, include_self=False).subset(IonChaDyn).unique()
Expand Down Expand Up @@ -176,7 +176,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
# inputs
x = 0. if x is None else x
x = self.sum_inputs(self.V.value, init=x)
x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)


Expand Down Expand Up @@ -384,7 +384,7 @@ def reset_state(self, batch_size=None, **kwargs):
self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size)

def dV(self, V, t, m, h, n, I):
I = self.sum_inputs(V, init=I)
I = self.sum_current_inputs(V, init=I)
I_Na = (self.gNa * m * m * m * h) * (V - self.ENa)
n2 = n * n
I_K = (self.gK * n2 * n2) * (V - self.EK)
Expand All @@ -402,6 +402,7 @@ def update(self, x=None):
x = 0. if x is None else x

V, m, h, n = self.integral(self.V.value, self.m.value, self.h.value, self.n.value, t, x, dt)
V += self.sum_delta_inputs()
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.m.value = m
Expand Down Expand Up @@ -532,7 +533,7 @@ def derivative(self):

def update(self, x=None):
x = 0. if x is None else x
x = self.sum_inputs(self.V.value, init=x)
x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)


Expand Down Expand Up @@ -662,7 +663,7 @@ def reset_state(self, batch_or_mode=None, **kwargs):
self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_or_mode)

def dV(self, V, t, W, I):
I = self.sum_inputs(V, init=I)
I = self.sum_current_inputs(V, init=I)
M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
I_Ca = self.g_Ca * M_inf * (V - self.V_Ca)
I_K = self.g_K * W * (V - self.V_K)
Expand All @@ -685,6 +686,7 @@ def update(self, x=None):
dt = share.load('dt')
x = 0. if x is None else x
V, W = self.integral(self.V, self.W, t, x, dt)
V += self.sum_delta_inputs()
spike = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.W.value = W
Expand Down Expand Up @@ -761,7 +763,7 @@ def dV(self, V, t, W, I):

def update(self, x=None):
x = 0. if x is None else x
x = self.sum_inputs(self.V.value, init=x)
x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)


Expand Down Expand Up @@ -951,7 +953,7 @@ def dn(self, n, t, V):
return self.phi * dndt

def dV(self, V, t, h, n, I):
I = self.sum_inputs(V, init=I)
I = self.sum_current_inputs(V, init=I)
INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa)
IK = self.gK * n ** 4 * (V - self.EK)
IL = self.gL * (V - self.EL)
Expand All @@ -968,6 +970,7 @@ def update(self, x=None):
x = 0. if x is None else x

V, h, n = self.integral(self.V, self.h, self.n, t, x, dt)
V += self.sum_delta_inputs()
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
self.h.value = h
Expand Down Expand Up @@ -1091,5 +1094,5 @@ def dV(self, V, t, h, n, I):

def update(self, x=None):
x = 0. if x is None else x
x = self.sum_inputs(self.V.value, init=x)
x = self.sum_current_inputs(self.V.value, init=x)
return super().update(x)

0 comments on commit c346f29

Please sign in to comment.