Skip to content

Commit

Permalink
Merge pull request #470 from pkuxmq/master
Browse files Browse the repository at this point in the history
add DSR, OTTT, SLTT training modules
  • Loading branch information
fangwei123456 committed Dec 4, 2023
2 parents 915abc5 + 4b12a38 commit 024edfb
Show file tree
Hide file tree
Showing 4 changed files with 1,264 additions and 0 deletions.
80 changes: 80 additions & 0 deletions spikingjelly/activation_based/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,4 +1277,84 @@ def fptt_online_training(model: nn.Module, optimizer: torch.optim.Optimizer, x_s



def ottt_online_training(model: nn.Module, optimizer: torch.optim.Optimizer, x_seq: torch.Tensor, target_seq: torch.Tensor, f_loss_t: Callable, online: bool) -> None:
"""
:param model: the neural network
:type model: nn.Module
:param optimizer: the optimizer for the network
:type optimizer: torch.optim.Optimizer
:param x_seq: the input sequence
:type x_seq: torch.Tensor
:param target_seq: the output sequence
:type target_seq: torch.Tensor
:param f_loss_t: the loss function, which should has the formulation of ``def f_loss_t(x_t, y_t) -> torch.Tensor``
:type f_loss_t: Callable
:param online: whether online update parameters or accumulate gradients through time steps
:type online: bool
The OTTT online training method is proposed by `Online Training Through Time for Spiking Neural Networks <https://openreview.net/forum?id=Siv3nHYHheI>`_.
This function can also be used for SLTT training method proposed by `Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks <https://openaccess.thecvf.com/content/ICCV2023/html/Meng_Towards_Memory-_and_Time-Efficient_Backpropagation_for_Training_Spiking_Neural_Networks_ICCV_2023_paper.html>`_ .
Example:
.. code-block:: python
from spikingjelly.activation_based import neuron, layer, functional
net = layer.OTTTSequential(
nn.Linear(8, 4),
neuron.OTTTLIFNode(),
nn.Linear(4, 2),
neuron.LIFNode()
)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
T = 4
N = 2
online = True
for epoch in range(2):
x_seq = torch.rand([N, T, 8])
target_seq = torch.rand([N, T, 2])
functional.ottt_online_training(model=net, optimizer=optimizer, x_seq=x_seq, target_seq=target_seq, f_loss_t=F.mse_loss, online=online)
functional.reset_net(net)
"""

# input x_seq/target_seq: [B, T, ...]
# transpose to [T, B, ...]
x_seq = x_seq.transpose(0, 1)
target_seq = target_seq.transpose(0, 1)
T = x_seq.shape[0]

batch_loss = 0.
y_all = []
if not online:
optimizer.zero_grad()
for t in range(T):
if online:
optimizer.zero_grad()

y_t = model(x_seq[t])
loss = f_loss_t(y_t, target_seq[t].contiguous())

loss.backward()

# update params
if online:
optimizer.step()

batch_loss += loss.data
y_all.append(y_t.detach())

if not online:
optimizer.step()

# y_all: [B, T, ...]
y_all = torch.stack(y_all, dim=1)

return batch_loss, y_all

227 changes: 227 additions & 0 deletions spikingjelly/activation_based/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional, List, Tuple, Union
from typing import Callable
from torch.nn.modules.batchnorm import _BatchNorm
import numpy as np


class MultiStepContainer(nn.Sequential, base.MultiStepModule):
Expand Down Expand Up @@ -2534,3 +2535,229 @@ def __init__(
def multi_step_forward(self, x_seq: torch.Tensor):
# x.shape = [T, N, C, H, W, D]
return self.bn(x_seq) * self.scale.view(-1, 1, 1, 1, 1, 1)


# OTTT modules

class ReplaceforGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x, x_r):
return x_r

@staticmethod
def backward(ctx, grad):
return (grad, grad)


class GradwithTrace(nn.Module):
def __init__(self, module):
"""
* :ref:`API in English <GradwithTrace-en>`
.. _GradwithTrace-cn:
:param module: 需要包装的模块
用于随时间在线训练时,根据神经元的迹计算梯度
出处:'Online Training Through Time for Spiking Neural Networks <https://openreview.net/forum?id=Siv3nHYHheI>'
* :ref:`中文 API <GradwithTrace-cn>`
.. _GradwithTrace-en:
:param module: the module that requires wrapping
Used for online training through time, calculate gradients by the traces of neurons
Reference: 'Online Training Through Time for Spiking Neural Networks <https://openreview.net/forum?id=Siv3nHYHheI>'
"""
super().__init__()
self.module = module

def forward(self, x: Tensor):
# x: [spike, trace], defined in OTTTLIFNode in neuron.py
spike, trace = x[0], x[1]

with torch.no_grad():
out = self.module(spike).detach()

in_for_grad = ReplaceforGrad.apply(spike, trace)
out_for_grad = self.module(in_for_grad)

x = ReplaceforGrad.apply(out_for_grad, out)

return x


class SpikeTraceOp(nn.Module):
def __init__(self, module):
"""
* :ref:`API in English <SpikeTraceOp-en>`
.. _SpikeTraceOp-cn:
:param module: 需要包装的模块
对脉冲和迹进行相同的运算,如Dropout,AvgPool等
* :ref:`中文 API <GradwithTrace-cn>`
.. _SpikeTraceOp-en:
:param module: the module that requires wrapping
perform the same operations for spike and trace, such as Dropout, Avgpool, etc.
"""
super().__init__()
self.module = module

def forward(self, x: Tensor):
# x: [spike, trace], defined in OTTTLIFNode in neuron.py
spike, trace = x[0], x[1]

spike = self.module(spike)
with torch.no_grad():
trace = self.module(trace)

x = [spike, trace]

return x


class OTTTSequential(nn.Sequential):
def __init__(self, *args):
super().__init__(*args)

def forward(self, input):
for module in self:
if not isinstance(input, list):
input = module(input)
else:
if len(list(module.parameters())) > 0: # e.g., Conv2d, Linear, etc.
module = GradwithTrace(module)
else: # e.g., Dropout, AvgPool, etc.
module = SpikeTraceOp(module)
input = module(input)
return input


# weight standardization modules

class WSConv2d(Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
step_mode: str = 's',
gain: bool = True,
eps: float = 1e-4
) -> None:
"""
* :ref:`API in English <WSConv2d-en>`
.. _WSConv2d-cn:
:param gain: 是否对权重引入可学习的缩放系数
:type gain: bool
:param eps: 预防数值问题的小量
:type eps: float
其他的参数API参见 :class:`Conv2d`
* :ref:`中文 API <WSConv2d-cn>`
.. _WSConv2d-en:
:param gain: whether introduce learnable scale factors for weights
:type step_mode: bool
:param eps: a small number to prevent numerical problems
:type eps: float
Refer to :class:`Conv2d` for other parameters' API
"""
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, step_mode)
if gain:
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
else:
self.gain = None
self.eps = eps

def get_weight(self):
fan_in = np.prod(self.weight.shape[1:])
mean = torch.mean(self.weight, axis=[1, 2, 3], keepdims=True)
var = torch.var(self.weight, axis=[1, 2, 3], keepdims=True)
weight = (self.weight - mean) / ((var * fan_in + self.eps) ** 0.5)
if self.gain is not None:
weight = weight * self.gain
return weight

def _forward(self, x: Tensor):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

def forward(self, x: Tensor):
if self.step_mode == 's':
x = self._forward(x)

elif self.step_mode == 'm':
if x.dim() != 5:
raise ValueError(f'expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!')
x = functional.seq_to_ann_forward(x, self._forward)

return x


class WSLinear(Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, step_mode='s', gain=True, eps=1e-4) -> None:
"""
* :ref:`API in English <WSLinear-en>`
.. _WSLinear-cn:
:param gain: 是否对权重引入可学习的缩放系数
:type gain: bool
:param eps: 预防数值问题的小量
:type eps: float
其他的参数API参见 :class:`Linear`
* :ref:`中文 API <WSLinear-cn>`
.. _WSLinear-en:
:param gain: whether introduce learnable scale factors for weights
:type step_mode: bool
:param eps: a small number to prevent numerical problems
:type eps: float
Refer to :class:`Linear` for other parameters' API
"""
super().__init__(in_features, out_features, bias, step_mode)
if gain:
self.gain = nn.Parameter(torch.ones(self.out_channels, 1))
else:
self.gain = None
self.eps = eps

def get_weight(self):
fan_in = np.prod(self.weight.shape[1:])
mean = torch.mean(self.weight, axis=[1], keepdims=True)
var = torch.var(self.weight, axis=[1], keepdims=True)
weight = (self.weight - mean) / ((var * fan_in + self.eps) ** 0.5)
if self.gain is not None:
weight = weight * self.gain
return weight

def forward(self, x: Tensor):
return F.linear(x, self.get_weight(), self.bias)

0 comments on commit 024edfb

Please sign in to comment.