-
Notifications
You must be signed in to change notification settings - Fork 213
Closed
Description
How to convert the [:,t,:] syntax of python language into C# language with high performance
source code
def forward(self, input_data):
input_weighted = torch.zeros(input_data.shape[0], self.T, self.input_size).to(device)
input_encoded = torch.zeros(input_data.shape[0], self.T, self.hidden_size).to(device)
# Eq. 8, parameters not in nn.Linear but to be learnt
hidden = init_hidden(input_data, self.hidden_size).to(device) # 1*batch_size*hidden_size
cell = init_hidden(input_data, self.hidden_size).to(device)
for t in range(self.T):
x = torch.cat((hidden.repeat(self.input_size, 1, 1).permute(1, 0, 2),
cell.repeat(self.input_size, 1, 1).permute(1, 0, 2),
input_data.permute(0, 2, 1)), dim=2) # batch_size * input_size * (2*hidden_size + T)
x = self.attn_linear(x.view(-1, self.hidden_size * 2 + self.T)) # (batch_size * input_size) * 1
attn_weights = F.softmax(x.view(-1, self.input_size),dim=1)
weighted_input = torch.mul(attn_weights, input_data[:, t, :]) # (batch_size, input_size)
self.lstm_layer.flatten_parameters()
_, lstm_states = self.lstm_layer(weighted_input.unsqueeze(0), (hidden, cell))
hidden = lstm_states[0]
cell = lstm_states[1]
# save output
input_weighted[:, t, :] = weighted_input
input_encoded[:, t, :] = hidden
return input_weighted, input_encodedAnd lstm does not support the flatten_parameters method.
Metadata
Metadata
Assignees
Labels
No labels