Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Sep 14, 2020
2 parents c966108 + c0e4224 commit 6941481
Showing 1 changed file with 167 additions and 0 deletions.
167 changes: 167 additions & 0 deletions spikingjelly/clock_driven/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,4 +582,171 @@ def forward(self, dv: torch.Tensor):

return self.spiking()

class AdaptThresholdNode(nn.Module):
def __init__(self, neuron_shape, tau_m: float, tau_adp: float, v_threshold_baseline=1.0, v_threshold_range=1.8, v_reset=0.0, surrogate_function=surrogate.Erf(), monitor_state=False, dt=1.0):
'''
* :ref:`API in English <AdaptThresholdNode.__init__-en>`
.. _AdaptThresholdNode.__init__-cn:
:param neuron_shape: 神经元张量的形状
:type neuron_shape: array_like
:param tau_m: 膜电位时间常数
:type tau_m: float
:param tau_adp: 阈值时间常数
:type tau_adp: float
:param v_threshold_baseline: 最小阈值,也为初始阈值 :math:`b_0` ,默认为1.0
:type v_threshold_baseline: float
:param v_threshold_range: 决定阈值变化范围的参数 :math:`\\beta` ,默认为1.8。控制阈值的范围为 :math:`[b_0,b_0+\\beta]`
:type v_threshold_range: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;如果设置为 ``None``,则电压会被减去 ``v_threshold``,默认为0.0
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数,默认为surrogate.Erf()
:param monitor_state: 是否设置监视器来保存神经元的电压和释放的脉冲。若为 ``True``,则 ``self.monitor`` 是一个字典,键包括 ``v`` 和 ``s``,分别记录电压和输出脉冲。对应的值是一个链表。为了节省显存(内存),列表中存入的是原始变量转换为 ``numpy`` 数组后的值。还需要注意,``self.reset()`` 函数会清空这些链表, 默认为False
:type monitor_state: bool
:param dt: 神经元的仿真间隔时间参数, 默认为1.0
:type dt: float
`Effective and Efficient Computation with Multiple-timescale Spiking Recurrent Neural Networks <https://arxiv.org/abs/2005.11633>`_ 中提出的自适应阈值神经元模型。在LIF神经元的基础上增加了一个阈值的动态方程:
.. math::
\\begin{align}
\\eta_t&=\\rho\\eta_{t-1}+(1-\\rho)S_{t-1},\\\\
\\theta_t&=b_0+\\beta\\eta_t,
\\end{align}
其中 :math:`\\eta_t` 为t时刻的阈值增幅,:math:`\\rho` 为阈值动态方程中由 ``tau_adp`` 决定的时间常数。:math:`\\theta_t` 为t时刻的电压阈值。
所有神经元动态方程的时间常数均为\ **可学习**\ 的网络参数。
.. hint::
不同于该模块中的其它神经元层,同层的各神经元不共享时间常数。
* :ref:`中文API <AdaptThresholdNode.__init__-cn>`
.. _AdaptThresholdNode.__init__-en:
:param neuron_shape: Shape of neuron tensor
:type neuron_shape: array_like
:param tau_m: Membrane potential time-constant
:type tau_m: float
:param tau_adp: Threshold time-constant
:type tau_adp: float
:param v_threshold_baseline: Minimal threshold, also the initial threshold :math:`b_0`, defaults to 1.0
:type v_threshold_baseline: float
:param v_threshold_range: Parameter :math:`\\beta` determining the range of threshold to :math:`[b_0,b_0+\\beta]` , defaults to 1.8
:type v_threshold_range: float
:param v_reset: Reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``, defaults to 0.0
:type v_reset: float
:param surrogate_function: Surrogate function for replacing gradient of spiking functions during back-propagation, defaults to surrogate.Erf()
:param monitor_state: Whether to turn on the monitor, defaults to False
:type monitor_state: bool
:param dt: Simulation interval constant of neurons, defaults to 1.0
:type dt: float
An neuron model with adaptive threshold proposed in `Effective and Efficient Computation with Multiple-timescale Spiking Recurrent Neural Networks <https://arxiv.org/abs/2005.11633>`_. Compared to vanilla LIF neuron, an additional dynamic equation of threshold is added:
.. math::
\\begin{align}
\\eta_t & = \\rho\\eta_{t-1}+(1-\\rho)S_{t-1},\\\\
\\theta_t & = b_0+\\beta\\eta_t,
\\end{align}
where :math:`\\eta_t` is the growth of threshold at timestep t, :math:`\\rho` is the time-constant determined by ``tau_adp`` in threshold dynamic. :math:`\\theta_t` is the threshold at timestep t.
All time constants in neurons' dynamics are **learnable** network parameters.
.. admonition:: Hint
:class: hint
Different from other types of neuron in this module, time-constant is NOT shared in the same layer.
'''

super().__init__()
self.neuron_shape = neuron_shape
self.b_0 = v_threshold_baseline
self.b = 0
self.v_reset = v_reset
self.beta = v_threshold_range
self.tau_m = nn.Parameter(torch.full(neuron_shape, fill_value=tau_m, dtype=torch.float))
self.tau_adp = nn.Parameter(torch.full(neuron_shape, fill_value=tau_adp, dtype=torch.float))
self.dt = dt
self.last_spike = torch.rand(neuron_shape)

if self.v_reset is None:
self.v = 0
else:
self.v = self.v_reset
self.v_threshold = self.b_0

self.surrogate_function = surrogate_function
if monitor_state:
self.monitor = {'v': [], 's': []}
else:
self.monitor = False

def extra_repr(self):
return f'v_threshold_baseline={self.b_0}, v_threshold_range={self.beta}, v_reset={self.v_reset}'

def set_monitor(self, monitor_state=True):
if monitor_state:
self.monitor = {'v': [], 's': []}
else:
self.monitor = False

def spiking(self):
spike = self.surrogate_function(self.v - self.v_threshold)
if self.monitor:
if self.monitor['v'].__len__() == 0:
# 补充在0时刻的电压
if self.v_reset is None:
self.monitor['v'].append(self.v.data.cpu().numpy().copy() * 0)
else:
self.monitor['v'].append(self.v.data.cpu().numpy().copy() * self.v_reset)

self.monitor['v'].append(self.v.data.cpu().numpy().copy())
self.monitor['s'].append(spike.data.cpu().numpy().copy())

if self.v_reset is None:
if self.surrogate_function.spiking:
self.v = accelerating.soft_voltage_transform(self.v, spike, self.v_threshold)
else:
self.v = self.v - spike * self.v_threshold
else:
if self.surrogate_function.spiking:
self.v = accelerating.hard_voltage_transform(self.v, spike, self.v_reset)
else:
self.v = self.v * (1 - spike) + self.v_reset * spike

if self.monitor:
self.monitor['v'].append(self.v.data.cpu().numpy().copy())

return spike

def forward(self, dv: torch.Tensor):
alpha = torch.exp(-self.dt / self.tau_m)
rho = torch.exp(-self.dt / self.tau_adp)

self.b = rho * self.b + (1 - rho) * self.last_spike
self.v_threshold = self.b_0 + self.beta * self.b
self.v = self.v * alpha + (1 - alpha) * dv

spike = self.spiking()

self.last_spike = spike

return spike

def reset(self):

if self.v_reset is None:
self.v = 0
else:
self.v = self.v_reset
self.v_threshold = self.b_0
self.b = 0
self.last_spike = torch.rand(self.neuron_shape)
if self.monitor:
self.monitor = {'v': [], 's': []}

0 comments on commit 6941481

Please sign in to comment.