Skip to content

Commit

Permalink
Merge pull request #381 from YufangMo/master
Browse files Browse the repository at this point in the history
Update rnn.py
  • Loading branch information
fangwei123456 committed May 18, 2023
2 parents bd10454 + 4bb5bb4 commit 916dcf4
Showing 1 changed file with 44 additions and 36 deletions.
80 changes: 44 additions & 36 deletions spikingjelly/activation_based/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
from spikingjelly.activation_based import surrogate, layer
import math

def directional_rnn_cell_forward(cell: nn.Module, x: torch.Tensor,
states: torch.Tensor):

T = x.shape[0]
ss = states

output = []
for t in range(T):
ss = cell(x[t], ss)
if states.dim() == 2:
output.append(ss)
elif states.dim() == 3:
output.append(ss[0])
# 当RNN cell具有多个隐藏状态时,通常第0个隐藏状态是其输出
return torch.stack(output), ss

def bidirectional_rnn_cell_forward(cell: nn.Module, cell_reverse: nn.Module, x: torch.Tensor,
states: torch.Tensor, states_reverse: torch.Tensor):
'''
Expand Down Expand Up @@ -406,21 +422,17 @@ def forward(self, x: torch.Tensor, states=None):
# states非None且为tuple,则合并成tensor
states_list = torch.stack(states)
# shape = [self.states_num(), self.num_layers * 2, batch_size, self.hidden_size]
elif isinstance(states, torch.Tensor):
# states非None且不为tuple时,它本身就是一个tensor,例如普通RNN的状态
elif states.dim() == 3:
states_list = states
elif states is None:
# squeeze(0)的作用是,若states_num() == 1则去掉多余的维度
if self.bidirectional:
states_list = torch.zeros(
size=[self.states_num(), self.num_layers * 2, batch_size, self.hidden_size]).to(x).squeeze(0)
else:
states_list = torch.zeros(size=[self.states_num(), self.num_layers, batch_size, self.hidden_size]).to(
x).squeeze(0)
else:
raise TypeError

# print(states_list.shape) [state_num num_direction*num_layer, B, H] or [num_direction*num_layer, B, H]

if self.bidirectional:
# 判断 num_direction*num_layers 是否符合要求,否则 new_states_list 会存在额外的0矩阵
if (states_list.dim() == 4 and states_list.shape[1] != 2*self.num_layers) or (states_list.dim() == 3 and states_list.shape[0] != 2*self.num_layers):
raise ValueError
# y 表示第i层的输出。初始化时,y即为输入
y = x.clone()
if self.training and self.dropout_p > 0 and self.invariant_dropout_mask:
Expand Down Expand Up @@ -455,47 +467,43 @@ def forward(self, x: torch.Tensor, states=None):
if self.states_num() == 1:
return y, new_states_list
else:
# split使得返回值是tuple
return y, torch.split(new_states_list, 1, dim=0)

return y, tuple(new_states_list)

else:
# 判断 num_direction*num_layers 是否符合要求,否则 new_states_list 会存在额外的0矩阵
if (states_list.dim() == 4 and states_list.shape[1] != self.num_layers) or (states_list.dim() == 3 and states_list.shape[0] != self.num_layers):
raise ValueError
# y 表示第i层的输出。初始化时,y即为输入
y = x.clone()
if self.training and self.dropout_p > 0 and self.invariant_dropout_mask:
mask = F.dropout(torch.ones(size=[self.num_layers - 1, batch_size, self.hidden_size]),
mask = F.dropout(torch.ones(size=[self.num_layers - 1, batch_size, self.hidden_size * 2]),
p=self.dropout_p, training=True, inplace=True).to(x)

output = []

for t in range(T):
for i in range(self.num_layers):
# 第i层神经元的起始状态从输入states_list获取
new_states_list = torch.zeros_like(states_list.data)
if self.states_num() == 1:
new_states_list[0] = self.cells[0](x[t], states_list[0])
cell_init_states = states_list[i]
else:
new_states_list[:, 0] = torch.stack(self.cells[0](x[t], states_list[:, 0]))
for i in range(1, self.num_layers):
y = states_list[0, i - 1]
if self.training and self.dropout_p > 0:
cell_init_states = states_list[:, i]

if self.training and self.dropout_p > 0:
if i > 1:
if self.invariant_dropout_mask:
y = y * mask[i - 1]
else:
y = F.dropout(y, p=self.dropout_p, training=True)
if self.states_num() == 1:
new_states_list[i] = self.cells[i](y, states_list[i])
else:
new_states_list[:, i] = torch.stack(self.cells[i](y, states_list[:, i]))
y, ss = directional_rnn_cell_forward(
self.cells[i], y, cell_init_states)
# 更新states_list[i]
if self.states_num() == 1:
output.append(new_states_list[-1].clone().unsqueeze(0))
new_states_list[i] = ss
else:
output.append(new_states_list[0, -1].clone().unsqueeze(0))
new_states_list[:, i] = torch.stack(ss)
states_list = new_states_list.clone()

if self.states_num() == 1:
return torch.cat(output, dim=0), new_states_list
return y, new_states_list
else:
# split使得返回值是tuple
new_states_list = list(torch.split(new_states_list, 1, dim=0))
for i in range(new_states_list.__len__()):
new_states_list[i] = new_states_list[i].squeeze(0)
return torch.cat(output, dim=0), new_states_list
return y, tuple(new_states_list)

class SpikingLSTMCell(SpikingRNNCellBase):
def __init__(self, input_size: int, hidden_size: int, bias=True,
Expand Down

0 comments on commit 916dcf4

Please sign in to comment.