diff --git a/docs/source/modules/signal.rst b/docs/source/modules/signal.rst index 50c9ffc3..88b47332 100644 --- a/docs/source/modules/signal.rst +++ b/docs/source/modules/signal.rst @@ -41,6 +41,13 @@ Temporal Signal Batch Iterators :members: :undoc-members: +Heterogeneous Temporal Signal Batch Iterators +------------------------------- + +.. automodule:: torch_geometric_temporal.signal.static_hetero_graph_temporal_signal_batch + :members: + :undoc-members: + Temporal Signal Train-Test Split -------------------------------- diff --git a/test/batch_test.py b/test/batch_test.py index dbf6fd10..3191c654 100644 --- a/test/batch_test.py +++ b/test/batch_test.py @@ -9,6 +9,8 @@ from torch_geometric_temporal.signal import DynamicGraphTemporalSignalBatch from torch_geometric_temporal.signal import DynamicGraphStaticSignalBatch +from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignalBatch + def get_edge_array(node_count, node_start): edges = [] @@ -83,6 +85,17 @@ def test_static_graph_temporal_signal_batch(): assert snapshot.batch is None +def test_static_hetero_graph_temporal_signal_batch(): + dataset = StaticHeteroGraphTemporalSignalBatch( + None, None, [None, None], [None, None], None + ) + for snapshot in dataset: + assert len(snapshot.node_types) == 0 + assert len(snapshot.node_stores) == 0 + assert len(snapshot.edge_types) == 0 + assert len(snapshot.edge_stores) == 0 + + def test_dynamic_graph_temporal_signal_batch(): dataset = DynamicGraphTemporalSignalBatch( [None, None], [None, None], [None, None], [None, None], [None, None] @@ -107,6 +120,18 @@ def test_static_graph_temporal_signal_typing_batch(): assert snapshot.batch is None +def test_static_hetero_graph_temporal_signal_typing_batch(): + dataset = StaticHeteroGraphTemporalSignalBatch( + None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], None + ) + for snapshot in dataset: + assert snapshot.node_types[0] == 'author' + assert snapshot.node_stores[0]['x'].shape == (1,) + assert snapshot.node_stores[0]['y'].shape == (1,) + assert 'batch' not in list(dict(snapshot.node_stores[0]).keys()) + assert len(snapshot.edge_types) == 0 + + def test_dynamic_graph_static_signal_typing_batch(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None]) for snapshot in dataset: @@ -135,6 +160,16 @@ def test_static_graph_temporal_signal_batch_additional_attrs(): assert snapshot.optional2.shape == (1,) +def test_static_hetero_graph_temporal_signal_batch_additional_attrs(): + dataset = StaticHeteroGraphTemporalSignalBatch(None, None, [None], [None], None, + optional1=[{'author': np.array([1])}], + optional2=[{'author': np.array([2])}]) + assert dataset.additional_feature_keys == ["optional1", "optional2"] + for snapshot in dataset: + assert snapshot.node_stores[0]['optional1'].shape == (1,) + assert snapshot.node_stores[0]['optional2'].shape == (1,) + + def test_dynamic_graph_static_signal_batch_additional_attrs(): dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None], optional1=[np.array([1])], optional2=[np.array([2])]) diff --git a/torch_geometric_temporal/signal/__init__.py b/torch_geometric_temporal/signal/__init__.py index ea226485..59e4d7dc 100644 --- a/torch_geometric_temporal/signal/__init__.py +++ b/torch_geometric_temporal/signal/__init__.py @@ -8,5 +8,6 @@ from .dynamic_graph_static_signal_batch import * from .static_hetero_graph_temporal_signal import * +from .static_hetero_graph_temporal_signal_batch import * from .train_test_split import * diff --git a/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py new file mode 100644 index 00000000..3516410c --- /dev/null +++ b/torch_geometric_temporal/signal/static_hetero_graph_temporal_signal_batch.py @@ -0,0 +1,162 @@ +import torch +import numpy as np +from typing import List, Dict, Union, Tuple +from torch_geometric.data import Batch, HeteroData + +Edge_Index = Union[Dict[Tuple[str, str, str], np.ndarray], None] +Edge_Weight = Union[Dict[Tuple[str, str, str], np.ndarray], None] +Node_Features = List[Union[Dict[str, np.ndarray], None]] +Targets = List[Union[Dict[str, np.ndarray], None]] +Batches = Union[Dict[str, np.ndarray], None] +Additional_Features = List[Union[Dict[str, np.ndarray], None]] + + +class StaticHeteroGraphTemporalSignalBatch(object): + r"""A data iterator object to contain a static heterogeneous graph with a dynamically + changing constant time difference temporal feature set (multiple signals). + The node labels (target) are also temporal. The iterator returns a single + constant time difference temporal snapshot for a time period (e.g. day or week). + This single temporal snapshot is a Pytorch Geometric Batch object. Between two + temporal snapshots the feature matrix, target matrices and optionally passed + attributes might change. However, the underlying graph is the same. + + Args: + edge_index_dict (Dictionary of keys=Tuples and values=Numpy arrays): Relation type tuples + and their edge index tensors. + edge_weight_dict (Dictionary of keys=Tuples and values=Numpy arrays): Relation type tuples + and their edge weight tensors. + feature_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node + types and their feature tensors. + target_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node + types and their label (target) tensors. + batch_dict (Dictionary of keys=Strings and values=Numpy arrays): Batch index tensor of each + node type. + **kwargs (optional List of dictionaries where keys=Strings and values=Numpy arrays): List + of node types and their additional attributes. + """ + + def __init__( + self, + edge_index_dict: Edge_Index, + edge_weight_dict: Edge_Weight, + feature_dicts: Node_Features, + target_dicts: Targets, + batch_dict: Batches, + **kwargs: Additional_Features + ): + self.edge_index_dict = edge_index_dict + self.edge_weight_dict = edge_weight_dict + self.feature_dicts = feature_dicts + self.target_dicts = target_dicts + self.batch_dict = batch_dict + self.additional_feature_keys = [] + for key, value in kwargs.items(): + setattr(self, key, value) + self.additional_feature_keys.append(key) + self._check_temporal_consistency() + self._set_snapshot_count() + + def _check_temporal_consistency(self): + assert len(self.feature_dicts) == len( + self.target_dicts + ), "Temporal dimension inconsistency." + for key in self.additional_feature_keys: + assert len(self.target_dicts) == len( + getattr(self, key) + ), "Temporal dimension inconsistency." + + def _set_snapshot_count(self): + self.snapshot_count = len(self.feature_dicts) + + def _get_edge_index(self): + if self.edge_index_dict is None: + return self.edge_index_dict + else: + return {key: torch.LongTensor(value) for key, value in self.edge_index_dict.items()} + + def _get_batch_index(self): + if self.batch_dict is None: + return self.batch_dict + else: + return {key: torch.LongTensor(value) for key, value in self.batch_dict.items()} + + def _get_edge_weight(self): + if self.edge_weight_dict is None: + return self.edge_weight_dict + else: + return {key: torch.FloatTensor(value) for key, value in self.edge_weight_dict.items()} + + def _get_features(self, time_index: int): + if self.feature_dicts[time_index] is None: + return self.feature_dicts[time_index] + else: + return {key: torch.FloatTensor(value) for key, value in self.feature_dicts[time_index].items() + if value is not None} + + def _get_target(self, time_index: int): + if self.target_dicts[time_index] is None: + return self.target_dicts[time_index] + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in self.target_dicts[time_index].items() + if value is not None} + + def _get_additional_feature(self, time_index: int, feature_key: str): + feature = getattr(self, feature_key)[time_index] + if feature is None: + return feature + else: + return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value) + if value.dtype.kind == "i" else value for key, value in feature.items() + if value is not None} + + def _get_additional_features(self, time_index: int): + additional_features = { + key: self._get_additional_feature(time_index, key) + for key in self.additional_feature_keys + } + return additional_features + + def __get_item__(self, time_index: int): + x_dict = self._get_features(time_index) + edge_index_dict = self._get_edge_index() + edge_weight_dict = self._get_edge_weight() + batch_dict = self._get_batch_index() + y_dict = self._get_target(time_index) + additional_features = self._get_additional_features(time_index) + + snapshot = Batch.from_data_list([HeteroData()]) + if x_dict: + for key, value in x_dict.items(): + snapshot[key].x = value + if edge_index_dict: + for key, value in edge_index_dict.items(): + snapshot[key].edge_index = value + if edge_weight_dict: + for key, value in edge_weight_dict.items(): + snapshot[key].edge_attr = value + if y_dict: + for key, value in y_dict.items(): + snapshot[key].y = value + if batch_dict: + for key, value in batch_dict.items(): + snapshot[key].batch = value + if additional_features: + for feature_name, feature_dict in additional_features.items(): + if feature_dict: + for key, value in feature_dict.items(): + snapshot[key][feature_name] = value + return snapshot + + def __next__(self): + if self.t < len(self.feature_dicts): + snapshot = self.__get_item__(self.t) + self.t = self.t + 1 + return snapshot + else: + self.t = 0 + raise StopIteration + + def __iter__(self): + self.t = 0 + return self