Skip to content

Commit

Permalink
Merge pull request #181 from xunil17/hetero_gc_lstm_paramater_fix
Browse files Browse the repository at this point in the history
Optimize code to prevent repeated calls to convolution operator
  • Loading branch information
benedekrozemberczki committed Jul 11, 2022
2 parents 8b0dee4 + 89f8595 commit 2a020a4
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torch_geometric_temporal/nn/hetero/heterogclstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,29 +109,33 @@ def _set_cell_state(self, x_dict, c_dict):

def _calculate_input_gate(self, x_dict, edge_index_dict, h_dict, c_dict):
i_dict = {node_type: torch.matmul(X, self.W_i[node_type]) for node_type, X in x_dict.items()}
i_dict = {node_type: I + self.conv_i(h_dict, edge_index_dict)[node_type] for node_type, I in i_dict.items()}
conv_i = self.conv_i(h_dict, edge_index_dict)
i_dict = {node_type: I + conv_i[node_type] for node_type, I in i_dict.items()}
i_dict = {node_type: I + self.b_i[node_type] for node_type, I in i_dict.items()}
i_dict = {node_type: torch.sigmoid(I) for node_type, I in i_dict.items()}
return i_dict

def _calculate_forget_gate(self, x_dict, edge_index_dict, h_dict, c_dict):
f_dict = {node_type: torch.matmul(X, self.W_f[node_type]) for node_type, X in x_dict.items()}
f_dict = {node_type: F + self.conv_f(h_dict, edge_index_dict)[node_type] for node_type, F in f_dict.items()}
conv_f = self.conv_f(h_dict, edge_index_dict)
f_dict = {node_type: F + conv_f[node_type] for node_type, F in f_dict.items()}
f_dict = {node_type: F + self.b_f[node_type] for node_type, F in f_dict.items()}
f_dict = {node_type: torch.sigmoid(F) for node_type, F in f_dict.items()}
return f_dict

def _calculate_cell_state(self, x_dict, edge_index_dict, h_dict, c_dict, i_dict, f_dict):
t_dict = {node_type: torch.matmul(X, self.W_c[node_type]) for node_type, X in x_dict.items()}
t_dict = {node_type: T + self.conv_c(h_dict, edge_index_dict)[node_type] for node_type, T in t_dict.items()}
conv_c = self.conv_c(h_dict, edge_index_dict)
t_dict = {node_type: T + conv_c[node_type] for node_type, T in t_dict.items()}
t_dict = {node_type: T + self.b_c[node_type] for node_type, T in t_dict.items()}
t_dict = {node_type: torch.tanh(T) for node_type, T in t_dict.items()}
c_dict = {node_type: f_dict[node_type] * C + i_dict[node_type] * t_dict[node_type] for node_type, C in c_dict.items()}
return c_dict

def _calculate_output_gate(self, x_dict, edge_index_dict, h_dict, c_dict):
o_dict = {node_type: torch.matmul(X, self.W_o[node_type]) for node_type, X in x_dict.items()}
o_dict = {node_type: O + self.conv_o(h_dict, edge_index_dict)[node_type] for node_type, O in o_dict.items()}
conv_o = self.conv_o(h_dict, edge_index_dict)
o_dict = {node_type: O + conv_o[node_type] for node_type, O in o_dict.items()}
o_dict = {node_type: O + self.b_o[node_type] for node_type, O in o_dict.items()}
o_dict = {node_type: torch.sigmoid(O) for node_type, O in o_dict.items()}
return o_dict
Expand Down

0 comments on commit 2a020a4

Please sign in to comment.