Skip to content

Commit

Permalink
add the PSN family
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Apr 25, 2023
1 parent 1dc0dff commit 2243f16
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 3 deletions.
2 changes: 1 addition & 1 deletion spikingjelly/activation_based/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def check_backend_library(backend: str):
raise ImportError('Lava-DL is not installed! You can install it from ' \
'"https://github.com/lava-nc/lava-dl". ')
else:
raise NotImplementedError(backend)
pass


class StepModule:
Expand Down
187 changes: 185 additions & 2 deletions spikingjelly/activation_based/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Callable
import torch
import torch.nn as nn
from . import surrogate, base
from .. import configure
import torch.nn.functional as F
import math
import numpy as np
import logging

from . import surrogate, base
from .auto_cuda import neuron_kernel as ac_neuron_kernel

try:
Expand Down Expand Up @@ -1854,3 +1855,185 @@ def neuronal_reset(self, spike):
self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)







class PSN(nn.Module, base.MultiStepModule):
def __init__(self, T: int, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan()):
super().__init__()
self.T = T
self.surrogate_function = surrogate_function
weight = torch.zeros([T, T])
bias = torch.zeros([T, 1])

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)

nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.constant_(self.bias, -1.)

def forward(self, x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
h_seq = torch.addmm(self.bias, self.weight, x_seq.flatten(1))
spike_seq = self.surrogate_function(h_seq)
return spike_seq.view(x_seq.shape)


class MaskedPSN(base.MemoryModule):
@staticmethod
@torch.jit.script
def gen_masked_weight(k: float, mask0: torch.Tensor, mask1: torch.Tensor, weight: torch.Tensor):
return (k * mask0 + (1. - k) * mask1) * weight

def masked_weight(self):
if self.k >= 1.:
return self.weight
else:
return self.gen_masked_weight(self.k, self.mask0, self.mask1, self.weight)

def __init__(self, order: int, T: int, k_init: float=0., surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's'):
super().__init__()
self.register_memory('time_step', 0)
self.register_memory('queue', [])
self.step_mode = step_mode
self.order = order
self.T = T
self.surrogate_function = surrogate_function
weight = torch.zeros([T, T])
bias = torch.zeros([T, 1])

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)

nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.constant_(self.bias, -1.)

self._k = k_init
mask1 = torch.ones([T, T])
mask0 = torch.tril(mask1) * torch.triu(mask1, -(order - 1))
self.register_buffer('mask0', mask0)
self.register_buffer('mask1', mask1)


@property
def k(self):
return self._k

@k.setter
def k(self, value: float):
self._k = value

def single_step_forward(self, x: torch.Tensor):
self.queue.append(x.flatten())
if self.queue.__len__() > self.order:
self.queue.pop(0)

if self.time_step + 1 > self.T:
raise OverflowError(f"The MaskedPSN(T={self.T}) has run {self.time_step + 1} time-steps!")

weight = self.masked_weight()[self.time_step][self.time_step + 1 - self.queue.__len__(): self.time_step + 1]
x_seq = torch.stack(self.queue)

for i in range(x.dim()):
weight = weight.unsqueeze(-1)

print(weight.shape, x_seq.shape)

h = torch.sum(weight * x_seq, 0)
spike = self.surrogate_function(h + self.bias[self.time_step])

self.time_step += 1
return spike.view(x.shape)


def multi_step_forward(self, x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
assert x_seq.shape[0] == self.T
h_seq = torch.addmm(self.bias, self.masked_weight(), x_seq.flatten(1))
spike_seq = self.surrogate_function(h_seq).view(x_seq.shape)
return spike_seq

def state_dict(self, *args, **kwargs):
sd = super().state_dict()
sd['k'] = torch.as_tensor(self.k)
return sd

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):

self.k = state_dict['k'].item()
state_dict.pop('k')
super().load_state_dict(state_dict, strict)

class SlidingPSN(base.MemoryModule):

@property
def supported_backends(self):
return 'gemm', 'conv'

def gen_gemm_weight(self, T: int):
weight = torch.zeros([T, T], device=self.weight.device)
for i in range(T):
end = i + 1
start = max(0, i + 1 - self.order)
length = min(end - start, self.order)
weight[i][start: end] = self.weight[self.order - length: self.order]

return weight

def __init__(self, order: int, exp_init: bool = True, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's', backend: str = 'gemm'):
super().__init__()
self.register_memory('queue', [])
self.step_mode = step_mode
self.order = order
self.surrogate_function = surrogate_function
self.backend = backend

if exp_init:
weight = torch.ones([order])
for i in range(order - 2, -1, -1):
weight[i] = weight[i + 1] / 2.
else:
weight = torch.ones([1, order])
nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
weight = weight[0]

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(torch.as_tensor(-1.))

def single_step_forward(self, x: torch.Tensor):
self.queue.append(x.flatten())
if self.queue.__len__() > self.order:
self.queue.pop(0)

weight = self.weight[self.order - self.queue.__len__(): self.order]
x_seq = torch.stack(self.queue)

for i in range(x.dim()):
weight = weight.unsqueeze(-1)

h = torch.sum(weight * x_seq, 0)
spike = self.surrogate_function(h + self.bias)

return spike.view(x.shape)


def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'gemm':
weight = self.gen_gemm_weight(x_seq.shape[0])
h_seq = torch.addmm(self.bias, weight, x_seq.flatten(1)).view(x_seq.shape)
return self.surrogate_function(h_seq)
elif self.backend == 'conv':
# x_seq.shape = [T, N, *]
x_seq_shape = x_seq.shape
# [T, N, *] -> [T, N] -> [N, T] -> [N, 1, T]
x_seq = x_seq.flatten(1).t().unsqueeze(1)
x_seq = F.pad(x_seq, pad=(self.order - 1, 0))
x_seq = F.conv1d(x_seq, self.weight.view(1, 1, -1), stride=1)
x_seq = x_seq.squeeze(1).t().view(x_seq_shape)
return self.surrogate_function(x_seq + self.bias)

else:
raise NotImplementedError(self.backend)

0 comments on commit 2243f16

Please sign in to comment.