Skip to content

Commit

Permalink
fixed lambda_max error when normalization != sym
Browse files Browse the repository at this point in the history
  • Loading branch information
gravins committed Mar 31, 2022
1 parent 88480db commit 1a8826f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 40 deletions.
26 changes: 14 additions & 12 deletions torch_geometric_temporal/nn/recurrent/gc_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,31 +133,31 @@ def _set_cell_state(self, X, C):
C = torch.zeros(X.shape[0], self.out_channels).to(X.device)
return C

def _calculate_input_gate(self, X, edge_index, edge_weight, H, C):
def _calculate_input_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
I = torch.matmul(X, self.W_i)
I = I + self.conv_i(H, edge_index, edge_weight)
I = I + self.conv_i(H, edge_index, edge_weight, lambda_max=lambda_max)
I = I + self.b_i
I = torch.sigmoid(I)
return I

def _calculate_forget_gate(self, X, edge_index, edge_weight, H, C):
def _calculate_forget_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
F = torch.matmul(X, self.W_f)
F = F + self.conv_f(H, edge_index, edge_weight)
F = F + self.conv_f(H, edge_index, edge_weight, lambda_max=lambda_max)
F = F + self.b_f
F = torch.sigmoid(F)
return F

def _calculate_cell_state(self, X, edge_index, edge_weight, H, C, I, F):
def _calculate_cell_state(self, X, edge_index, edge_weight, H, C, I, F, lambda_max):
T = torch.matmul(X, self.W_c)
T = T + self.conv_c(H, edge_index, edge_weight)
T = T + self.conv_c(H, edge_index, edge_weight, lambda_max=lambda_max)
T = T + self.b_c
T = torch.tanh(T)
C = F * C + I * T
return C

def _calculate_output_gate(self, X, edge_index, edge_weight, H, C):
def _calculate_output_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
O = torch.matmul(X, self.W_o)
O = O + self.conv_o(H, edge_index, edge_weight)
O = O + self.conv_o(H, edge_index, edge_weight, lambda_max=lambda_max)
O = O + self.b_o
O = torch.sigmoid(O)
return O
Expand All @@ -173,6 +173,7 @@ def forward(
edge_weight: torch.FloatTensor = None,
H: torch.FloatTensor = None,
C: torch.FloatTensor = None,
lambda_max: torch.Tensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass. If edge weights are not present the forward pass
Expand All @@ -186,16 +187,17 @@ def forward(
* **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
* **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
* **C** *(PyTorch Float Tensor, optional)* - Cell state matrix for all nodes.
* **lambda_max** *(PyTorch Tensor, optional but mandatory if normalization is not sym)* - Largest eigenvalue of Laplacian.
Return types:
* **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
* **C** *(PyTorch Float Tensor)* - Cell state matrix for all nodes.
"""
H = self._set_hidden_state(X, H)
C = self._set_cell_state(X, C)
I = self._calculate_input_gate(X, edge_index, edge_weight, H, C)
F = self._calculate_forget_gate(X, edge_index, edge_weight, H, C)
C = self._calculate_cell_state(X, edge_index, edge_weight, H, C, I, F)
O = self._calculate_output_gate(X, edge_index, edge_weight, H, C)
I = self._calculate_input_gate(X, edge_index, edge_weight, H, C, lambda_max)
F = self._calculate_forget_gate(X, edge_index, edge_weight, H, C, lambda_max)
C = self._calculate_cell_state(X, edge_index, edge_weight, H, C, I, F, lambda_max)
O = self._calculate_output_gate(X, edge_index, edge_weight, H, C, lambda_max)
H = self._calculate_hidden_state(O, C)
return H, C
27 changes: 15 additions & 12 deletions torch_geometric_temporal/nn/recurrent/gconv_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,21 @@ def _set_hidden_state(self, X, H):
H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
return H

def _calculate_update_gate(self, X, edge_index, edge_weight, H):
Z = self.conv_x_z(X, edge_index, edge_weight)
Z = Z + self.conv_h_z(H, edge_index, edge_weight)
def _calculate_update_gate(self, X, edge_index, edge_weight, H, lambda_max):
Z = self.conv_x_z(X, edge_index, edge_weight, lambda_max=lambda_max)
Z = Z + self.conv_h_z(H, edge_index, edge_weight, lambda_max=lambda_max)
Z = torch.sigmoid(Z)
return Z

def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
R = self.conv_x_r(X, edge_index, edge_weight)
R = R + self.conv_h_r(H, edge_index, edge_weight)
def _calculate_reset_gate(self, X, edge_index, edge_weight, H, lambda_max):
R = self.conv_x_r(X, edge_index, edge_weight, lambda_max=lambda_max)
R = R + self.conv_h_r(H, edge_index, edge_weight, lambda_max=lambda_max)
R = torch.sigmoid(R)
return R

def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
H_tilde = self.conv_x_h(X, edge_index, edge_weight)
H_tilde = H_tilde + self.conv_h_h(H * R, edge_index, edge_weight)
def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R, lambda_max):
H_tilde = self.conv_x_h(X, edge_index, edge_weight, lambda_max=lambda_max)
H_tilde = H_tilde + self.conv_h_h(H * R, edge_index, edge_weight, lambda_max=lambda_max)
H_tilde = torch.tanh(H_tilde)
return H_tilde

Expand All @@ -144,6 +144,7 @@ def forward(
edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
H: torch.FloatTensor = None,
lambda_max: torch.Tensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass. If edge weights are not present the forward pass
Expand All @@ -155,13 +156,15 @@ def forward(
* **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
* **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
* **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
* **lambda_max** *(PyTorch Tensor, optional but mandatory if normalization is not sym)* - Largest eigenvalue of Laplacian.
Return types:
* **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
"""
H = self._set_hidden_state(X, H)
Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
Z = self._calculate_update_gate(X, edge_index, edge_weight, H, lambda_max)
R = self._calculate_reset_gate(X, edge_index, edge_weight, H, lambda_max)
H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R, lambda_max)
H = self._calculate_hidden_state(Z, H, H_tilde)
return H
34 changes: 18 additions & 16 deletions torch_geometric_temporal/nn/recurrent/gconv_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,33 +163,33 @@ def _set_cell_state(self, X, C):
C = torch.zeros(X.shape[0], self.out_channels).to(X.device)
return C

def _calculate_input_gate(self, X, edge_index, edge_weight, H, C):
I = self.conv_x_i(X, edge_index, edge_weight)
I = I + self.conv_h_i(H, edge_index, edge_weight)
def _calculate_input_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
I = self.conv_x_i(X, edge_index, edge_weight, lambda_max=lambda_max)
I = I + self.conv_h_i(H, edge_index, edge_weight, lambda_max=lambda_max)
I = I + (self.w_c_i * C)
I = I + self.b_i
I = torch.sigmoid(I)
return I

def _calculate_forget_gate(self, X, edge_index, edge_weight, H, C):
F = self.conv_x_f(X, edge_index, edge_weight)
F = F + self.conv_h_f(H, edge_index, edge_weight)
def _calculate_forget_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
F = self.conv_x_f(X, edge_index, edge_weight, lambda_max=lambda_max)
F = F + self.conv_h_f(H, edge_index, edge_weight, lambda_max=lambda_max)
F = F + (self.w_c_f * C)
F = F + self.b_f
F = torch.sigmoid(F)
return F

def _calculate_cell_state(self, X, edge_index, edge_weight, H, C, I, F):
T = self.conv_x_c(X, edge_index, edge_weight)
T = T + self.conv_h_c(H, edge_index, edge_weight)
def _calculate_cell_state(self, X, edge_index, edge_weight, H, C, I, F, lambda_max):
T = self.conv_x_c(X, edge_index, edge_weight, lambda_max=lambda_max)
T = T + self.conv_h_c(H, edge_index, edge_weight, lambda_max=lambda_max)
T = T + self.b_c
T = torch.tanh(T)
C = F * C + I * T
return C

def _calculate_output_gate(self, X, edge_index, edge_weight, H, C):
O = self.conv_x_o(X, edge_index, edge_weight)
O = O + self.conv_h_o(H, edge_index, edge_weight)
def _calculate_output_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
O = self.conv_x_o(X, edge_index, edge_weight, lambda_max=lambda_max)
O = O + self.conv_h_o(H, edge_index, edge_weight, lambda_max=lambda_max)
O = O + (self.w_c_o * C)
O = O + self.b_o
O = torch.sigmoid(O)
Expand All @@ -206,6 +206,7 @@ def forward(
edge_weight: torch.FloatTensor = None,
H: torch.FloatTensor = None,
C: torch.FloatTensor = None,
lambda_max: torch.Tensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass. If edge weights are not present the forward pass
Expand All @@ -219,16 +220,17 @@ def forward(
* **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
* **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
* **C** *(PyTorch Float Tensor, optional)* - Cell state matrix for all nodes.
* **lambda_max** *(PyTorch Tensor, optional but mandatory if normalization is not sym)* - Largest eigenvalue of Laplacian.
Return types:
* **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
* **C** *(PyTorch Float Tensor)* - Cell state matrix for all nodes.
"""
H = self._set_hidden_state(X, H)
C = self._set_cell_state(X, C)
I = self._calculate_input_gate(X, edge_index, edge_weight, H, C)
F = self._calculate_forget_gate(X, edge_index, edge_weight, H, C)
C = self._calculate_cell_state(X, edge_index, edge_weight, H, C, I, F)
O = self._calculate_output_gate(X, edge_index, edge_weight, H, C)
I = self._calculate_input_gate(X, edge_index, edge_weight, H, C, lambda_max)
F = self._calculate_forget_gate(X, edge_index, edge_weight, H, C, lambda_max)
C = self._calculate_cell_state(X, edge_index, edge_weight, H, C, I, F, lambda_max)
O = self._calculate_output_gate(X, edge_index, edge_weight, H, C, lambda_max)
H = self._calculate_hidden_state(O, C)
return H, C

0 comments on commit 1a8826f

Please sign in to comment.