Skip to content

Commit

Permalink
Static hetero graph temporla signal batch
Browse files Browse the repository at this point in the history
This includes:
- a new data structure, that represents PyG HeteroDataBatch snapshots for static heterogeneous graphs with temporal signals
- documentation
- tests
  • Loading branch information
doGregor committed Jan 20, 2022
1 parent a3f66a1 commit 62a86b4
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/source/modules/signal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ Temporal Signal Batch Iterators
:members:
:undoc-members:

Heterogeneous Temporal Signal Batch Iterators
-------------------------------

.. automodule:: torch_geometric_temporal.signal.static_hetero_graph_temporal_signal_batch
:members:
:undoc-members:

Temporal Signal Train-Test Split
--------------------------------

Expand Down
35 changes: 35 additions & 0 deletions test/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch_geometric_temporal.signal import DynamicGraphTemporalSignalBatch
from torch_geometric_temporal.signal import DynamicGraphStaticSignalBatch

from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignalBatch


def get_edge_array(node_count, node_start):
edges = []
Expand Down Expand Up @@ -83,6 +85,17 @@ def test_static_graph_temporal_signal_batch():
assert snapshot.batch is None


def test_static_hetero_graph_temporal_signal_batch():
dataset = StaticHeteroGraphTemporalSignalBatch(
None, None, [None, None], [None, None], None
)
for snapshot in dataset:
assert len(snapshot.node_types) == 0
assert len(snapshot.node_stores) == 0
assert len(snapshot.edge_types) == 0
assert len(snapshot.edge_stores) == 0


def test_dynamic_graph_temporal_signal_batch():
dataset = DynamicGraphTemporalSignalBatch(
[None, None], [None, None], [None, None], [None, None], [None, None]
Expand All @@ -107,6 +120,18 @@ def test_static_graph_temporal_signal_typing_batch():
assert snapshot.batch is None


def test_static_hetero_graph_temporal_signal_typing_batch():
dataset = StaticHeteroGraphTemporalSignalBatch(
None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], None
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert 'batch' not in list(dict(snapshot.node_stores[0]).keys())
assert len(snapshot.edge_types) == 0


def test_dynamic_graph_static_signal_typing_batch():
dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None])
for snapshot in dataset:
Expand Down Expand Up @@ -135,6 +160,16 @@ def test_static_graph_temporal_signal_batch_additional_attrs():
assert snapshot.optional2.shape == (1,)


def test_static_hetero_graph_temporal_signal_batch_additional_attrs():
dataset = StaticHeteroGraphTemporalSignalBatch(None, None, [None], [None], None,
optional1=[{'author': np.array([1])}],
optional2=[{'author': np.array([2])}])
assert dataset.additional_feature_keys == ["optional1", "optional2"]
for snapshot in dataset:
assert snapshot.node_stores[0]['optional1'].shape == (1,)
assert snapshot.node_stores[0]['optional2'].shape == (1,)


def test_dynamic_graph_static_signal_batch_additional_attrs():
dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None],
optional1=[np.array([1])], optional2=[np.array([2])])
Expand Down
1 change: 1 addition & 0 deletions torch_geometric_temporal/signal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
from .dynamic_graph_static_signal_batch import *

from .static_hetero_graph_temporal_signal import *
from .static_hetero_graph_temporal_signal_batch import *

from .train_test_split import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import torch
import numpy as np
from typing import List, Dict, Union, Tuple
from torch_geometric.data import Batch, HeteroData

Edge_Index = Union[Dict[Tuple[str, str, str], np.ndarray], None]
Edge_Weight = Union[Dict[Tuple[str, str, str], np.ndarray], None]
Node_Features = List[Union[Dict[str, np.ndarray], None]]
Targets = List[Union[Dict[str, np.ndarray], None]]
Batches = Union[Dict[str, np.ndarray], None]
Additional_Features = List[Union[Dict[str, np.ndarray], None]]


class StaticHeteroGraphTemporalSignalBatch(object):
r"""A data iterator object to contain a static heterogeneous graph with a dynamically
changing constant time difference temporal feature set (multiple signals).
The node labels (target) are also temporal. The iterator returns a single
constant time difference temporal snapshot for a time period (e.g. day or week).
This single temporal snapshot is a Pytorch Geometric Batch object. Between two
temporal snapshots the feature matrix, target matrices and optionally passed
attributes might change. However, the underlying graph is the same.
Args:
edge_index_dict (Dictionary of keys=Tuples and values=Numpy arrays): Relation type tuples
and their edge index tensors.
edge_weight_dict (Dictionary of keys=Tuples and values=Numpy arrays): Relation type tuples
and their edge weight tensors.
feature_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node
types and their feature tensors.
target_dicts (List of dictionaries where keys=Strings and values=Numpy arrays): List of node
types and their label (target) tensors.
batch_dict (Dictionary of keys=Strings and values=Numpy arrays): Batch index tensor of each
node type.
**kwargs (optional List of dictionaries where keys=Strings and values=Numpy arrays): List
of node types and their additional attributes.
"""

def __init__(
self,
edge_index_dict: Edge_Index,
edge_weight_dict: Edge_Weight,
feature_dicts: Node_Features,
target_dicts: Targets,
batch_dict: Batches,
**kwargs: Additional_Features
):
self.edge_index_dict = edge_index_dict
self.edge_weight_dict = edge_weight_dict
self.feature_dicts = feature_dicts
self.target_dicts = target_dicts
self.batch_dict = batch_dict
self.additional_feature_keys = []
for key, value in kwargs.items():
setattr(self, key, value)
self.additional_feature_keys.append(key)
self._check_temporal_consistency()
self._set_snapshot_count()

def _check_temporal_consistency(self):
assert len(self.feature_dicts) == len(
self.target_dicts
), "Temporal dimension inconsistency."
for key in self.additional_feature_keys:
assert len(self.target_dicts) == len(
getattr(self, key)
), "Temporal dimension inconsistency."

def _set_snapshot_count(self):
self.snapshot_count = len(self.feature_dicts)

def _get_edge_index(self):
if self.edge_index_dict is None:
return self.edge_index_dict
else:
return {key: torch.LongTensor(value) for key, value in self.edge_index_dict.items()}

def _get_batch_index(self):
if self.batch_dict is None:
return self.batch_dict
else:
return {key: torch.LongTensor(value) for key, value in self.batch_dict.items()}

def _get_edge_weight(self):
if self.edge_weight_dict is None:
return self.edge_weight_dict
else:
return {key: torch.FloatTensor(value) for key, value in self.edge_weight_dict.items()}

def _get_features(self, time_index: int):
if self.feature_dicts[time_index] is None:
return self.feature_dicts[time_index]
else:
return {key: torch.FloatTensor(value) for key, value in self.feature_dicts[time_index].items()
if value is not None}

def _get_target(self, time_index: int):
if self.target_dicts[time_index] is None:
return self.target_dicts[time_index]
else:
return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value)
if value.dtype.kind == "i" else value for key, value in self.target_dicts[time_index].items()
if value is not None}

def _get_additional_feature(self, time_index: int, feature_key: str):
feature = getattr(self, feature_key)[time_index]
if feature is None:
return feature
else:
return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value)
if value.dtype.kind == "i" else value for key, value in feature.items()
if value is not None}

def _get_additional_features(self, time_index: int):
additional_features = {
key: self._get_additional_feature(time_index, key)
for key in self.additional_feature_keys
}
return additional_features

def __get_item__(self, time_index: int):
x_dict = self._get_features(time_index)
edge_index_dict = self._get_edge_index()
edge_weight_dict = self._get_edge_weight()
batch_dict = self._get_batch_index()
y_dict = self._get_target(time_index)
additional_features = self._get_additional_features(time_index)

snapshot = Batch.from_data_list([HeteroData()])
if x_dict:
for key, value in x_dict.items():
snapshot[key].x = value
if edge_index_dict:
for key, value in edge_index_dict.items():
snapshot[key].edge_index = value
if edge_weight_dict:
for key, value in edge_weight_dict.items():
snapshot[key].edge_attr = value
if y_dict:
for key, value in y_dict.items():
snapshot[key].y = value
if batch_dict:
for key, value in batch_dict.items():
snapshot[key].batch = value
if additional_features:
for feature_name, feature_dict in additional_features.items():
if feature_dict:
for key, value in feature_dict.items():
snapshot[key][feature_name] = value
return snapshot

def __next__(self):
if self.t < len(self.feature_dicts):
snapshot = self.__get_item__(self.t)
self.t = self.t + 1
return snapshot
else:
self.t = 0
raise StopIteration

def __iter__(self):
self.t = 0
return self

0 comments on commit 62a86b4

Please sign in to comment.