Skip to content

Commit

Permalink
去掉阈值电压,因cell自带的bias可以起阈值电压的作用
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Sep 13, 2020
1 parent ddefc8e commit 84af124
Showing 1 changed file with 16 additions and 23 deletions.
39 changes: 16 additions & 23 deletions spikingjelly/clock_driven/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def forward(self, x: torch.Tensor, states=None):
return torch.cat(output, dim=0), torch.split(states_list, 1, dim=0)

class SpikingLSTMCell(SpikingRNNCellBase):
def __init__(self, input_size: int, hidden_size: int, bias=True, v_threshold=1.0,
def __init__(self, input_size: int, hidden_size: int, bias=True,
surrogate_function1=surrogate.Erf(), surrogate_function2=None):
'''
A `spiking` long short-term memory (LSTM) cell, which is firstly proposed in
Expand All @@ -245,9 +245,6 @@ def __init__(self, input_size: int, hidden_size: int, bias=True, v_threshold=1.0
``b_hh``. Default: ``True``
:type bias: bool
:param v_threshold: threshold voltage of neurons
:type v_threshold: float
:param surrogate_function1: surrogate function for replacing gradient of spiking functions during
back-propagation, which is used for generating ``i``, ``f``, ``o``
Expand Down Expand Up @@ -282,7 +279,6 @@ def __init__(self, input_size: int, hidden_size: int, bias=True, v_threshold=1.0
'''

super().__init__(input_size, hidden_size, bias)
self.v_threshold = v_threshold

self.linear_ih = nn.Linear(input_size, 4 * hidden_size, bias=bias)
self.linear_hh = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
Expand Down Expand Up @@ -320,10 +316,10 @@ def forward(self, x: torch.Tensor, hc=None):
h = hc[0]
c = hc[1]
if self.surrogate_function2 is None:
i, f, g, o = torch.split(self.surrogate_function1(self.linear_ih(x) + self.linear_hh(h) - self.v_threshold),
i, f, g, o = torch.split(self.surrogate_function1(self.linear_ih(x) + self.linear_hh(h)),
self.hidden_size, dim=1)
else:
i, f, g, o = torch.split(self.linear_ih(x) + self.linear_hh(h) - self.v_threshold, self.hidden_size, dim=1)
i, f, g, o = torch.split(self.linear_ih(x) + self.linear_hh(h), self.hidden_size, dim=1)
i = self.surrogate_function1(i)
f = self.surrogate_function1(f)
g = self.surrogate_function2(g)
Expand All @@ -343,10 +339,10 @@ def forward(self, x: torch.Tensor, hc=None):

class SpikingLSTM(SpikingRNNBase):
def __init__(self, input_size, hidden_size, num_layers, bias=True, dropout_p=0,
invariant_dropout_mask=False, bidirectional=False, v_threshold=1.0,
invariant_dropout_mask=False, bidirectional=False,
surrogate_function1=surrogate.Erf(), surrogate_function2=None):
super().__init__(input_size, hidden_size, num_layers, bias, dropout_p, invariant_dropout_mask, bidirectional,
v_threshold, surrogate_function1, surrogate_function2)
surrogate_function1, surrogate_function2)
@staticmethod
def base_cell():
return SpikingLSTMCell
Expand All @@ -356,10 +352,9 @@ def states_num():
return 2

class SpikingVanillaRNNCell(SpikingRNNCellBase):
def __init__(self, input_size: int, hidden_size: int, bias=True, v_threshold=1.0,
def __init__(self, input_size: int, hidden_size: int, bias=True,
surrogate_function=surrogate.Erf()):
super().__init__(input_size, hidden_size, bias)
self.v_threshold = v_threshold

self.linear_ih = nn.Linear(input_size, hidden_size, bias=bias)
self.linear_hh = nn.Linear(hidden_size, hidden_size, bias=bias)
Expand All @@ -371,14 +366,13 @@ def __init__(self, input_size: int, hidden_size: int, bias=True, v_threshold=1.0
def forward(self, x: torch.Tensor, h=None):
if h is None:
h = torch.zeros(size=[x.shape[0], self.hidden_size], dtype=torch.float, device=x.device)
return self.surrogate_function(self.linear_ih(x) + self.linear_hh(h) - self.v_threshold)
return self.surrogate_function(self.linear_ih(x) + self.linear_hh(h))

class SpikingVanillaRNN(SpikingRNNBase):
def __init__(self, input_size, hidden_size, num_layers, bias=True, dropout_p=0,
invariant_dropout_mask=False, bidirectional=False, v_threshold=1.0,
surrogate_function=surrogate.Erf()):
invariant_dropout_mask=False, bidirectional=False, surrogate_function=surrogate.Erf()):
super().__init__(input_size, hidden_size, num_layers, bias, dropout_p, invariant_dropout_mask, bidirectional,
v_threshold, surrogate_function)
surrogate_function)

@staticmethod
def base_cell():
Expand All @@ -389,10 +383,9 @@ def states_num():
return 1

class SpikingGRUCell(SpikingRNNCellBase):
def __init__(self, input_size: int, hidden_size: int, bias=True, v_threshold=1.0,
def __init__(self, input_size: int, hidden_size: int, bias=True,
surrogate_function1=surrogate.Erf(), surrogate_function2=None):
super().__init__(input_size, hidden_size, bias)
self.v_threshold = v_threshold

self.linear_ih = nn.Linear(input_size, 3 * hidden_size, bias=bias)
self.linear_hh = nn.Linear(hidden_size, 3 * hidden_size, bias=bias)
Expand All @@ -413,14 +406,14 @@ def forward(self, x: torch.Tensor, hc=None):

y_ih = torch.split(self.linear_ih(x), self.hidden_size, dim=1)
y_hh = torch.split(self.linear_hh(x), self.hidden_size, dim=1)
r = self.surrogate_function1(y_ih[0] + y_hh[0] - self.v_threshold)
z = self.surrogate_function1(y_ih[1] + y_hh[1] - self.v_threshold)
r = self.surrogate_function1(y_ih[0] + y_hh[0])
z = self.surrogate_function1(y_ih[1] + y_hh[1])

if self.surrogate_function2 is None:
n = self.surrogate_function1(y_ih[2] + r * y_hh[2] - self.v_threshold)
n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
else:
assert self.surrogate_function1.spiking == self.surrogate_function2.spiking
n = self.surrogate_function2(y_ih[2] + r * y_hh[2] - self.v_threshold)
n = self.surrogate_function2(y_ih[2] + r * y_hh[2])

if self.surrogate_function1.spiking:
# 可以使用针对脉冲的加速
Expand All @@ -432,10 +425,10 @@ def forward(self, x: torch.Tensor, hc=None):

class SpikingGRU(SpikingRNNBase):
def __init__(self, input_size, hidden_size, num_layers, bias=True, dropout_p=0,
invariant_dropout_mask=False, bidirectional=False, v_threshold=1.0,
invariant_dropout_mask=False, bidirectional=False,
surrogate_function1=surrogate.Erf(), surrogate_function2=None):
super().__init__(input_size, hidden_size, num_layers, bias, dropout_p, invariant_dropout_mask, bidirectional,
v_threshold, surrogate_function1, surrogate_function2)
surrogate_function1, surrogate_function2)
@staticmethod
def base_cell():
return SpikingGRUCell
Expand Down

0 comments on commit 84af124

Please sign in to comment.