From 4bd7c2e75383a5d1143de08b4705dd8b5170f221 Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Thu, 6 Jan 2022 13:23:00 +0100 Subject: [PATCH 01/10] First commit for heterogeneous graph support This includes: - a new data-structure 'StaticHeteroGraphTemporalSignal' that works like 'StaticGraphTemporalSignal' but with troch_geometric HeteroData objects as snapshots; instead of np arrays, dictionaries wirth key(node/edge types as strings)/value(indices/features as np arrays) are expected - code tests for the new data-structure - documentation for the new data-structure --- docs/source/modules/signal.rst | 7 + test/dataset_test.py | 30 +++ torch_geometric_temporal/signal/__init__.py | 2 + .../static_hetero_graph_temporal_signal.py | 203 ++++++++++++++++++ 4 files changed, 242 insertions(+) create mode 100644 torch_geometric_temporal/signal/static_hetero_graph_temporal_signal.py diff --git a/docs/source/modules/signal.rst b/docs/source/modules/signal.rst index eaa06bce..50c9ffc3 100644 --- a/docs/source/modules/signal.rst +++ b/docs/source/modules/signal.rst @@ -19,6 +19,13 @@ Temporal Signal Iterators :members: :undoc-members: +Heterogeneous Temporal Signal Iterators +------------------------- + +.. automodule:: torch_geometric_temporal.signal.static_hetero_graph_temporal_signal + :members: + :undoc-members: + Temporal Signal Batch Iterators ------------------------------- diff --git a/test/dataset_test.py b/test/dataset_test.py index 19426c4a..44110dc5 100644 --- a/test/dataset_test.py +++ b/test/dataset_test.py @@ -7,6 +7,8 @@ from torch_geometric_temporal.signal import DynamicGraphTemporalSignal from torch_geometric_temporal.signal import DynamicGraphStaticSignal +from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignal + from torch_geometric_temporal.dataset import METRLADatasetLoader, PemsBayDatasetLoader from torch_geometric_temporal.dataset import ( ChickenpoxDatasetLoader, @@ -139,6 +141,34 @@ def test_dynamic_graph_temporal_signal_additional_attrs(): assert snapshot.optional2.shape == (1,) +def test_static_hetero_graph_temporal_signal(): + dataset = StaticHeteroGraphTemporalSignal(None, None, [None], [None]) + for snapshot in dataset: + assert len(snapshot.node_types) == 0 + assert len(snapshot.node_stores) == 0 + assert len(snapshot.edge_types) == 0 + assert len(snapshot.edge_stores) == 0 + + +def test_static_hetero_graph_temporal_signal_typing(): + dataset = StaticHeteroGraphTemporalSignal(None, None, [{'author': np.array([1])}], [{'author': np.array([2])}]) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert len(snapshot.edge_types) == 0 + + +def test_static_hetero_graph_temporal_signal_additional_attrs(): + dataset = StaticHeteroGraphTemporalSignal(None, None, [None], [None], + optional1=[{'author': np.array([1])}], + optional2=[{'author': np.array([2])}]) + assert dataset.additional_feature_keys == ["optional1", "optional2"] + for snapshot in dataset: + assert snapshot.node_stores[0]['optional1'].shape == (1,) + assert snapshot.node_stores[0]['optional2'].shape == (1,) + + def test_chickenpox(): loader = ChickenpoxDatasetLoader() diff --git a/torch_geometric_temporal/signal/__init__.py b/torch_geometric_temporal/signal/__init__.py index a10d6496..ea226485 100644 --- a/torch_geometric_temporal/signal/__init__.py +++ b/torch_geometric_temporal/signal/__init__.py @@ -7,4 +7,6 @@ from .dynamic_graph_static_signal import * from .dynamic_graph_static_signal_batch import * +from .static_hetero_graph_temporal_signal import * + from .train_test_split import * diff --git a/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal.py b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal.py new file mode 100644 index 00000000..7bf2b7c7 --- /dev/null +++ b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal.py @@ -0,0 +1,203 @@ +import torch +import numpy as np +from typing import List, Dict, Union, Tuple +from torch_geometric.data import HeteroData + + +Edge_Index = Union[Dict[Tuple[str, str, str], np.ndarray], None] +Edge_Weight = Union[Dict[Tuple[str, str, str], np.ndarray], None] +Node_Features = List[Union[Dict[str, np.ndarray], None]] +Targets = List[Union[Dict[str, np.ndarray], None]] +Additional_Features = List[Union[Dict[str, np.ndarray], None]] + + +class StaticHeteroGraphTemporalSignal(object): + r"""A data iterator object to contain a static heterogeneous graph with a dynamically + changing constant time difference temporal feature set (multiple signals). + The node labels (target) are also temporal. The iterator returns a single + constant time difference temporal snapshot for a time period (e.g. day or week). + This single temporal snapshot is a Pytorch Geometric HeteroData object. Between two + temporal snapshots the features and optionally passed attributes might change. + However, the underlying graph is the same. + + .. code-block:: python + from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignal + + edge_index_dict = { + ("author", "writes", "paper"): np.array([[0, 0, 1], [0, 1, 2]]) + } + + feature_dicts = [ + {"author": np.array([[0], [0]]), + "paper": np.array([[0], [0], [0]])}, + {"author": np.array([[0.1], [0.1]]), + "paper": np.array([[0.1], [0.1], [0.1]])}, + {"author": np.array([[0.2], [0.2]]), + "paper": np.array([[0.2], [0.2], [0.2]])} + ] + + target_dicts = [ + {"author": np.array([0, 0]), + "paper": np.array([0, 0, 0])}, + {"author": np.array([1, 1]), + "paper": np.array([1, 1, 1])}, + {"author": np.array([2, 2]), + "paper": np.array([2, 2, 2])} + ] + + # Create heterogeneous graph snapshots with same structure but different features + # and labels (in this example node types "paper" and "author"): + graph_snapshots = StaticHeteroGraphTemporalSignal(edge_index_dict, None, feature_dicts, target_dicts) + + Note that in this example all feature and target dicts have the same keys. + + * To skip initializing nodes of all types in a specific snapshot simply + pass :obj:`None` instead of a dictionary: + + .. code-block:: python + feature_dicts = [ + {"author": np.array([[0], [0]]), + "paper": np.array([[0], [0], [0]])}, + None, + {"author": np.array([[0.2], [0.2]]), + "paper": np.array([[0.2], [0.2], [0.2]])} + ] + + * To skip initializing node features of type :obj:`"paper"` in a specific snapshot simply + pass :obj:`None` as dictionary value or omit this feature type: + + .. code-block:: python + feature_dicts = [ + {"author": np.array([[0], [0]]), + "paper": np.array([[0], [0], [0]])}, + {"author": np.array([[0.1], [0.1]]), + "paper": None}, # pass None as value + {"author": np.array([[0.2], [0.2]])} # omit type in dict + ] + + Args: + edge_index_dict (Dictionary of keys=Tuples and values=Numpy arrays): Relation type tuples + and their edge index tensors. + edge_weight_dict (Dictionary of keys=Tuples and values=Numpy arrays): Relation type tuples + and their edge weight tensors. + feature_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node + types and their feature tensors. + target_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node + types and their label (target) tensors. + **kwargs (optional List of dictionaries where keys=Strings and values=Numpy arrays): List + of node types and their additional attributes. + """ + + def __init__( + self, + edge_index_dict: Edge_Index, + edge_weight_dict: Edge_Weight, + feature_dicts: Node_Features, + target_dicts: Targets, + **kwargs: Additional_Features + ): + self.edge_index_dict = edge_index_dict + self.edge_weight_dict = edge_weight_dict + self.feature_dicts = feature_dicts + self.target_dicts = target_dicts + self.additional_feature_keys = [] + for key, value in kwargs.items(): + setattr(self, key, value) + self.additional_feature_keys.append(key) + self._check_temporal_consistency() + self._set_snapshot_count() + + def _check_temporal_consistency(self): + assert len(self.feature_dicts) == len( + self.target_dicts + ), "Temporal dimension inconsistency." + for key in self.additional_feature_keys: + assert len(self.target_dicts) == len( + getattr(self, key) + ), "Temporal dimension inconsistency." + + def _set_snapshot_count(self): + self.snapshot_count = len(self.feature_dicts) + + def _get_edge_index(self): + if self.edge_index_dict is None: + return self.edge_index_dict + else: + return {key: torch.LongTensor(value) for key, value in self.edge_index_dict.items()} + + def _get_edge_weight(self): + if self.edge_weight_dict is None: + return self.edge_weight_dict + else: + return {key: torch.FloatTensor(value) for key, value in self.edge_weight_dict.items()} + + def _get_features(self, time_index: int): + if self.feature_dicts[time_index] is None: + return self.feature_dicts[time_index] + else: + return {key: torch.FloatTensor(value) for key, value in self.feature_dicts[time_index].items() + if value is not None} + + def _get_target(self, time_index: int): + if self.target_dicts[time_index] is None: + return self.target_dicts[time_index] + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in self.target_dicts[time_index].items() + if value is not None} + + def _get_additional_feature(self, time_index: int, feature_key: str): + feature = getattr(self, feature_key)[time_index] + if feature is None: + return feature + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in feature.items() + if value is not None} + + def _get_additional_features(self, time_index: int): + additional_features = { + key: self._get_additional_feature(time_index, key) + for key in self.additional_feature_keys + } + return additional_features + + def __get_item__(self, time_index: int): + x_dict = self._get_features(time_index) + edge_index_dict = self._get_edge_index() + edge_weight_dict = self._get_edge_weight() + y_dict = self._get_target(time_index) + additional_features = self._get_additional_features(time_index) + + snapshot = HeteroData() + if x_dict: + for key, value in x_dict.items(): + snapshot[key].x = value + if edge_index_dict: + for key, value in edge_index_dict.items(): + snapshot[key].edge_index = value + if edge_weight_dict: + for key, value in edge_weight_dict.items(): + snapshot[key].edge_attr = value + if y_dict: + for key, value in y_dict.items(): + snapshot[key].y = value + if additional_features: + for feature_name, feature_dict in additional_features.items(): + if feature_dict: + for key, value in feature_dict.items(): + snapshot[key][feature_name] = value + return snapshot + + def __next__(self): + if self.t < len(self.feature_dicts): + snapshot = self.__get_item__(self.t) + self.t = self.t + 1 + return snapshot + else: + self.t = 0 + raise StopIteration + + def __iter__(self): + self.t = 0 + return self From 57f4422a4ad4dd78a06813b87d8bc58786de5e11 Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Mon, 10 Jan 2022 09:20:12 +0100 Subject: [PATCH 02/10] Updated torch_geometric version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 038b23cc..2f613ceb 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ "torch_scatter", "torch_cluster", "torch_spline_conv", - "torch_geometric==1.7.0", + "torch_geometric==2.0.0", "numpy", "scipy", "tqdm", From 62a86b46bbb3a3cf19e246bbd88c5906588212aa Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Thu, 20 Jan 2022 14:09:16 +0100 Subject: [PATCH 03/10] Static hetero graph temporla signal batch This includes: - a new data structure, that represents PyG HeteroDataBatch snapshots for static heterogeneous graphs with temporal signals - documentation - tests --- docs/source/modules/signal.rst | 7 + test/batch_test.py | 35 ++++ torch_geometric_temporal/signal/__init__.py | 1 + ...atic_hetero_graph_temporal_signal_batch.py | 162 ++++++++++++++++++ 4 files changed, 205 insertions(+) create mode 100644 torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py diff --git a/docs/source/modules/signal.rst b/docs/source/modules/signal.rst index 50c9ffc3..88b47332 100644 --- a/docs/source/modules/signal.rst +++ b/docs/source/modules/signal.rst @@ -41,6 +41,13 @@ Temporal Signal Batch Iterators :members: :undoc-members: +Heterogeneous Temporal Signal Batch Iterators +------------------------------- + +.. automodule:: torch_geometric_temporal.signal.static_hetero_graph_temporal_signal_batch + :members: + :undoc-members: + Temporal Signal Train-Test Split -------------------------------- diff --git a/test/batch_test.py b/test/batch_test.py index dbf6fd10..3191c654 100644 --- a/test/batch_test.py +++ b/test/batch_test.py @@ -9,6 +9,8 @@ from torch_geometric_temporal.signal import DynamicGraphTemporalSignalBatch from torch_geometric_temporal.signal import DynamicGraphStaticSignalBatch +from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignalBatch + def get_edge_array(node_count, node_start): edges = [] @@ -83,6 +85,17 @@ def test_static_graph_temporal_signal_batch(): assert snapshot.batch is None +def test_static_hetero_graph_temporal_signal_batch(): + dataset = StaticHeteroGraphTemporalSignalBatch( + None, None, [None, None], [None, None], None + ) + for snapshot in dataset: + assert len(snapshot.node_types) == 0 + assert len(snapshot.node_stores) == 0 + assert len(snapshot.edge_types) == 0 + assert len(snapshot.edge_stores) == 0 + + def test_dynamic_graph_temporal_signal_batch(): dataset = DynamicGraphTemporalSignalBatch( [None, None], [None, None], [None, None], [None, None], [None, None] @@ -107,6 +120,18 @@ def test_static_graph_temporal_signal_typing_batch(): assert snapshot.batch is None +def test_static_hetero_graph_temporal_signal_typing_batch(): + dataset = StaticHeteroGraphTemporalSignalBatch( + None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], None + ) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert 'batch' not in list(dict(snapshot.node_stores[0]).keys()) + assert len(snapshot.edge_types) == 0 + + def test_dynamic_graph_static_signal_typing_batch(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None]) for snapshot in dataset: @@ -135,6 +160,16 @@ def test_static_graph_temporal_signal_batch_additional_attrs(): assert snapshot.optional2.shape == (1,) +def test_static_hetero_graph_temporal_signal_batch_additional_attrs(): + dataset = StaticHeteroGraphTemporalSignalBatch(None, None, [None], [None], None, + optional1=[{'author': np.array([1])}], + optional2=[{'author': np.array([2])}]) + assert dataset.additional_feature_keys == ["optional1", "optional2"] + for snapshot in dataset: + assert snapshot.node_stores[0]['optional1'].shape == (1,) + assert snapshot.node_stores[0]['optional2'].shape == (1,) + + def test_dynamic_graph_static_signal_batch_additional_attrs(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None], optional1=[np.array([1])], optional2=[np.array([2])]) diff --git a/torch_geometric_temporal/signal/__init__.py b/torch_geometric_temporal/signal/__init__.py index ea226485..59e4d7dc 100644 --- a/torch_geometric_temporal/signal/__init__.py +++ b/torch_geometric_temporal/signal/__init__.py @@ -8,5 +8,6 @@ from .dynamic_graph_static_signal_batch import * from .static_hetero_graph_temporal_signal import * +from .static_hetero_graph_temporal_signal_batch import * from .train_test_split import * diff --git a/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py new file mode 100644 index 00000000..3516410c --- /dev/null +++ b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py @@ -0,0 +1,162 @@ +import torch +import numpy as np +from typing import List, Dict, Union, Tuple +from torch_geometric.data import Batch, HeteroData + +Edge_Index = Union[Dict[Tuple[str, str, str], np.ndarray], None] +Edge_Weight = Union[Dict[Tuple[str, str, str], np.ndarray], None] +Node_Features = List[Union[Dict[str, np.ndarray], None]] +Targets = List[Union[Dict[str, np.ndarray], None]] +Batches = Union[Dict[str, np.ndarray], None] +Additional_Features = List[Union[Dict[str, np.ndarray], None]] + + +class StaticHeteroGraphTemporalSignalBatch(object): + r"""A data iterator object to contain a static heterogeneous graph with a dynamically + changing constant time difference temporal feature set (multiple signals). + The node labels (target) are also temporal. The iterator returns a single + constant time difference temporal snapshot for a time period (e.g. day or week). + This single temporal snapshot is a Pytorch Geometric Batch object. Between two + temporal snapshots the feature matrix, target matrices and optionally passed + attributes might change. However, the underlying graph is the same. + + Args: + edge_index_dict (Dictionary of keys=Tuples and values=Numpy arrays): Relation type tuples + and their edge index tensors. + edge_weight_dict (Dictionary of keys=Tuples and values=Numpy arrays): Relation type tuples + and their edge weight tensors. + feature_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node + types and their feature tensors. + target_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node + types and their label (target) tensors. + batch_dict (Dictionary of keys=Strings and values=Numpy arrays): Batch index tensor of each + node type. + **kwargs (optional List of dictionaries where keys=Strings and values=Numpy arrays): List + of node types and their additional attributes. + """ + + def __init__( + self, + edge_index_dict: Edge_Index, + edge_weight_dict: Edge_Weight, + feature_dicts: Node_Features, + target_dicts: Targets, + batch_dict: Batches, + **kwargs: Additional_Features + ): + self.edge_index_dict = edge_index_dict + self.edge_weight_dict = edge_weight_dict + self.feature_dicts = feature_dicts + self.target_dicts = target_dicts + self.batch_dict = batch_dict + self.additional_feature_keys = [] + for key, value in kwargs.items(): + setattr(self, key, value) + self.additional_feature_keys.append(key) + self._check_temporal_consistency() + self._set_snapshot_count() + + def _check_temporal_consistency(self): + assert len(self.feature_dicts) == len( + self.target_dicts + ), "Temporal dimension inconsistency." + for key in self.additional_feature_keys: + assert len(self.target_dicts) == len( + getattr(self, key) + ), "Temporal dimension inconsistency." + + def _set_snapshot_count(self): + self.snapshot_count = len(self.feature_dicts) + + def _get_edge_index(self): + if self.edge_index_dict is None: + return self.edge_index_dict + else: + return {key: torch.LongTensor(value) for key, value in self.edge_index_dict.items()} + + def _get_batch_index(self): + if self.batch_dict is None: + return self.batch_dict + else: + return {key: torch.LongTensor(value) for key, value in self.batch_dict.items()} + + def _get_edge_weight(self): + if self.edge_weight_dict is None: + return self.edge_weight_dict + else: + return {key: torch.FloatTensor(value) for key, value in self.edge_weight_dict.items()} + + def _get_features(self, time_index: int): + if self.feature_dicts[time_index] is None: + return self.feature_dicts[time_index] + else: + return {key: torch.FloatTensor(value) for key, value in self.feature_dicts[time_index].items() + if value is not None} + + def _get_target(self, time_index: int): + if self.target_dicts[time_index] is None: + return self.target_dicts[time_index] + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in self.target_dicts[time_index].items() + if value is not None} + + def _get_additional_feature(self, time_index: int, feature_key: str): + feature = getattr(self, feature_key)[time_index] + if feature is None: + return feature + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in feature.items() + if value is not None} + + def _get_additional_features(self, time_index: int): + additional_features = { + key: self._get_additional_feature(time_index, key) + for key in self.additional_feature_keys + } + return additional_features + + def __get_item__(self, time_index: int): + x_dict = self._get_features(time_index) + edge_index_dict = self._get_edge_index() + edge_weight_dict = self._get_edge_weight() + batch_dict = self._get_batch_index() + y_dict = self._get_target(time_index) + additional_features = self._get_additional_features(time_index) + + snapshot = Batch.from_data_list([HeteroData()]) + if x_dict: + for key, value in x_dict.items(): + snapshot[key].x = value + if edge_index_dict: + for key, value in edge_index_dict.items(): + snapshot[key].edge_index = value + if edge_weight_dict: + for key, value in edge_weight_dict.items(): + snapshot[key].edge_attr = value + if y_dict: + for key, value in y_dict.items(): + snapshot[key].y = value + if batch_dict: + for key, value in batch_dict.items(): + snapshot[key].batch = value + if additional_features: + for feature_name, feature_dict in additional_features.items(): + if feature_dict: + for key, value in feature_dict.items(): + snapshot[key][feature_name] = value + return snapshot + + def __next__(self): + if self.t < len(self.feature_dicts): + snapshot = self.__get_item__(self.t) + self.t = self.t + 1 + return snapshot + else: + self.t = 0 + raise StopIteration + + def __iter__(self): + self.t = 0 + return self From 947d0fe33fda74397cc7b607a297eadcb1787c19 Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Wed, 26 Jan 2022 10:00:52 +0100 Subject: [PATCH 04/10] Additional dataset tests for coverage edge tests edge index tests additional attributes None test --- test/dataset_test.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/test/dataset_test.py b/test/dataset_test.py index 32572aed..8c4d5837 100644 --- a/test/dataset_test.py +++ b/test/dataset_test.py @@ -169,11 +169,27 @@ def test_static_hetero_graph_temporal_signal_typing(): def test_static_hetero_graph_temporal_signal_additional_attrs(): dataset = StaticHeteroGraphTemporalSignal(None, None, [None], [None], optional1=[{'author': np.array([1])}], - optional2=[{'author': np.array([2])}]) - assert dataset.additional_feature_keys == ["optional1", "optional2"] + optional2=[{'author': np.array([2])}], + optional3=[None]) + assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"] for snapshot in dataset: assert snapshot.node_stores[0]['optional1'].shape == (1,) assert snapshot.node_stores[0]['optional2'].shape == (1,) + assert "optional3" not in list(dict(snapshot.node_stores[0]).keys()) + + +def test_static_hetero_graph_temporal_signal_edges(): + dataset = StaticHeteroGraphTemporalSignal({("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}, + {("author", "writes", "paper"): np.array([[0.1], [0.1]])}, + [{"author": np.array([[0], [0]]), + "paper": np.array([[0], [0], [0]])}, + {"author": np.array([[0.1], [0.1]]), + "paper": np.array([[0.1], [0.1], [0.1]])}], + [None, None]) + for snapshot in dataset: + assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2) + assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1) + assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] From 74b18d076e4114412b04a1427b4916e2065b067c Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Wed, 26 Jan 2022 10:10:58 +0100 Subject: [PATCH 05/10] Added batch tests for coverage edge test edge weights test additional attributes None test batch assigned test --- test/batch_test.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/test/batch_test.py b/test/batch_test.py index 3191c654..bafc64e2 100644 --- a/test/batch_test.py +++ b/test/batch_test.py @@ -163,11 +163,13 @@ def test_static_graph_temporal_signal_batch_additional_attrs(): def test_static_hetero_graph_temporal_signal_batch_additional_attrs(): dataset = StaticHeteroGraphTemporalSignalBatch(None, None, [None], [None], None, optional1=[{'author': np.array([1])}], - optional2=[{'author': np.array([2])}]) - assert dataset.additional_feature_keys == ["optional1", "optional2"] + optional2=[{'author': np.array([2])}], + optional3=[None]) + assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"] for snapshot in dataset: assert snapshot.node_stores[0]['optional1'].shape == (1,) assert snapshot.node_stores[0]['optional2'].shape == (1,) + assert "optional3" not in list(dict(snapshot.node_stores[0]).keys()) def test_dynamic_graph_static_signal_batch_additional_attrs(): @@ -179,6 +181,33 @@ def test_dynamic_graph_static_signal_batch_additional_attrs(): assert snapshot.optional2.shape == (1,) +def test_static_hetero_graph_temporal_signal_batch_edges(): + dataset = StaticHeteroGraphTemporalSignalBatch({("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}, + {("author", "writes", "paper"): np.array([[0.1], [0.1]])}, + [{"author": np.array([[0], [0]]), + "paper": np.array([[0], [0], [0]])}, + {"author": np.array([[0.1], [0.1]]), + "paper": np.array([[0.1], [0.1], [0.1]])}], + [None, None], + None) + for snapshot in dataset: + assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2) + assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1) + assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] + + +def test_static_hetero_graph_temporal_signal_batch_assigned(): + dataset = StaticHeteroGraphTemporalSignalBatch( + None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], {'author': np.array([1])} + ) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert snapshot.node_stores[0]['batch'].shape == (1,) + assert len(snapshot.edge_types) == 0 + + def test_discrete_train_test_split_dynamic_batch(): snapshot_count = 250 From f7ae5d0a0e75ae622fdc2648f40efb6d8099eaee Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Thu, 3 Feb 2022 09:42:57 +0100 Subject: [PATCH 06/10] refactor __getitem__ for hetero graphs --- .../signal/static_hetero_graph_temporal_signal.py | 4 ++-- .../signal/static_hetero_graph_temporal_signal_batch.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal.py b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal.py index 7bf2b7c7..d983324d 100644 --- a/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal.py +++ b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal.py @@ -162,7 +162,7 @@ def _get_additional_features(self, time_index: int): } return additional_features - def __get_item__(self, time_index: int): + def __getitem__(self, time_index: int): x_dict = self._get_features(time_index) edge_index_dict = self._get_edge_index() edge_weight_dict = self._get_edge_weight() @@ -191,7 +191,7 @@ def __get_item__(self, time_index: int): def __next__(self): if self.t < len(self.feature_dicts): - snapshot = self.__get_item__(self.t) + snapshot = self[self.t] self.t = self.t + 1 return snapshot else: diff --git a/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py index 3516410c..69a04721 100644 --- a/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py +++ b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py @@ -117,7 +117,7 @@ def _get_additional_features(self, time_index: int): } return additional_features - def __get_item__(self, time_index: int): + def __getitem__(self, time_index: int): x_dict = self._get_features(time_index) edge_index_dict = self._get_edge_index() edge_weight_dict = self._get_edge_weight() @@ -150,7 +150,7 @@ def __get_item__(self, time_index: int): def __next__(self): if self.t < len(self.feature_dicts): - snapshot = self.__get_item__(self.t) + snapshot = self[self.t] self.t = self.t + 1 return snapshot else: From 8e4ddd12e2f96a27743b98cba46ea9a092ef13ac Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Thu, 3 Feb 2022 09:44:52 +0100 Subject: [PATCH 07/10] dynamic hetero graph static signal - new data structure that works like DynamicGraphStaticSignal but with PyG HeteroData objects - tests - documentation --- docs/source/modules/signal.rst | 4 + test/dataset_test.py | 43 +++++ torch_geometric_temporal/signal/__init__.py | 2 + .../dynamic_hetero_graph_static_signal.py | 154 ++++++++++++++++++ 4 files changed, 203 insertions(+) create mode 100644 torch_geometric_temporal/signal/dynamic_hetero_graph_static_signal.py diff --git a/docs/source/modules/signal.rst b/docs/source/modules/signal.rst index 88b47332..1d450023 100644 --- a/docs/source/modules/signal.rst +++ b/docs/source/modules/signal.rst @@ -26,6 +26,10 @@ Heterogeneous Temporal Signal Iterators :members: :undoc-members: +.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_static_signal + :members: + :undoc-members: + Temporal Signal Batch Iterators ------------------------------- diff --git a/test/dataset_test.py b/test/dataset_test.py index 8c4d5837..efa40c8d 100644 --- a/test/dataset_test.py +++ b/test/dataset_test.py @@ -8,6 +8,7 @@ from torch_geometric_temporal.signal import DynamicGraphStaticSignal from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignal +from torch_geometric_temporal.signal import DynamicHeteroGraphStaticSignal from torch_geometric_temporal.dataset import METRLADatasetLoader, PemsBayDatasetLoader from torch_geometric_temporal.dataset import ( @@ -192,6 +193,48 @@ def test_static_hetero_graph_temporal_signal_edges(): assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] +def test_dynamic_hetero_graph_static_signal(): + dataset = DynamicHeteroGraphStaticSignal([None], [None], None, [None]) + for snapshot in dataset: + assert len(snapshot.node_types) == 0 + assert len(snapshot.node_stores) == 0 + assert len(snapshot.edge_types) == 0 + assert len(snapshot.edge_stores) == 0 + + +def test_dynamic_hetero_graph_static_signal_typing(): + dataset = DynamicHeteroGraphStaticSignal([None], [None], {'author': np.array([1])}, [{'author': np.array([2])}]) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert len(snapshot.edge_types) == 0 + + +def test_dynamic_hetero_graph_static_signal_additional_attrs(): + dataset = DynamicHeteroGraphStaticSignal([None], [None], None, [None], + optional1=[{'author': np.array([1])}], + optional2=[{'author': np.array([2])}], + optional3=[None]) + assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"] + for snapshot in dataset: + assert snapshot.node_stores[0]['optional1'].shape == (1,) + assert snapshot.node_stores[0]['optional2'].shape == (1,) + assert "optional3" not in list(dict(snapshot.node_stores[0]).keys()) + + +def test_dynamic_hetero_graph_static_signal_edges(): + dataset = DynamicHeteroGraphStaticSignal([{("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}], + [{("author", "writes", "paper"): np.array([[0.1], [0.1]])}], + {"author": np.array([[0], [0]]), + "paper": np.array([[0], [0], [0]])}, + [None]) + for snapshot in dataset: + assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2) + assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1) + assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] + + def test_chickenpox(): loader = ChickenpoxDatasetLoader() diff --git a/torch_geometric_temporal/signal/__init__.py b/torch_geometric_temporal/signal/__init__.py index 59e4d7dc..03567543 100644 --- a/torch_geometric_temporal/signal/__init__.py +++ b/torch_geometric_temporal/signal/__init__.py @@ -10,4 +10,6 @@ from .static_hetero_graph_temporal_signal import * from .static_hetero_graph_temporal_signal_batch import * +from .dynamic_hetero_graph_static_signal import * + from .train_test_split import * diff --git a/torch_geometric_temporal/signal/dynamic_hetero_graph_static_signal.py b/torch_geometric_temporal/signal/dynamic_hetero_graph_static_signal.py new file mode 100644 index 00000000..63818d82 --- /dev/null +++ b/torch_geometric_temporal/signal/dynamic_hetero_graph_static_signal.py @@ -0,0 +1,154 @@ +import torch +import numpy as np +from typing import List, Dict, Union, Tuple +from torch_geometric.data import HeteroData + + +Edge_Indices = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]] +Edge_Weights = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]] +Node_Feature = Union[Dict[str, np.ndarray], None] +Targets = List[Union[Dict[str, np.ndarray], None]] +Additional_Features = List[Union[Dict[str, np.ndarray], None]] + + +class DynamicHeteroGraphStaticSignal(object): + r"""A data iterator object to contain a dynamic heterogeneous graph with a + changing edge set and weights. The node labels + (target) are also dynamic. The iterator returns a single discrete temporal + snapshot for a time period (e.g. day or week). This single snapshot is a + Pytorch Geometric HeteroData object. Between two temporal snapshots the edges, + edge weights, target matrices and optionally passed attributes might change. + + Args: + edge_index_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays): + List of relation type tuples and their edge index tensors. + edge_weight_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays): + List of relation type tuples and their edge weight tensors. + feature_dict (Dictionary of keys=Strings and values=Numpy arrays): Node type tuples + and their node feature tensor. + target_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): + List of node types and their label (target) tensors. + **kwargs (optional List of dictionaries where keys=Strings and values=Numpy arrays): List + of node types and their additional attributes. + """ + + def __init__( + self, + edge_index_dicts: Edge_Indices, + edge_weight_dicts: Edge_Weights, + feature_dict: Node_Feature, + target_dicts: Targets, + **kwargs: Additional_Features + ): + self.edge_index_dicts = edge_index_dicts + self.edge_weight_dicts = edge_weight_dicts + self.feature_dict = feature_dict + self.target_dicts = target_dicts + self.additional_feature_keys = [] + for key, value in kwargs.items(): + setattr(self, key, value) + self.additional_feature_keys.append(key) + self._check_temporal_consistency() + self._set_snapshot_count() + + def _check_temporal_consistency(self): + assert len(self.edge_index_dicts) == len( + self.edge_weight_dicts + ), "Temporal dimension inconsistency." + assert len(self.target_dicts) == len( + self.edge_index_dicts + ), "Temporal dimension inconsistency." + for key in self.additional_feature_keys: + assert len(self.target_dicts) == len( + getattr(self, key) + ), "Temporal dimension inconsistency." + + def _set_snapshot_count(self): + self.snapshot_count = len(self.target_dicts) + + def _get_edge_index(self, time_index: int): + if self.edge_index_dicts[time_index] is None: + return self.edge_index_dicts[time_index] + else: + return {key: torch.LongTensor(value) for key, value in self.edge_index_dicts[time_index].items() + if value is not None} + + def _get_edge_weight(self, time_index: int): + if self.edge_weight_dicts[time_index] is None: + return self.edge_weight_dicts[time_index] + else: + return {key: torch.FloatTensor(value) for key, value in self.edge_weight_dicts[time_index].items() + if value is not None} + + def _get_feature(self): + if self.feature_dict is None: + return self.feature_dict + else: + return {key: torch.FloatTensor(value) for key, value in self.feature_dict.items()} + + def _get_target(self, time_index: int): + if self.target_dicts[time_index] is None: + return self.target_dicts[time_index] + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in self.target_dicts[time_index].items() + if value is not None} + + def _get_additional_feature(self, time_index: int, feature_key: str): + feature = getattr(self, feature_key)[time_index] + if feature is None: + return feature + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in feature.items() + if value is not None} + + def _get_additional_features(self, time_index: int): + additional_features = { + key: self._get_additional_feature(time_index, key) + for key in self.additional_feature_keys + } + return additional_features + + def __len__(self): + return len(self.target_dicts) + + def __getitem__(self, time_index: int): + x_dict = self._get_feature() + edge_index_dict = self._get_edge_index(time_index) + edge_weight_dict = self._get_edge_weight(time_index) + y_dict = self._get_target(time_index) + additional_features = self._get_additional_features(time_index) + + snapshot = HeteroData() + if x_dict: + for key, value in x_dict.items(): + snapshot[key].x = value + if edge_index_dict: + for key, value in edge_index_dict.items(): + snapshot[key].edge_index = value + if edge_weight_dict: + for key, value in edge_weight_dict.items(): + snapshot[key].edge_attr = value + if y_dict: + for key, value in y_dict.items(): + snapshot[key].y = value + if additional_features: + for feature_name, feature_dict in additional_features.items(): + if feature_dict: + for key, value in feature_dict.items(): + snapshot[key][feature_name] = value + return snapshot + + def __next__(self): + if self.t < len(self.target_dicts): + snapshot = self[self.t] + self.t = self.t + 1 + return snapshot + else: + self.t = 0 + raise StopIteration + + def __iter__(self): + self.t = 0 + return self From 591ca3b7f4d7bef1433057c922b374edc9790309 Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Thu, 3 Feb 2022 10:09:13 +0100 Subject: [PATCH 08/10] dynamic hetero graph static signal batch - data structure that works like DynamicGraphStaticSignal but with PyG HeteroData Batches - tests - documentation --- docs/source/modules/signal.rst | 4 + test/batch_test.py | 61 +++++++ torch_geometric_temporal/signal/__init__.py | 1 + ...ynamic_hetero_graph_static_signal_batch.py | 171 ++++++++++++++++++ 4 files changed, 237 insertions(+) create mode 100644 torch_geometric_temporal/signal/dynamic_hetero_graph_static_signal_batch.py diff --git a/docs/source/modules/signal.rst b/docs/source/modules/signal.rst index 1d450023..d23e70b5 100644 --- a/docs/source/modules/signal.rst +++ b/docs/source/modules/signal.rst @@ -52,6 +52,10 @@ Heterogeneous Temporal Signal Batch Iterators :members: :undoc-members: +.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_static_signal_batch + :members: + :undoc-members: + Temporal Signal Train-Test Split -------------------------------- diff --git a/test/batch_test.py b/test/batch_test.py index bafc64e2..8144fd6f 100644 --- a/test/batch_test.py +++ b/test/batch_test.py @@ -10,6 +10,7 @@ from torch_geometric_temporal.signal import DynamicGraphStaticSignalBatch from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignalBatch +from torch_geometric_temporal.signal import DynamicHeteroGraphStaticSignalBatch def get_edge_array(node_count, node_start): @@ -96,6 +97,17 @@ def test_static_hetero_graph_temporal_signal_batch(): assert len(snapshot.edge_stores) == 0 +def test_dynamic_hetero_graph_static_signal_batch(): + dataset = DynamicHeteroGraphStaticSignalBatch( + [None], [None], None, [None], [None] + ) + for snapshot in dataset: + assert len(snapshot.node_types) == 0 + assert len(snapshot.node_stores) == 0 + assert len(snapshot.edge_types) == 0 + assert len(snapshot.edge_stores) == 0 + + def test_dynamic_graph_temporal_signal_batch(): dataset = DynamicGraphTemporalSignalBatch( [None, None], [None, None], [None, None], [None, None], [None, None] @@ -132,6 +144,18 @@ def test_static_hetero_graph_temporal_signal_typing_batch(): assert len(snapshot.edge_types) == 0 +def test_dynamic_hetero_graph_static_signal_typing_batch(): + dataset = DynamicHeteroGraphStaticSignalBatch( + [None], [None], {'author': np.array([1])}, [{'author': np.array([2])}], [None] + ) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert 'batch' not in list(dict(snapshot.node_stores[0]).keys()) + assert len(snapshot.edge_types) == 0 + + def test_dynamic_graph_static_signal_typing_batch(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None]) for snapshot in dataset: @@ -172,6 +196,18 @@ def test_static_hetero_graph_temporal_signal_batch_additional_attrs(): assert "optional3" not in list(dict(snapshot.node_stores[0]).keys()) +def test_dynamic_hetero_graph_static_signal_batch_additional_attrs(): + dataset = DynamicHeteroGraphStaticSignalBatch([None], [None], None, [None], [None], + optional1=[{'author': np.array([1])}], + optional2=[{'author': np.array([2])}], + optional3=[None]) + assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"] + for snapshot in dataset: + assert snapshot.node_stores[0]['optional1'].shape == (1,) + assert snapshot.node_stores[0]['optional2'].shape == (1,) + assert "optional3" not in list(dict(snapshot.node_stores[0]).keys()) + + def test_dynamic_graph_static_signal_batch_additional_attrs(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None], optional1=[np.array([1])], optional2=[np.array([2])]) @@ -196,6 +232,19 @@ def test_static_hetero_graph_temporal_signal_batch_edges(): assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] +def test_dynamic_hetero_graph_static_signal_batch_edges(): + dataset = DynamicHeteroGraphStaticSignalBatch([{("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}], + [{("author", "writes", "paper"): np.array([[0.1], [0.1]])}], + {"author": np.array([[0], [0]]), + "paper": np.array([[0], [0], [0]])}, + [None], + [None]) + for snapshot in dataset: + assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2) + assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1) + assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] + + def test_static_hetero_graph_temporal_signal_batch_assigned(): dataset = StaticHeteroGraphTemporalSignalBatch( None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], {'author': np.array([1])} @@ -208,6 +257,18 @@ def test_static_hetero_graph_temporal_signal_batch_assigned(): assert len(snapshot.edge_types) == 0 +def test_dynamic_hetero_graph_static_signal_batch_assigned(): + dataset = DynamicHeteroGraphStaticSignalBatch( + [None], [None], {'author': np.array([1])}, [{'author': np.array([2])}], [{'author': np.array([1])}] + ) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert snapshot.node_stores[0]['batch'].shape == (1,) + assert len(snapshot.edge_types) == 0 + + def test_discrete_train_test_split_dynamic_batch(): snapshot_count = 250 diff --git a/torch_geometric_temporal/signal/__init__.py b/torch_geometric_temporal/signal/__init__.py index 03567543..fabb5ca3 100644 --- a/torch_geometric_temporal/signal/__init__.py +++ b/torch_geometric_temporal/signal/__init__.py @@ -11,5 +11,6 @@ from .static_hetero_graph_temporal_signal_batch import * from .dynamic_hetero_graph_static_signal import * +from .dynamic_hetero_graph_static_signal_batch import * from .train_test_split import * diff --git a/torch_geometric_temporal/signal/dynamic_hetero_graph_static_signal_batch.py b/torch_geometric_temporal/signal/dynamic_hetero_graph_static_signal_batch.py new file mode 100644 index 00000000..b08d75e3 --- /dev/null +++ b/torch_geometric_temporal/signal/dynamic_hetero_graph_static_signal_batch.py @@ -0,0 +1,171 @@ +import torch +import numpy as np +from typing import List, Dict, Union, Tuple +from torch_geometric.data import HeteroData, Batch + + +Edge_Indices = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]] +Edge_Weights = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]] +Node_Feature = Union[Dict[str, np.ndarray], None] +Targets = List[Union[Dict[str, np.ndarray], None]] +Batches = List[Union[Dict[str, np.ndarray], None]] +Additional_Features = List[Union[Dict[str, np.ndarray], None]] + + +class DynamicHeteroGraphStaticSignalBatch(object): + r"""A batch iterator object to contain a dynamic heterogeneous graph with a + changing edge set and weights . The node labels + (target) are also dynamic. The iterator returns a single discrete temporal + snapshot for a time period (e.g. day or week). This single snapshot is a + Pytorch Geometric Batch object. Between two temporal snapshots the edges, + batch memberships, edge weights, target matrices and optionally passed + attributes might change. + + Args: + edge_index_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays): + List of relation type tuples and their edge index tensors. + edge_weight_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays): + List of relation type tuples and their edge weight tensors. + feature_dict (Dictionary of keys=Strings and values=Numpy arrays): Node type tuples + and their node feature tensor. + target_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): + List of node types and their label (target) tensors. + batch_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): + List of batch index tensor for each node type. + **kwargs (optional List of dictionaries where keys=Strings and values=Numpy arrays): List + of node types and their additional attributes. + """ + + def __init__( + self, + edge_index_dicts: Edge_Indices, + edge_weight_dicts: Edge_Weights, + feature_dict: Node_Feature, + target_dicts: Targets, + batch_dicts: Batches, + **kwargs: Additional_Features + ): + self.edge_index_dicts = edge_index_dicts + self.edge_weight_dicts = edge_weight_dicts + self.feature_dict = feature_dict + self.target_dicts = target_dicts + self.batch_dicts = batch_dicts + self.additional_feature_keys = [] + for key, value in kwargs.items(): + setattr(self, key, value) + self.additional_feature_keys.append(key) + self._check_temporal_consistency() + self._set_snapshot_count() + + def _check_temporal_consistency(self): + assert len(self.edge_index_dicts) == len( + self.edge_weight_dicts + ), "Temporal dimension inconsistency." + assert len(self.target_dicts) == len( + self.edge_index_dicts + ), "Temporal dimension inconsistency." + assert len(self.batch_dicts) == len( + self.edge_index_dicts + ), "Temporal dimension inconsistency." + for key in self.additional_feature_keys: + assert len(self.target_dicts) == len( + getattr(self, key) + ), "Temporal dimension inconsistency." + + def _set_snapshot_count(self): + self.snapshot_count = len(self.target_dicts) + + def _get_edge_index(self, time_index: int): + if self.edge_index_dicts[time_index] is None: + return self.edge_index_dicts[time_index] + else: + return {key: torch.LongTensor(value) for key, value in self.edge_index_dicts[time_index].items() + if value is not None} + + def _get_batch_index(self, time_index: int): + if self.batch_dicts[time_index] is None: + return self.batch_dicts[time_index] + else: + return {key: torch.LongTensor(value) for key, value in self.batch_dicts[time_index].items() + if value is not None} + + def _get_edge_weight(self, time_index: int): + if self.edge_weight_dicts[time_index] is None: + return self.edge_weight_dicts[time_index] + else: + return {key: torch.FloatTensor(value) for key, value in self.edge_weight_dicts[time_index].items() + if value is not None} + + def _get_feature(self): + if self.feature_dict is None: + return self.feature_dict + else: + return {key: torch.FloatTensor(value) for key, value in self.feature_dict.items()} + + def _get_target(self, time_index: int): + if self.target_dicts[time_index] is None: + return self.target_dicts[time_index] + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in self.target_dicts[time_index].items() + if value is not None} + + def _get_additional_feature(self, time_index: int, feature_key: str): + feature = getattr(self, feature_key)[time_index] + if feature is None: + return feature + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in feature.items() + if value is not None} + + def _get_additional_features(self, time_index: int): + additional_features = { + key: self._get_additional_feature(time_index, key) + for key in self.additional_feature_keys + } + return additional_features + + def __getitem__(self, time_index: int): + x_dict = self._get_feature() + edge_index_dict = self._get_edge_index(time_index) + edge_weight_dict = self._get_edge_weight(time_index) + batch_dict = self._get_batch_index(time_index) + y_dict = self._get_target(time_index) + additional_features = self._get_additional_features(time_index) + + snapshot = Batch.from_data_list([HeteroData()]) + if x_dict: + for key, value in x_dict.items(): + snapshot[key].x = value + if edge_index_dict: + for key, value in edge_index_dict.items(): + snapshot[key].edge_index = value + if edge_weight_dict: + for key, value in edge_weight_dict.items(): + snapshot[key].edge_attr = value + if y_dict: + for key, value in y_dict.items(): + snapshot[key].y = value + if batch_dict: + for key, value in batch_dict.items(): + snapshot[key].batch = value + if additional_features: + for feature_name, feature_dict in additional_features.items(): + if feature_dict: + for key, value in feature_dict.items(): + snapshot[key][feature_name] = value + return snapshot + + def __next__(self): + if self.t < len(self.target_dicts): + snapshot = self[self.t] + self.t = self.t + 1 + return snapshot + else: + self.t = 0 + raise StopIteration + + def __iter__(self): + self.t = 0 + return self From c354f9b25977f1398224481a7f56067d38ae5c54 Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Thu, 3 Feb 2022 10:31:16 +0100 Subject: [PATCH 09/10] dynamic hetero graph temporal signal - new data structure that works like DynamicGraphTemporalSignal but with PyG HeteroData objects - tests -documentation --- docs/source/modules/signal.rst | 4 + test/dataset_test.py | 45 +++++ torch_geometric_temporal/signal/__init__.py | 2 + .../dynamic_hetero_graph_temporal_signal.py | 155 ++++++++++++++++++ 4 files changed, 206 insertions(+) create mode 100644 torch_geometric_temporal/signal/dynamic_hetero_graph_temporal_signal.py diff --git a/docs/source/modules/signal.rst b/docs/source/modules/signal.rst index d23e70b5..481614c4 100644 --- a/docs/source/modules/signal.rst +++ b/docs/source/modules/signal.rst @@ -26,6 +26,10 @@ Heterogeneous Temporal Signal Iterators :members: :undoc-members: +.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_temporal_signal + :members: + :undoc-members: + .. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_static_signal :members: :undoc-members: diff --git a/test/dataset_test.py b/test/dataset_test.py index efa40c8d..1b6ecab7 100644 --- a/test/dataset_test.py +++ b/test/dataset_test.py @@ -8,6 +8,7 @@ from torch_geometric_temporal.signal import DynamicGraphStaticSignal from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignal +from torch_geometric_temporal.signal import DynamicHeteroGraphTemporalSignal from torch_geometric_temporal.signal import DynamicHeteroGraphStaticSignal from torch_geometric_temporal.dataset import METRLADatasetLoader, PemsBayDatasetLoader @@ -235,6 +236,50 @@ def test_dynamic_hetero_graph_static_signal_edges(): assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] +def test_dynamic_hetero_graph_temporal_signal(): + dataset = DynamicHeteroGraphTemporalSignal( + [None, None], [None, None], [None, None], [None, None] + ) + for snapshot in dataset: + assert len(snapshot.node_types) == 0 + assert len(snapshot.node_stores) == 0 + assert len(snapshot.edge_types) == 0 + assert len(snapshot.edge_stores) == 0 + + +def test_dynamic_hetero_graph_temporal_signal_typing(): + dataset = DynamicHeteroGraphTemporalSignal([None], [None], [{'author': np.array([1])}], [{'author': np.array([2])}]) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert len(snapshot.edge_types) == 0 + + +def test_dynamic_hetero_graph_temporal_signal_additional_attrs(): + dataset = DynamicHeteroGraphTemporalSignal([None], [None], [None], [None], + optional1=[{'author': np.array([1])}], + optional2=[{'author': np.array([2])}], + optional3=[None]) + assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"] + for snapshot in dataset: + assert snapshot.node_stores[0]['optional1'].shape == (1,) + assert snapshot.node_stores[0]['optional2'].shape == (1,) + assert "optional3" not in list(dict(snapshot.node_stores[0]).keys()) + + +def test_dynamic_hetero_graph_temporal_signal_edges(): + dataset = DynamicHeteroGraphTemporalSignal([{("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}], + [{("author", "writes", "paper"): np.array([[0.1], [0.1]])}], + [{"author": np.array([[0], [0]]), + "paper": np.array([[0], [0], [0]])}], + [None]) + for snapshot in dataset: + assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2) + assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1) + assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] + + def test_chickenpox(): loader = ChickenpoxDatasetLoader() diff --git a/torch_geometric_temporal/signal/__init__.py b/torch_geometric_temporal/signal/__init__.py index fabb5ca3..70865b58 100644 --- a/torch_geometric_temporal/signal/__init__.py +++ b/torch_geometric_temporal/signal/__init__.py @@ -7,6 +7,8 @@ from .dynamic_graph_static_signal import * from .dynamic_graph_static_signal_batch import * +from .dynamic_hetero_graph_temporal_signal import * + from .static_hetero_graph_temporal_signal import * from .static_hetero_graph_temporal_signal_batch import * diff --git a/torch_geometric_temporal/signal/dynamic_hetero_graph_temporal_signal.py b/torch_geometric_temporal/signal/dynamic_hetero_graph_temporal_signal.py new file mode 100644 index 00000000..41a0892c --- /dev/null +++ b/torch_geometric_temporal/signal/dynamic_hetero_graph_temporal_signal.py @@ -0,0 +1,155 @@ +import torch +import numpy as np +from typing import List, Dict, Union, Tuple +from torch_geometric.data import HeteroData + + +Edge_Indices = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]] +Edge_Weights = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]] +Node_Features = List[Union[Dict[str, np.ndarray], None]] +Targets = List[Union[Dict[str, np.ndarray], None]] +Additional_Features = List[Union[Dict[str, np.ndarray], None]] + + +class DynamicHeteroGraphTemporalSignal(object): + r"""A data iterator object to contain a dynamic heterogeneous graph with a + changing edge set and weights. The feature set and node labels + (target) are also dynamic. The iterator returns a single discrete temporal + snapshot for a time period (e.g. day or week). This single snapshot is a + Pytorch Geometric HeteroData object. Between two temporal snapshots the edges, + edge weights, target matrices and optionally passed attributes might change. + + Args: + edge_index_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays): + List of relation type tuples and their edge index tensors. + edge_weight_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays): + List of relation type tuples and their edge weight tensors. + feature_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node + types and their feature tensors. + target_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node + types and their label (target) tensors. + **kwargs (optional List of dictionaries where keys=Strings and values=Numpy arrays): List + of node types and their additional attributes. + """ + + def __init__( + self, + edge_index_dicts: Edge_Indices, + edge_weight_dicts: Edge_Weights, + feature_dicts: Node_Features, + target_dicts: Targets, + **kwargs: Additional_Features + ): + self.edge_index_dicts = edge_index_dicts + self.edge_weight_dicts = edge_weight_dicts + self.feature_dicts = feature_dicts + self.target_dicts = target_dicts + self.additional_feature_keys = [] + for key, value in kwargs.items(): + setattr(self, key, value) + self.additional_feature_keys.append(key) + self._check_temporal_consistency() + self._set_snapshot_count() + + def _check_temporal_consistency(self): + assert len(self.feature_dicts) == len( + self.target_dicts + ), "Temporal dimension inconsistency." + assert len(self.edge_index_dicts) == len( + self.edge_weight_dicts + ), "Temporal dimension inconsistency." + assert len(self.feature_dicts) == len( + self.edge_weight_dicts + ), "Temporal dimension inconsistency." + for key in self.additional_feature_keys: + assert len(self.target_dicts) == len( + getattr(self, key) + ), "Temporal dimension inconsistency." + + def _set_snapshot_count(self): + self.snapshot_count = len(self.feature_dicts) + + def _get_edge_index(self, time_index: int): + if self.edge_index_dicts[time_index] is None: + return self.edge_index_dicts[time_index] + else: + return {key: torch.LongTensor(value) for key, value in self.edge_index_dicts[time_index].items() + if value is not None} + + def _get_edge_weight(self, time_index: int): + if self.edge_weight_dicts[time_index] is None: + return self.edge_weight_dicts[time_index] + else: + return {key: torch.FloatTensor(value) for key, value in self.edge_weight_dicts[time_index].items() + if value is not None} + + def _get_features(self, time_index: int): + if self.feature_dicts[time_index] is None: + return self.feature_dicts[time_index] + else: + return {key: torch.FloatTensor(value) for key, value in self.feature_dicts[time_index].items() + if value is not None} + + def _get_target(self, time_index: int): + if self.target_dicts[time_index] is None: + return self.target_dicts[time_index] + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in self.target_dicts[time_index].items() + if value is not None} + + def _get_additional_feature(self, time_index: int, feature_key: str): + feature = getattr(self, feature_key)[time_index] + if feature is None: + return feature + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in feature.items() + if value is not None} + + def _get_additional_features(self, time_index: int): + additional_features = { + key: self._get_additional_feature(time_index, key) + for key in self.additional_feature_keys + } + return additional_features + + def __getitem__(self, time_index): + x_dict = self._get_features(time_index) + edge_index_dict = self._get_edge_index(time_index) + edge_weight_dict = self._get_edge_weight(time_index) + y_dict = self._get_target(time_index) + additional_features = self._get_additional_features(time_index) + + snapshot = HeteroData() + if x_dict: + for key, value in x_dict.items(): + snapshot[key].x = value + if edge_index_dict: + for key, value in edge_index_dict.items(): + snapshot[key].edge_index = value + if edge_weight_dict: + for key, value in edge_weight_dict.items(): + snapshot[key].edge_attr = value + if y_dict: + for key, value in y_dict.items(): + snapshot[key].y = value + if additional_features: + for feature_name, feature_dict in additional_features.items(): + if feature_dict: + for key, value in feature_dict.items(): + snapshot[key][feature_name] = value + return snapshot + + def __next__(self): + if self.t < len(self.feature_dicts): + snapshot = self[self.t] + self.t = self.t + 1 + return snapshot + else: + self.t = 0 + raise StopIteration + + def __iter__(self): + self.t = 0 + return self From 0437310837b3826acc5152a77508377ba990c2f3 Mon Sep 17 00:00:00 2001 From: Gregor Donabauer Date: Thu, 3 Feb 2022 10:51:42 +0100 Subject: [PATCH 10/10] dynamic hetero graph temporal signal batch - data structure that works like DynamicGraphTemporalSignalBatch but with PyG HeteroData Batches instead - tests -documentation --- docs/source/modules/signal.rst | 4 + test/batch_test.py | 61 ++++++ torch_geometric_temporal/signal/__init__.py | 1 + ...amic_hetero_graph_temporal_signal_batch.py | 175 ++++++++++++++++++ 4 files changed, 241 insertions(+) create mode 100644 torch_geometric_temporal/signal/dynamic_hetero_graph_temporal_signal_batch.py diff --git a/docs/source/modules/signal.rst b/docs/source/modules/signal.rst index 481614c4..699dc9a7 100644 --- a/docs/source/modules/signal.rst +++ b/docs/source/modules/signal.rst @@ -56,6 +56,10 @@ Heterogeneous Temporal Signal Batch Iterators :members: :undoc-members: +.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_temporal_signal_batch + :members: + :undoc-members: + .. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_static_signal_batch :members: :undoc-members: diff --git a/test/batch_test.py b/test/batch_test.py index 8144fd6f..a5bcf4b6 100644 --- a/test/batch_test.py +++ b/test/batch_test.py @@ -10,6 +10,7 @@ from torch_geometric_temporal.signal import DynamicGraphStaticSignalBatch from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignalBatch +from torch_geometric_temporal.signal import DynamicHeteroGraphTemporalSignalBatch from torch_geometric_temporal.signal import DynamicHeteroGraphStaticSignalBatch @@ -108,6 +109,17 @@ def test_dynamic_hetero_graph_static_signal_batch(): assert len(snapshot.edge_stores) == 0 +def test_dynamic_hetero_graph_temporal_signal_batch(): + dataset = DynamicHeteroGraphTemporalSignalBatch( + [None, None], [None, None], [None, None], [None, None], [None, None] + ) + for snapshot in dataset: + assert len(snapshot.node_types) == 0 + assert len(snapshot.node_stores) == 0 + assert len(snapshot.edge_types) == 0 + assert len(snapshot.edge_stores) == 0 + + def test_dynamic_graph_temporal_signal_batch(): dataset = DynamicGraphTemporalSignalBatch( [None, None], [None, None], [None, None], [None, None], [None, None] @@ -156,6 +168,18 @@ def test_dynamic_hetero_graph_static_signal_typing_batch(): assert len(snapshot.edge_types) == 0 +def test_dynamic_hetero_graph_temporal_signal_typing_batch(): + dataset = DynamicHeteroGraphTemporalSignalBatch( + [None], [None], [{'author': np.array([1])}], [{'author': np.array([2])}], [None] + ) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert 'batch' not in list(dict(snapshot.node_stores[0]).keys()) + assert len(snapshot.edge_types) == 0 + + def test_dynamic_graph_static_signal_typing_batch(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None]) for snapshot in dataset: @@ -208,6 +232,18 @@ def test_dynamic_hetero_graph_static_signal_batch_additional_attrs(): assert "optional3" not in list(dict(snapshot.node_stores[0]).keys()) +def test_dynamic_hetero_graph_temporal_signal_batch_additional_attrs(): + dataset = DynamicHeteroGraphTemporalSignalBatch([None], [None], [None], [None], [None], + optional1=[{'author': np.array([1])}], + optional2=[{'author': np.array([2])}], + optional3=[None]) + assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"] + for snapshot in dataset: + assert snapshot.node_stores[0]['optional1'].shape == (1,) + assert snapshot.node_stores[0]['optional2'].shape == (1,) + assert "optional3" not in list(dict(snapshot.node_stores[0]).keys()) + + def test_dynamic_graph_static_signal_batch_additional_attrs(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None], optional1=[np.array([1])], optional2=[np.array([2])]) @@ -245,6 +281,19 @@ def test_dynamic_hetero_graph_static_signal_batch_edges(): assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] +def test_dynamic_hetero_graph_temporal_signal_batch_edges(): + dataset = DynamicHeteroGraphTemporalSignalBatch([{("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}], + [{("author", "writes", "paper"): np.array([[0.1], [0.1]])}], + [{"author": np.array([[0], [0]]), + "paper": np.array([[0], [0], [0]])}], + [None], + [None]) + for snapshot in dataset: + assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2) + assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1) + assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0] + + def test_static_hetero_graph_temporal_signal_batch_assigned(): dataset = StaticHeteroGraphTemporalSignalBatch( None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], {'author': np.array([1])} @@ -269,6 +318,18 @@ def test_dynamic_hetero_graph_static_signal_batch_assigned(): assert len(snapshot.edge_types) == 0 +def test_dynamic_hetero_graph_temporal_signal_batch_assigned(): + dataset = DynamicHeteroGraphTemporalSignalBatch( + [None], [None], [{'author': np.array([1])}], [{'author': np.array([2])}], [{'author': np.array([1])}] + ) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert snapshot.node_stores[0]['batch'].shape == (1,) + assert len(snapshot.edge_types) == 0 + + def test_discrete_train_test_split_dynamic_batch(): snapshot_count = 250 diff --git a/torch_geometric_temporal/signal/__init__.py b/torch_geometric_temporal/signal/__init__.py index 70865b58..3b7c9b3f 100644 --- a/torch_geometric_temporal/signal/__init__.py +++ b/torch_geometric_temporal/signal/__init__.py @@ -8,6 +8,7 @@ from .dynamic_graph_static_signal_batch import * from .dynamic_hetero_graph_temporal_signal import * +from .dynamic_hetero_graph_temporal_signal_batch import * from .static_hetero_graph_temporal_signal import * from .static_hetero_graph_temporal_signal_batch import * diff --git a/torch_geometric_temporal/signal/dynamic_hetero_graph_temporal_signal_batch.py b/torch_geometric_temporal/signal/dynamic_hetero_graph_temporal_signal_batch.py new file mode 100644 index 00000000..b2c836da --- /dev/null +++ b/torch_geometric_temporal/signal/dynamic_hetero_graph_temporal_signal_batch.py @@ -0,0 +1,175 @@ +import torch +import numpy as np +from typing import List, Dict, Union, Tuple +from torch_geometric.data import HeteroData, Batch + + +Edge_Indices = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]] +Edge_Weights = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]] +Node_Features = List[Union[Dict[str, np.ndarray], None]] +Targets = List[Union[Dict[str, np.ndarray], None]] +Batches = List[Union[Dict[str, np.ndarray], None]] +Additional_Features = List[Union[Dict[str, np.ndarray], None]] + + +class DynamicHeteroGraphTemporalSignalBatch(object): + r"""A data iterator object to contain a dynamic heterogeneous graph with a + changing edge set and weights. The feature set and node labels + (target) are also dynamic. The iterator returns a single discrete temporal + snapshot for a time period (e.g. day or week). This single snapshot is a + Pytorch Geometric HeteroData Batch object. Between two temporal snapshots the edges, + edge weights, the feature matrix, target matrices and optionally passed + attributes might change. + + Args: + edge_index_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays): + List of relation type tuples and their edge index tensors. + edge_weight_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays): + List of relation type tuples and their edge weight tensors. + feature_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): + List of node types and their feature tensors. + target_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): + List of node types and their label (target) tensors. + batch_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): + List of batch index tensor for each node type. + **kwargs (optional List of dictionaries where keys=Strings and values=Numpy arrays): + List of node types and their additional attributes. + """ + + def __init__( + self, + edge_index_dicts: Edge_Indices, + edge_weight_dicts: Edge_Weights, + feature_dicts: Node_Features, + target_dicts: Targets, + batch_dicts: Batches, + **kwargs: Additional_Features + ): + self.edge_index_dicts = edge_index_dicts + self.edge_weight_dicts = edge_weight_dicts + self.feature_dicts = feature_dicts + self.target_dicts = target_dicts + self.batch_dicts = batch_dicts + self.additional_feature_keys = [] + for key, value in kwargs.items(): + setattr(self, key, value) + self.additional_feature_keys.append(key) + self._check_temporal_consistency() + self._set_snapshot_count() + + def _check_temporal_consistency(self): + assert len(self.feature_dicts) == len( + self.target_dicts + ), "Temporal dimension inconsistency." + assert len(self.edge_index_dicts) == len( + self.edge_weight_dicts + ), "Temporal dimension inconsistency." + assert len(self.feature_dicts) == len( + self.edge_weight_dicts + ), "Temporal dimension inconsistency." + assert len(self.feature_dicts) == len( + self.batch_dicts + ), "Temporal dimension inconsistency." + for key in self.additional_feature_keys: + assert len(self.target_dicts) == len( + getattr(self, key) + ), "Temporal dimension inconsistency." + + def _set_snapshot_count(self): + self.snapshot_count = len(self.feature_dicts) + + def _get_edge_index(self, time_index: int): + if self.edge_index_dicts[time_index] is None: + return self.edge_index_dicts[time_index] + else: + return {key: torch.LongTensor(value) for key, value in self.edge_index_dicts[time_index].items() + if value is not None} + + def _get_batch_index(self, time_index: int): + if self.batch_dicts[time_index] is None: + return self.batch_dicts[time_index] + else: + return {key: torch.LongTensor(value) for key, value in self.batch_dicts[time_index].items() + if value is not None} + + def _get_edge_weight(self, time_index: int): + if self.edge_weight_dicts[time_index] is None: + return self.edge_weight_dicts[time_index] + else: + return {key: torch.FloatTensor(value) for key, value in self.edge_weight_dicts[time_index].items() + if value is not None} + + def _get_feature(self, time_index: int): + if self.feature_dicts[time_index] is None: + return self.feature_dicts[time_index] + else: + return {key: torch.FloatTensor(value) for key, value in self.feature_dicts[time_index].items() + if value is not None} + + def _get_target(self, time_index: int): + if self.target_dicts[time_index] is None: + return self.target_dicts[time_index] + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in self.target_dicts[time_index].items() + if value is not None} + + def _get_additional_feature(self, time_index: int, feature_key: str): + feature = getattr(self, feature_key)[time_index] + if feature is None: + return feature + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in feature.items() + if value is not None} + + def _get_additional_features(self, time_index: int): + additional_features = { + key: self._get_additional_feature(time_index, key) + for key in self.additional_feature_keys + } + return additional_features + + def __getitem__(self, time_index: int): + x_dict = self._get_feature(time_index) + edge_index_dict = self._get_edge_index(time_index) + edge_weight_dict = self._get_edge_weight(time_index) + batch_dict = self._get_batch_index(time_index) + y_dict = self._get_target(time_index) + additional_features = self._get_additional_features(time_index) + + snapshot = Batch.from_data_list([HeteroData()]) + if x_dict: + for key, value in x_dict.items(): + snapshot[key].x = value + if edge_index_dict: + for key, value in edge_index_dict.items(): + snapshot[key].edge_index = value + if edge_weight_dict: + for key, value in edge_weight_dict.items(): + snapshot[key].edge_attr = value + if y_dict: + for key, value in y_dict.items(): + snapshot[key].y = value + if batch_dict: + for key, value in batch_dict.items(): + snapshot[key].batch = value + if additional_features: + for feature_name, feature_dict in additional_features.items(): + if feature_dict: + for key, value in feature_dict.items(): + snapshot[key][feature_name] = value + return snapshot + + def __next__(self): + if self.t < len(self.feature_dicts): + snapshot = self[self.t] + self.t = self.t + 1 + return snapshot + else: + self.t = 0 + raise StopIteration + + def __iter__(self): + self.t = 0 + return self