Skip to content

Commit

Permalink
add api doc
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Apr 25, 2023
1 parent 2243f16 commit dbaa402
Showing 1 changed file with 130 additions and 31 deletions.
161 changes: 130 additions & 31 deletions spikingjelly/activation_based/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,21 +541,23 @@ def multi_step_forward(self, x_seq: torch.Tensor):
else:
raise NotImplementedError(x_seq.dtype)

if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset, dtype=dtype):

if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset,
dtype=dtype):
self.forward_kernel = ac_neuron_kernel.IFNodeFPTTKernel(hard_reset=hard_reset, dtype=dtype)

if self.backward_kernel is None or not self.backward_kernel.check_attributes(surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype):

if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype):
self.backward_kernel = ac_neuron_kernel.IFNodeBPTTKernel(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype)

self.v_float_to_tensor(x_seq[0])

spike_seq, v_seq = ac_neuron_kernel.IFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(0),
self.v_threshold, self.v_reset, self.forward_kernel,
self.backward_kernel)
self.v_threshold, self.v_reset,
self.forward_kernel,
self.backward_kernel)

spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
Expand Down Expand Up @@ -942,20 +944,27 @@ def multi_step_forward(self, x_seq: torch.Tensor):
else:
raise NotImplementedError(x_seq.dtype)

if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset, dtype=dtype, decay_input=self.decay_input):
self.forward_kernel = ac_neuron_kernel.LIFNodeFPTTKernel(decay_input=self.decay_input, hard_reset=hard_reset, dtype=dtype)
if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset,
dtype=dtype,
decay_input=self.decay_input):
self.forward_kernel = ac_neuron_kernel.LIFNodeFPTTKernel(decay_input=self.decay_input,
hard_reset=hard_reset, dtype=dtype)

if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input):
self.backward_kernel = ac_neuron_kernel.LIFNodeBPTTKernel(decay_input=self.decay_input, surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype)
self.backward_kernel = ac_neuron_kernel.LIFNodeBPTTKernel(decay_input=self.decay_input,
surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset,
dtype=dtype)

self.v_float_to_tensor(x_seq[0])

spike_seq, v_seq = ac_neuron_kernel.LIFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(0),
self.v_threshold, self.v_reset, 1. / self.tau,
self.forward_kernel,
self.backward_kernel)
self.v_threshold, self.v_reset, 1. / self.tau,
self.forward_kernel,
self.backward_kernel)

spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
Expand Down Expand Up @@ -1175,16 +1184,16 @@ def multi_step_forward(self, x_seq: torch.Tensor):
dtype=dtype,
decay_input=self.decay_input):
self.forward_kernel = ac_neuron_kernel.ParametricLIFNodeFPTTKernel(decay_input=self.decay_input,
hard_reset=hard_reset, dtype=dtype)
hard_reset=hard_reset, dtype=dtype)

if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input):
self.backward_kernel = ac_neuron_kernel.ParametricLIFNodeBPTTKernel(decay_input=self.decay_input,
surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype)

surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset,
dtype=dtype)

self.v_float_to_tensor(x_seq[0])

Expand Down Expand Up @@ -1628,7 +1637,7 @@ def single_step_forward(self, x: torch.Tensor):


class KLIFNode(BaseNode):
def __init__(self, scale_reset: bool=False, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
def __init__(self, scale_reset: bool = False, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
Expand Down Expand Up @@ -1799,23 +1808,20 @@ def __init__(self, scale_reset: bool=False, tau: float = 2., decay_input: bool =

self.k = nn.Parameter(torch.as_tensor(1.))


@staticmethod
@torch.jit.script
def neuronal_charge_decay_input(x: torch.Tensor, v: torch.Tensor, v_reset: float, tau: float, k: torch.Tensor):
v = v + (x - (v - v_reset)) / tau
v = torch.relu_(k * v)
return v


@staticmethod
@torch.jit.script
def neuronal_charge_no_decay_input(x: torch.Tensor, v: torch.Tensor, v_reset: float, tau: float, k: torch.Tensor):
v = v - (v - v_reset) / tau + x
v = torch.relu_(k * v)
return v


def neuronal_charge(self, x: torch.Tensor):
if self.v_reset is None:
v_reset = 0.
Expand All @@ -1827,14 +1833,12 @@ def neuronal_charge(self, x: torch.Tensor):
else:
self.v = self.neuronal_charge_no_decay_input(x, self.v, v_reset, self.tau, self.k)


def neuronal_reset(self, spike):
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike


if self.scale_reset:
if self.v_reset is None:
# soft reset
Expand All @@ -1855,13 +1859,28 @@ def neuronal_reset(self, spike):
self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)


class PSN(nn.Module, base.MultiStepModule):
def __init__(self, T: int, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan()):
"""
:param T: the number of time-steps
:type T: int
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
The Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability`. The neuronal dynamics are defined as
.. math::
H &= WX, ~~~~~~~~~~~~~~~W \\in \\mathbb{R}^{T \\times T}, X \\in \\mathbb{R}^{T \\times N} \\label{eq psn neuronal charge}\\\\
S &= \\Theta(H - B), ~~~~~B \\in \\mathbb{R}^{T}, S\\in \\{0, 1\\}^{T \\times N}
where :math`W` is the learnable weight matrix, and :math:`B` is the learnable threshold.
.. admonition:: Note
:class: note
class PSN(nn.Module, base.MultiStepModule):
def __init__(self, T: int, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan()):
The PSN only supports the multi-step mode.
"""
super().__init__()
self.T = T
self.surrogate_function = surrogate_function
Expand Down Expand Up @@ -1893,7 +1912,52 @@ def masked_weight(self):
else:
return self.gen_masked_weight(self.k, self.mask0, self.mask1, self.weight)

def __init__(self, order: int, T: int, k_init: float=0., surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's'):
def __init__(self, order: int, T: int, k_init: float = 0.,
surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's'):
"""
:param order: the order of the Masked PSN
:type order: int
:param T: the number of time-steps
:type T: int
:param k_init: the initial value of ``k`` to adjust the progressive masking process
:type k_init: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
The Masked Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability`. The neuronal dynamics are defined as
.. math::
H &= (W \\cdot {M}_{order})X, ~~~~~~~~~~~~~~~W \\in \\mathbb{R}^{T \\times T}, {M}_{order} \\in \\mathbb{R}^{T \\times T}, X \\in \\mathbb{R}^{T \\times N} \\\\
S &= \\Theta(H - B), ~~~~~B \\in \\mathbb{R}^{T}, S\\in \\{0, 1\\}^{T \\times N}
where :math`W` is the learnable weight matrix, :math:`B` is the learnable threshold, and :math:`{M}_{order}` is defined as
.. math::
{M}_{order}[i][j] = \\begin{cases}
1, ~~ j \\leq i \\leq j + k - 1 \\\\
0, \\mathrm{otherwise}
\\end{cases}.
:math:`k` is used to adjust the progressive masking process, which is
.. math::
M_{order}(k) = k \\cdot M_{order} + (1 - k) \\cdot J,
where :math:`J` is an all-one matrix.
The user can set :math:`k` during training by calling ``self.k = a``.
.. admonition:: Note
:class: note
The masked PSN supports both single-step and multi-step mode. But using the multi-step mode is much faster than the single-step mode.
"""
super().__init__()
self.register_memory('time_step', 0)
self.register_memory('queue', [])
Expand All @@ -1916,7 +1980,6 @@ def __init__(self, order: int, T: int, k_init: float=0., surrogate_function: sur
self.register_buffer('mask0', mask0)
self.register_buffer('mask1', mask1)


@property
def k(self):
return self._k
Expand Down Expand Up @@ -1947,7 +2010,6 @@ def single_step_forward(self, x: torch.Tensor):
self.time_step += 1
return spike.view(x.shape)


def multi_step_forward(self, x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
assert x_seq.shape[0] == self.T
Expand All @@ -1967,6 +2029,7 @@ def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
state_dict.pop('k')
super().load_state_dict(state_dict, strict)


class SlidingPSN(base.MemoryModule):

@property
Expand All @@ -1983,7 +2046,44 @@ def gen_gemm_weight(self, T: int):

return weight

def __init__(self, order: int, exp_init: bool = True, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's', backend: str = 'gemm'):
def __init__(self, order: int, exp_init: bool = True,
surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's',
backend: str = 'gemm'):
"""
:param order: the order of the Masked PSN
:type order: int
:param exp_init: if ``True``, the weight will be initialized as ``(..., 1/4, 1/2, 1)``. If ``False``, the weight will be initialized by the kaiming uniform
:type exp_init: bool
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neuron layer, which can be "gemm" or "conv"
:type backend: str
The Sliding Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability`. The neuronal dynamics are defined as
.. math::
H[t] &= \\sum_{i=0}^{k-1}W_{i}\\cdot X[t - k + 1 + i], \\\\
S[t] &= \\Theta(H[t] - B),
where :math:`W = [W_{0}, W_{1}, ..., W_{k-1}] \\in \\mathbb{R}^{T}` is the learnable weight, and :math:`B` is the learnable threshold.
.. admonition:: Note
:class: note
The sliding PSN supports both single-step and multi-step mode. But using the multi-step mode is much faster than the single-step mode.
.. admonition:: Note
:class: note
The "gemm" backend is much faster than the "conv" backend.
"""

super().__init__()
self.register_memory('queue', [])
self.step_mode = step_mode
Expand Down Expand Up @@ -2019,7 +2119,6 @@ def single_step_forward(self, x: torch.Tensor):

return spike.view(x.shape)


def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'gemm':
weight = self.gen_gemm_weight(x_seq.shape[0])
Expand All @@ -2036,4 +2135,4 @@ def multi_step_forward(self, x_seq: torch.Tensor):
return self.surrogate_function(x_seq + self.bias)

else:
raise NotImplementedError(self.backend)
raise NotImplementedError(self.backend)

0 comments on commit dbaa402

Please sign in to comment.