Skip to content

Commit

Permalink
fix #374
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei committed Apr 30, 2023
1 parent 214a847 commit 70a196f
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions spikingjelly/activation_based/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,9 @@ def forward(self, x_seq: torch.Tensor):
spike_seq = self.surrogate_function(h_seq)
return spike_seq.view(x_seq.shape)

def extra_repr(self):
return super().extra_repr() + f', T={self.T}'


class MaskedPSN(base.MemoryModule):
@staticmethod
Expand Down Expand Up @@ -1967,14 +1970,14 @@ def __init__(self, k: int, T: int, lambda_init: float = 0.,
self.surrogate_function = surrogate_function
weight = torch.zeros([T, T])
bias = torch.zeros([T, 1])
self.register_buffer('_lambda_', torch.as_tensor(lambda_init))

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)

nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.constant_(self.bias, -1.)

self.lambda_ = lambda_init
mask1 = torch.ones([T, T])
mask0 = torch.tril(mask1) * torch.triu(mask1, -(self.k - 1))
self.register_buffer('mask0', mask0)
Expand Down Expand Up @@ -2015,17 +2018,16 @@ def multi_step_forward(self, x_seq: torch.Tensor):
spike_seq = self.surrogate_function(h_seq).view(x_seq.shape)
return spike_seq

def state_dict(self, *args, **kwargs):
sd = super().state_dict()
sd['lambda_'] = torch.as_tensor(self.lambda_)
return sd
@property
def lambda_(self):
return self._lambda_

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
@lambda_.setter
def lambda_(self, value: float):
torch.fill_(self.lambda_, value)

self.lambda_ = state_dict['lambda_'].item()
state_dict.pop('lambda_')
super().load_state_dict(state_dict, strict)
def extra_repr(self):
return super().extra_repr() + f', lambda_={self.lambda_}, T={self.T}'


class SlidingPSN(base.MemoryModule):
Expand Down Expand Up @@ -2162,3 +2164,6 @@ def multi_step_forward(self, x_seq: torch.Tensor):

else:
raise NotImplementedError(self.backend)

def extra_repr(self):
return super().extra_repr() + f', order={self.k}'

0 comments on commit 70a196f

Please sign in to comment.