Skip to content

Commit

Permalink
SpikingLSTM的mask bug修复
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Sep 7, 2020
1 parent 60e34b7 commit 2669d98
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 32 deletions.
2 changes: 0 additions & 2 deletions spikingjelly/clock_driven/accelerating.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ def sub(x: torch.Tensor, spike: torch.Tensor):
numerical stability.
'''
return subtract_spike.apply(x, spike)


def mul(x: torch.Tensor, spike: torch.Tensor, spike_mul_spike=False):
'''
* :ref:`API in English <mul-en>`
Expand Down
59 changes: 29 additions & 30 deletions spikingjelly/clock_driven/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ def forward(self, x: torch.Tensor, hc=None):
:param hc: (h_0, c_0)
h_0 : torch.Tensor
``shape = [batch, hidden_size]``, tensor containing the initial hidden state for each element in the batch
``shape = [batch_size, hidden_size]``, tensor containing the initial hidden state for each element in the batch
c_0 : torch.Tensor
``shape = [batch, hidden_size]``, tensor containing the initial cell state for each element in the batch
``shape = [batch_size, hidden_size]``, tensor containing the initial cell state for each element in the batch
If (h_0, c_0) is not provided, both ``h_0`` and ``c_0`` default to zero
:type hc: tuple or None
:return: (h_1, c_1) :
h_1 : torch.Tensor
``shape = [batch, hidden_size]``, tensor containing the next hidden state for each element in the batch
``shape = [batch_size, hidden_size]``, tensor containing the next hidden state for each element in the batch
c_1 : torch.Tensor
``shape = [batch, hidden_size]``, tensor containing the next cell state for each element in the batch
``shape = [batch_size, hidden_size]``, tensor containing the next cell state for each element in the batch
:rtype: tuple
'''
if hc is None:
Expand All @@ -128,7 +128,6 @@ def forward(self, x: torch.Tensor, hc=None):
f = self.surrogate_function1(f)
g = self.surrogate_function2(g)
o = self.surrogate_function1(o)

if self.surrogate_function2 is not None:
assert self.surrogate_function1.spiking == self.surrogate_function2.spiking
if self.surrogate_function1.spiking:
Expand Down Expand Up @@ -170,44 +169,44 @@ def __init__(self, input_size, hidden_size, num_layers, bias=True, dropout_p=0,
raise NotImplementedError
else:
self.lstm_cells = []
for i in range(num_layers):
self.lstm_cells.append(SpikingLSTMCell(input_size, hidden_size, bias, v_threshold,
surrogate_function1, surrogate_function2))
for i in range(num_layers - 1):
self.lstm_cells.append(SpikingLSTMCell(hidden_size, hidden_size, bias, v_threshold,
surrogate_function1, surrogate_function2))
self.lstm_cells = nn.Sequential(*self.lstm_cells)

def forward(self, x: torch.Tensor, hc=None):
# x.shape=[T, batch_size, input_size]
T = x.shape[0]
batch_size = x.shape[1]
if self.bidirectional:
raise NotImplementedError
else:
# 生成保存h和c的list
h_list = torch.zeros(size=[self.num_layers, batch_size, self.hidden_size]).to(x)
c_list = torch.zeros_like(h_list)
# 初始的h c从输入获取
if hc is None:
hc_list = [None] * self.num_layers
else:
hc_list = []
for i in range(self.num_layers):
hc_list.append(hc[i])
if self.training and self.dropout_p > 0:
if self.invariant_dropout_mask:
# 生成不随时间变化的dropout的mask
mask = F.dropout(torch.zeros_like(x[0].data), p=self.dropout_p, training=True, inplace=True)
output = []
for t in range(T):
hc_list[0] = self.lstm_cells[0](x[t], hc_list[0])
for i in range(1, self.num_layers):
if hc is not None:
h_list = hc[0]
c_list = hc[1]
if self.training and self.dropout_p > 0 and self.invariant_dropout_mask:
mask = F.dropout(torch.ones(size=[batch_size, self.hidden_size]), p=self.dropout_p, training=True, inplace=True).to(x)

output = []
for t in range(T):
h_list[0], c_list[0] = self.lstm_cells[0](x[t], (h_list[0], c_list[0]))
for i in range(1, self.num_layers):
h = h_list[i - 1]
if self.training and self.dropout_p > 0:
if self.invariant_dropout_mask:
hc_list[i] = self.lstm_cells[i](hc_list[i - 1][0] * mask, hc_list[i])
h = h * mask
else:
hc_list[i] = self.lstm_cells[i](F.dropout(hc_list[i - 1][0], p=self.dropout_p, training=True),
hc_list[i])
output.append(hc_list[-1][0].unsqueeze(0))
h_n = []
c_n = []
for hc in hc_list:
h_n.append(hc[0].unsqueeze(0))
c_n.append(hc[1].unsqueeze(0))
return torch.cat(output, dim=0), torch.cat(h_n, dim=0), torch.cat(c_n, dim=0)
h = F.dropout(h, p=self.dropout_p, training=True)
h_list[i], c_list[i] = self.lstm_cells[i](h, (h_list[i], c_list[i]))
output.append(h_list[-1].unsqueeze(0))

return torch.cat(output, dim=0), (h_list, c_list)



Expand Down

0 comments on commit 2669d98

Please sign in to comment.