Skip to content

Commit

Permalink
Merge pull request #187 from h3dema/patch-1
Browse files Browse the repository at this point in the history
change on _set_hidden_state()
  • Loading branch information
benedekrozemberczki committed Sep 4, 2022
2 parents 57006a4 + cbb8ba5 commit 9d4a7ad
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions torch_geometric_temporal/nn/recurrent/temporalgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ class TGCN2(torch.nn.Module):
add_self_loops (bool): Adding self-loops for smoothing. Default is True.
"""

def __init__(self, in_channels: int, out_channels: int, batch_size: int, improved: bool = False, cached: bool = False,
def __init__(self, in_channels: int, out_channels: int,
batch_size: int, # this entry is unnecessary, kept only for backward compatibility
improved: bool = False, cached: bool = False,
add_self_loops: bool = True):
super(TGCN2, self).__init__()

Expand All @@ -153,7 +155,7 @@ def __init__(self, in_channels: int, out_channels: int, batch_size: int, improve
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
self.batch_size = batch_size
self.batch_size = batch_size # not needed
self._create_parameters_and_layers()

def _create_update_gate_parameters_and_layers(self):
Expand All @@ -178,7 +180,8 @@ def _create_parameters_and_layers(self):

def _set_hidden_state(self, X, H):
if H is None:
H = torch.zeros(self.batch_size,X.shape[1], self.out_channels).to(X.device) #(b, 207, 32)
# can infer batch_size from X.shape, because X is [B, N, F]
H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device) #(b, 207, 32)
return H

def _calculate_update_gate(self, X, edge_index, edge_weight, H):
Expand Down

0 comments on commit 9d4a7ad

Please sign in to comment.