Skip to content

Commit

Permalink
dynamic hetero graph temporal signal batch
Browse files Browse the repository at this point in the history
- data structure that works like DynamicGraphTemporalSignalBatch but with PyG HeteroData Batches instead
- tests
-documentation
  • Loading branch information
doGregor committed Feb 3, 2022
1 parent c354f9b commit 0437310
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/modules/signal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ Heterogeneous Temporal Signal Batch Iterators
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_temporal_signal_batch
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_static_signal_batch
:members:
:undoc-members:
Expand Down
61 changes: 61 additions & 0 deletions test/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch_geometric_temporal.signal import DynamicGraphStaticSignalBatch

from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignalBatch
from torch_geometric_temporal.signal import DynamicHeteroGraphTemporalSignalBatch
from torch_geometric_temporal.signal import DynamicHeteroGraphStaticSignalBatch


Expand Down Expand Up @@ -108,6 +109,17 @@ def test_dynamic_hetero_graph_static_signal_batch():
assert len(snapshot.edge_stores) == 0


def test_dynamic_hetero_graph_temporal_signal_batch():
dataset = DynamicHeteroGraphTemporalSignalBatch(
[None, None], [None, None], [None, None], [None, None], [None, None]
)
for snapshot in dataset:
assert len(snapshot.node_types) == 0
assert len(snapshot.node_stores) == 0
assert len(snapshot.edge_types) == 0
assert len(snapshot.edge_stores) == 0


def test_dynamic_graph_temporal_signal_batch():
dataset = DynamicGraphTemporalSignalBatch(
[None, None], [None, None], [None, None], [None, None], [None, None]
Expand Down Expand Up @@ -156,6 +168,18 @@ def test_dynamic_hetero_graph_static_signal_typing_batch():
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_temporal_signal_typing_batch():
dataset = DynamicHeteroGraphTemporalSignalBatch(
[None], [None], [{'author': np.array([1])}], [{'author': np.array([2])}], [None]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert 'batch' not in list(dict(snapshot.node_stores[0]).keys())
assert len(snapshot.edge_types) == 0


def test_dynamic_graph_static_signal_typing_batch():
dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None])
for snapshot in dataset:
Expand Down Expand Up @@ -208,6 +232,18 @@ def test_dynamic_hetero_graph_static_signal_batch_additional_attrs():
assert "optional3" not in list(dict(snapshot.node_stores[0]).keys())


def test_dynamic_hetero_graph_temporal_signal_batch_additional_attrs():
dataset = DynamicHeteroGraphTemporalSignalBatch([None], [None], [None], [None], [None],
optional1=[{'author': np.array([1])}],
optional2=[{'author': np.array([2])}],
optional3=[None])
assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"]
for snapshot in dataset:
assert snapshot.node_stores[0]['optional1'].shape == (1,)
assert snapshot.node_stores[0]['optional2'].shape == (1,)
assert "optional3" not in list(dict(snapshot.node_stores[0]).keys())


def test_dynamic_graph_static_signal_batch_additional_attrs():
dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None],
optional1=[np.array([1])], optional2=[np.array([2])])
Expand Down Expand Up @@ -245,6 +281,19 @@ def test_dynamic_hetero_graph_static_signal_batch_edges():
assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0]


def test_dynamic_hetero_graph_temporal_signal_batch_edges():
dataset = DynamicHeteroGraphTemporalSignalBatch([{("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}],
[{("author", "writes", "paper"): np.array([[0.1], [0.1]])}],
[{"author": np.array([[0], [0]]),
"paper": np.array([[0], [0], [0]])}],
[None],
[None])
for snapshot in dataset:
assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2)
assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1)
assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0]


def test_static_hetero_graph_temporal_signal_batch_assigned():
dataset = StaticHeteroGraphTemporalSignalBatch(
None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], {'author': np.array([1])}
Expand All @@ -269,6 +318,18 @@ def test_dynamic_hetero_graph_static_signal_batch_assigned():
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_temporal_signal_batch_assigned():
dataset = DynamicHeteroGraphTemporalSignalBatch(
[None], [None], [{'author': np.array([1])}], [{'author': np.array([2])}], [{'author': np.array([1])}]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert snapshot.node_stores[0]['batch'].shape == (1,)
assert len(snapshot.edge_types) == 0


def test_discrete_train_test_split_dynamic_batch():

snapshot_count = 250
Expand Down
1 change: 1 addition & 0 deletions torch_geometric_temporal/signal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .dynamic_graph_static_signal_batch import *

from .dynamic_hetero_graph_temporal_signal import *
from .dynamic_hetero_graph_temporal_signal_batch import *

from .static_hetero_graph_temporal_signal import *
from .static_hetero_graph_temporal_signal_batch import *
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import torch
import numpy as np
from typing import List, Dict, Union, Tuple
from torch_geometric.data import HeteroData, Batch


Edge_Indices = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]]
Edge_Weights = List[Union[Dict[Tuple[str, str, str], np.ndarray], None]]
Node_Features = List[Union[Dict[str, np.ndarray], None]]
Targets = List[Union[Dict[str, np.ndarray], None]]
Batches = List[Union[Dict[str, np.ndarray], None]]
Additional_Features = List[Union[Dict[str, np.ndarray], None]]


class DynamicHeteroGraphTemporalSignalBatch(object):
r"""A data iterator object to contain a dynamic heterogeneous graph with a
changing edge set and weights. The feature set and node labels
(target) are also dynamic. The iterator returns a single discrete temporal
snapshot for a time period (e.g. day or week). This single snapshot is a
Pytorch Geometric HeteroData Batch object. Between two temporal snapshots the edges,
edge weights, the feature matrix, target matrices and optionally passed
attributes might change.
Args:
edge_index_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays):
List of relation type tuples and their edge index tensors.
edge_weight_dicts (List of dictionaries where keys=Tuples and values=Numpy arrays):
List of relation type tuples and their edge weight tensors.
feature_dicts (List of dictionaries where keys=Strings and values=Numpy arrays):
List of node types and their feature tensors.
target_dicts (List of dictionaries where keys=Strings and values=Numpy arrays):
List of node types and their label (target) tensors.
batch_dicts (List of dictionaries where keys=Strings and values=Numpy arrays):
List of batch index tensor for each node type.
**kwargs (optional List of dictionaries where keys=Strings and values=Numpy arrays):
List of node types and their additional attributes.
"""

def __init__(
self,
edge_index_dicts: Edge_Indices,
edge_weight_dicts: Edge_Weights,
feature_dicts: Node_Features,
target_dicts: Targets,
batch_dicts: Batches,
**kwargs: Additional_Features
):
self.edge_index_dicts = edge_index_dicts
self.edge_weight_dicts = edge_weight_dicts
self.feature_dicts = feature_dicts
self.target_dicts = target_dicts
self.batch_dicts = batch_dicts
self.additional_feature_keys = []
for key, value in kwargs.items():
setattr(self, key, value)
self.additional_feature_keys.append(key)
self._check_temporal_consistency()
self._set_snapshot_count()

def _check_temporal_consistency(self):
assert len(self.feature_dicts) == len(
self.target_dicts
), "Temporal dimension inconsistency."
assert len(self.edge_index_dicts) == len(
self.edge_weight_dicts
), "Temporal dimension inconsistency."
assert len(self.feature_dicts) == len(
self.edge_weight_dicts
), "Temporal dimension inconsistency."
assert len(self.feature_dicts) == len(
self.batch_dicts
), "Temporal dimension inconsistency."
for key in self.additional_feature_keys:
assert len(self.target_dicts) == len(
getattr(self, key)
), "Temporal dimension inconsistency."

def _set_snapshot_count(self):
self.snapshot_count = len(self.feature_dicts)

def _get_edge_index(self, time_index: int):
if self.edge_index_dicts[time_index] is None:
return self.edge_index_dicts[time_index]
else:
return {key: torch.LongTensor(value) for key, value in self.edge_index_dicts[time_index].items()
if value is not None}

def _get_batch_index(self, time_index: int):
if self.batch_dicts[time_index] is None:
return self.batch_dicts[time_index]
else:
return {key: torch.LongTensor(value) for key, value in self.batch_dicts[time_index].items()
if value is not None}

def _get_edge_weight(self, time_index: int):
if self.edge_weight_dicts[time_index] is None:
return self.edge_weight_dicts[time_index]
else:
return {key: torch.FloatTensor(value) for key, value in self.edge_weight_dicts[time_index].items()
if value is not None}

def _get_feature(self, time_index: int):
if self.feature_dicts[time_index] is None:
return self.feature_dicts[time_index]
else:
return {key: torch.FloatTensor(value) for key, value in self.feature_dicts[time_index].items()
if value is not None}

def _get_target(self, time_index: int):
if self.target_dicts[time_index] is None:
return self.target_dicts[time_index]
else:
return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value)
if value.dtype.kind == "i" else value for key, value in self.target_dicts[time_index].items()
if value is not None}

def _get_additional_feature(self, time_index: int, feature_key: str):
feature = getattr(self, feature_key)[time_index]
if feature is None:
return feature
else:
return {key: torch.FloatTensor(value) if value.dtype.kind == "f" else torch.LongTensor(value)
if value.dtype.kind == "i" else value for key, value in feature.items()
if value is not None}

def _get_additional_features(self, time_index: int):
additional_features = {
key: self._get_additional_feature(time_index, key)
for key in self.additional_feature_keys
}
return additional_features

def __getitem__(self, time_index: int):
x_dict = self._get_feature(time_index)
edge_index_dict = self._get_edge_index(time_index)
edge_weight_dict = self._get_edge_weight(time_index)
batch_dict = self._get_batch_index(time_index)
y_dict = self._get_target(time_index)
additional_features = self._get_additional_features(time_index)

snapshot = Batch.from_data_list([HeteroData()])
if x_dict:
for key, value in x_dict.items():
snapshot[key].x = value
if edge_index_dict:
for key, value in edge_index_dict.items():
snapshot[key].edge_index = value
if edge_weight_dict:
for key, value in edge_weight_dict.items():
snapshot[key].edge_attr = value
if y_dict:
for key, value in y_dict.items():
snapshot[key].y = value
if batch_dict:
for key, value in batch_dict.items():
snapshot[key].batch = value
if additional_features:
for feature_name, feature_dict in additional_features.items():
if feature_dict:
for key, value in feature_dict.items():
snapshot[key][feature_name] = value
return snapshot

def __next__(self):
if self.t < len(self.feature_dicts):
snapshot = self[self.t]
self.t = self.t + 1
return snapshot
else:
self.t = 0
raise StopIteration

def __iter__(self):
self.t = 0
return self

0 comments on commit 0437310

Please sign in to comment.