Skip to content

Commit

Permalink
Merge pull request #382 from YufangMo/master
Browse files Browse the repository at this point in the history
Update rnn.py
  • Loading branch information
fangwei123456 committed May 19, 2023
2 parents 916dcf4 + 8b2fad8 commit 21ce1bc
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions spikingjelly/activation_based/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,17 @@ def forward(self, x: torch.Tensor, states=None):
# states非None且为tuple,则合并成tensor
states_list = torch.stack(states)
# shape = [self.states_num(), self.num_layers * 2, batch_size, self.hidden_size]
elif states.dim() == 3:
states_list = states
elif isinstance(states, torch.Tensor):
if states.dim() == 3:
states_list = states
else:
raise TypeError
elif states == None:
if self.bidirectional == True:
states_list = torch.zeros(size=[self.states_num(), self.num_layers*2, x.shape[1], self.hidden_size], dtype=torch.float, device=x.device).squeeze(0)
else:
states_list = torch.zeros(size=[self.states_num(), self.num_layers, x.shape[1], self.hidden_size], dtype=torch.float, device=x.device).squeeze(0)

else:
raise TypeError

Expand Down

0 comments on commit 21ce1bc

Please sign in to comment.