Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change on _set_hidden_state() #187

Merged
merged 2 commits into from
Sep 4, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions torch_geometric_temporal/nn/recurrent/temporalgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ class TGCN2(torch.nn.Module):
add_self_loops (bool): Adding self-loops for smoothing. Default is True.
"""

def __init__(self, in_channels: int, out_channels: int, batch_size: int, improved: bool = False, cached: bool = False,
def __init__(self, in_channels: int, out_channels: int,
batch_size: int, # this entry is unnecessary, kept only for backward compatibility
improved: bool = False, cached: bool = False,
add_self_loops: bool = True):
super(TGCN2, self).__init__()

Expand All @@ -153,7 +155,7 @@ def __init__(self, in_channels: int, out_channels: int, batch_size: int, improve
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
self.batch_size = batch_size
self.batch_size = batch_size # not needed
self._create_parameters_and_layers()

def _create_update_gate_parameters_and_layers(self):
Expand All @@ -178,7 +180,8 @@ def _create_parameters_and_layers(self):

def _set_hidden_state(self, X, H):
if H is None:
H = torch.zeros(self.batch_size,X.shape[1], self.out_channels).to(X.device) #(b, 207, 32)
# can infer batch_size from X.shape, because X is [B, N, F]
H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device) #(b, 207, 32)
return H

def _calculate_update_gate(self, X, edge_index, edge_weight, H):
Expand Down