Skip to content

Commit

Permalink
[Dyn] Fix alpha synapse bugs (#578)
Browse files Browse the repository at this point in the history
* [dyn] fix alpha synapse bugs

* [docs] fix math expression
  • Loading branch information
ztqakita committed Jan 3, 2024
1 parent d0988a0 commit 0b297b7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 21 deletions.
35 changes: 30 additions & 5 deletions brainpy/_src/dyn/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def return_info(self):
DualExponV2.__doc__ = DualExponV2.__doc__ % (pneu_doc,)


class Alpha(DualExpon):
class Alpha(SynDyn):
r"""Alpha synapse model.
**Model Descriptions**
Expand All @@ -494,7 +494,7 @@ class Alpha(DualExpon):
.. math::
\begin{aligned}
&\frac{d g}{d t}=-\frac{g}{\tau}+h \\
&\frac{d g}{d t}=-\frac{g}{\tau}+\frac{h}{\tau} \\
&\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right)
\end{aligned}
Expand Down Expand Up @@ -585,16 +585,41 @@ def __init__(
tau_decay: Union[float, ArrayType, Callable] = 10.0,
):
super().__init__(
tau_decay=tau_decay,
tau_rise=tau_decay,
method=method,
name=name,
mode=mode,
size=size,
keep_size=keep_size,
sharding=sharding
)

# parameters
self.tau_decay = self.init_param(tau_decay)

# integrator
self.integral = odeint(JointEq(self.dg, self.dh), method=method)

self.reset_state(self.mode)

def reset_state(self, batch_or_mode=None, **kwargs):
self.h = self.init_variable(bm.zeros, batch_or_mode)
self.g = self.init_variable(bm.zeros, batch_or_mode)

def dh(self, h, t):
return -h / self.tau_decay

def dg(self, g, t, h):
return -g / self.tau_decay + h / self.tau_decay

def update(self, x):
# update synaptic variables
self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
self.h += x
return self.g.value

def return_info(self):
return self.g



Alpha.__doc__ = Alpha.__doc__ % (pneu_doc,)

Expand Down
52 changes: 36 additions & 16 deletions brainpy/_src/dynold/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def update(self, pre_spike=None):
return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)


class Alpha(DualExponential):
class Alpha(_TwoEndConnAlignPre):
r"""Alpha synapse model.
**Model Descriptions**
Expand All @@ -516,7 +516,7 @@ class Alpha(DualExponential):
\begin{aligned}
&g_{\mathrm{syn}}(t)= g_{\mathrm{max}} g \\
&\frac{d g}{d t}=-\frac{g}{\tau}+h \\
&\frac{d g}{d t}=-\frac{g}{\tau}+\frac{h}{\tau} \\
&\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right)
\end{aligned}
Expand Down Expand Up @@ -593,20 +593,40 @@ def __init__(
mode: bm.Mode = None,
stop_spike_gradient: bool = False,
):
super(Alpha, self).__init__(pre=pre,
post=post,
conn=conn,
comp_method=comp_method,
delay_step=delay_step,
g_max=g_max,
tau_decay=tau_decay,
tau_rise=tau_decay,
method=method,
output=output,
stp=stp,
name=name,
mode=mode,
stop_spike_gradient=stop_spike_gradient)
# parameters
self.stop_spike_gradient = stop_spike_gradient
self.comp_method = comp_method
self.tau_decay = tau_decay
if bm.size(self.tau_decay) != 1:
raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. '
f'But we got {self.tau_decay}')

syn = synapses.Alpha(pre.size,
pre.keep_size,
mode=mode,
tau_decay=tau_decay,
method=method)

super().__init__(pre=pre,
post=post,
syn=syn,
conn=conn,
comp_method=comp_method,
delay_step=delay_step,
g_max=g_max,
output=output,
stp=stp,
name=name,
mode=mode,)

self.check_post_attrs('input')
# copy the references
self.g = syn.g
self.h = syn.h

def update(self, pre_spike=None):
return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)



class NMDA(_TwoEndConnAlignPre):
Expand Down

0 comments on commit 0b297b7

Please sign in to comment.