From c5a96b6025a274e41a549094d6aa0f67160e8f93 Mon Sep 17 00:00:00 2001 From: henrique moura Date: Mon, 29 Aug 2022 14:32:10 +0200 Subject: [PATCH 1/2] change on _set_hidden_state() Allows TGCN2 to work with batch from different sizes. Current implementation forces all batches to be the same size (defined by self.batch_size). --- torch_geometric_temporal/nn/recurrent/temporalgcn.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_geometric_temporal/nn/recurrent/temporalgcn.py b/torch_geometric_temporal/nn/recurrent/temporalgcn.py index d84c81cf..a5980a47 100644 --- a/torch_geometric_temporal/nn/recurrent/temporalgcn.py +++ b/torch_geometric_temporal/nn/recurrent/temporalgcn.py @@ -138,13 +138,13 @@ class TGCN2(torch.nn.Module): Args: in_channels (int): Number of input features. out_channels (int): Number of output features. - batch_size (int): Size of the batch. improved (bool): Stronger self loops. Default is False. cached (bool): Caching the message weights. Default is False. add_self_loops (bool): Adding self-loops for smoothing. Default is True. """ - def __init__(self, in_channels: int, out_channels: int, batch_size: int, improved: bool = False, cached: bool = False, + def __init__(self, in_channels: int, out_channels: int, + improved: bool = False, cached: bool = False, add_self_loops: bool = True): super(TGCN2, self).__init__() @@ -153,7 +153,7 @@ def __init__(self, in_channels: int, out_channels: int, batch_size: int, improve self.improved = improved self.cached = cached self.add_self_loops = add_self_loops - self.batch_size = batch_size + # self.batch_size = batch_size # not needed self._create_parameters_and_layers() def _create_update_gate_parameters_and_layers(self): @@ -178,7 +178,8 @@ def _create_parameters_and_layers(self): def _set_hidden_state(self, X, H): if H is None: - H = torch.zeros(self.batch_size,X.shape[1], self.out_channels).to(X.device) #(b, 207, 32) + # can infer batch_size from X.shape, because X is [B, N, F] + H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device) #(b, 207, 32) return H def _calculate_update_gate(self, X, edge_index, edge_weight, H): From cbb8ba51c3f7485450f3e50858228e530da66ca0 Mon Sep 17 00:00:00 2001 From: henrique moura Date: Tue, 30 Aug 2022 10:25:18 +0200 Subject: [PATCH 2/2] Must keep batch_size as parameter --- torch_geometric_temporal/nn/recurrent/temporalgcn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_geometric_temporal/nn/recurrent/temporalgcn.py b/torch_geometric_temporal/nn/recurrent/temporalgcn.py index a5980a47..f9eaa347 100644 --- a/torch_geometric_temporal/nn/recurrent/temporalgcn.py +++ b/torch_geometric_temporal/nn/recurrent/temporalgcn.py @@ -138,12 +138,14 @@ class TGCN2(torch.nn.Module): Args: in_channels (int): Number of input features. out_channels (int): Number of output features. + batch_size (int): Size of the batch. improved (bool): Stronger self loops. Default is False. cached (bool): Caching the message weights. Default is False. add_self_loops (bool): Adding self-loops for smoothing. Default is True. """ def __init__(self, in_channels: int, out_channels: int, + batch_size: int, # this entry is unnecessary, kept only for backward compatibility improved: bool = False, cached: bool = False, add_self_loops: bool = True): super(TGCN2, self).__init__() @@ -153,7 +155,7 @@ def __init__(self, in_channels: int, out_channels: int, self.improved = improved self.cached = cached self.add_self_loops = add_self_loops - # self.batch_size = batch_size # not needed + self.batch_size = batch_size # not needed self._create_parameters_and_layers() def _create_update_gate_parameters_and_layers(self):