Skip to content

Commit

Permalink
remove batch first, because it can not accelerate(only time-step last…
Browse files Browse the repository at this point in the history
… can)
  • Loading branch information
fangwei123456 committed May 8, 2023
1 parent 01cdcd2 commit aafe59b
Showing 1 changed file with 11 additions and 34 deletions.
45 changes: 11 additions & 34 deletions spikingjelly/activation_based/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,7 +2048,7 @@ def gen_gemm_weight(self, T: int):

def __init__(self, k: int, exp_init: bool = True,
surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's',
backend: str = 'gemm', batch_first: bool=False):
backend: str = 'gemm'):
"""
:param k: the order of the Sliding PSN
:type k: int
Expand All @@ -2060,10 +2060,6 @@ def __init__(self, k: int, exp_init: bool = True,
:type step_mode: str
:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode
:type backend: str
:param batch_first: whether the input sequence is batch first, which is with a shape of ``[N, T, *]``.
The default value of ``batch_first`` is ``False`` because SpikingJelly uses ``[T, N, *]`` by default.
This option only works for the multi-step mode
:type batch_first: bool
The Sliding Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability <https://arxiv.org/abs/2304.12760>`_. The neuronal dynamics are defined as
Expand All @@ -2086,10 +2082,6 @@ def __init__(self, k: int, exp_init: bool = True,
The "gemm" backend is much faster than the "conv" backend when ``T`` is small. But when ``T`` is large, the ``conv`` backend is faster.
.. admonition:: Note
:class: note
The "conv" backend is faster with ``batch_first = True`` than ``batch_first = False``. The ``gemm`` backend is faster with ``batch_first = False`` than ``batch_first = True``.
"""

Expand All @@ -2111,7 +2103,6 @@ def __init__(self, k: int, exp_init: bool = True,

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(torch.as_tensor(-1.))
self.batch_first = batch_first

def single_step_forward(self, x: torch.Tensor):
self.queue.append(x.flatten())
Expand All @@ -2131,35 +2122,21 @@ def single_step_forward(self, x: torch.Tensor):

def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'gemm':
if self.batch_first:
# x_seq.shape = [N, T, *]
weight = self.gen_gemm_weight(x_seq.shape[1])
h_seq = F.linear(x_seq.transpose(1, -1), weight, self.bias)
h_seq = h_seq.transpose(1, -1)
else:
weight = self.gen_gemm_weight(x_seq.shape[0])
h_seq = torch.addmm(self.bias, weight, x_seq.flatten(1)).view(x_seq.shape)

weight = self.gen_gemm_weight(x_seq.shape[0])
h_seq = torch.addmm(self.bias, weight, x_seq.flatten(1)).view(x_seq.shape)
return self.surrogate_function(h_seq)
elif self.backend == 'conv':
if self.batch_first:
x_seq_shape = x_seq.shape
N = x_seq.shape[0]
T = x_seq.shape[1]
# [N, T, *] -> [N, *, T] -> [N * ?, 1, T]
x_seq = x_seq.transpose(1, -1).reshape(-1, 1, T)
else:
# x_seq.shape = [T, N, *]
x_seq_shape = x_seq.shape
# [T, N, *] -> [T, N] -> [N, T] -> [N, 1, T]
x_seq = x_seq.flatten(1).t().unsqueeze(1)

# x_seq.shape = [T, N, *]
x_seq_shape = x_seq.shape
# [T, N, *] -> [T, N] -> [N, T] -> [N, 1, T]
x_seq = x_seq.flatten(1).t().unsqueeze(1)

x_seq = F.pad(x_seq, pad=(self.k - 1, 0))
x_seq = F.conv1d(x_seq, self.weight.view(1, 1, -1), stride=1)
if self.batch_first:
# [N * ?, 1, T] -> [N, ?, T] -> [N, T, ?]
x_seq = x_seq.view(N, -1, T).transpose(1, -1).view(x_seq_shape)
else:
x_seq = x_seq.squeeze(1).t().view(x_seq_shape)

x_seq = x_seq.squeeze(1).t().view(x_seq_shape)
return self.surrogate_function(x_seq + self.bias)

else:
Expand Down

0 comments on commit aafe59b

Please sign in to comment.