Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First commit for heterogeneous graph support #125

Merged
30 changes: 30 additions & 0 deletions docs/source/modules/signal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ Temporal Signal Iterators
:members:
:undoc-members:

Heterogeneous Temporal Signal Iterators
-------------------------

.. automodule:: torch_geometric_temporal.signal.static_hetero_graph_temporal_signal
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_temporal_signal
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_static_signal
:members:
:undoc-members:

Temporal Signal Batch Iterators
-------------------------------

Expand All @@ -34,6 +49,21 @@ Temporal Signal Batch Iterators
:members:
:undoc-members:

Heterogeneous Temporal Signal Batch Iterators
-------------------------------

.. automodule:: torch_geometric_temporal.signal.static_hetero_graph_temporal_signal_batch
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_temporal_signal_batch
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_static_signal_batch
:members:
:undoc-members:

Temporal Signal Train-Test Split
--------------------------------

Expand Down
186 changes: 186 additions & 0 deletions test/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from torch_geometric_temporal.signal import DynamicGraphTemporalSignalBatch
from torch_geometric_temporal.signal import DynamicGraphStaticSignalBatch

from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignalBatch
from torch_geometric_temporal.signal import DynamicHeteroGraphTemporalSignalBatch
from torch_geometric_temporal.signal import DynamicHeteroGraphStaticSignalBatch


def get_edge_array(node_count, node_start):
edges = []
Expand Down Expand Up @@ -83,6 +87,39 @@ def test_static_graph_temporal_signal_batch():
assert snapshot.batch is None


def test_static_hetero_graph_temporal_signal_batch():
dataset = StaticHeteroGraphTemporalSignalBatch(
None, None, [None, None], [None, None], None
)
for snapshot in dataset:
assert len(snapshot.node_types) == 0
assert len(snapshot.node_stores) == 0
assert len(snapshot.edge_types) == 0
assert len(snapshot.edge_stores) == 0


def test_dynamic_hetero_graph_static_signal_batch():
dataset = DynamicHeteroGraphStaticSignalBatch(
[None], [None], None, [None], [None]
)
for snapshot in dataset:
assert len(snapshot.node_types) == 0
assert len(snapshot.node_stores) == 0
assert len(snapshot.edge_types) == 0
assert len(snapshot.edge_stores) == 0


def test_dynamic_hetero_graph_temporal_signal_batch():
dataset = DynamicHeteroGraphTemporalSignalBatch(
[None, None], [None, None], [None, None], [None, None], [None, None]
)
for snapshot in dataset:
assert len(snapshot.node_types) == 0
assert len(snapshot.node_stores) == 0
assert len(snapshot.edge_types) == 0
assert len(snapshot.edge_stores) == 0


def test_dynamic_graph_temporal_signal_batch():
dataset = DynamicGraphTemporalSignalBatch(
[None, None], [None, None], [None, None], [None, None], [None, None]
Expand All @@ -107,6 +144,42 @@ def test_static_graph_temporal_signal_typing_batch():
assert snapshot.batch is None


def test_static_hetero_graph_temporal_signal_typing_batch():
dataset = StaticHeteroGraphTemporalSignalBatch(
None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], None
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert 'batch' not in list(dict(snapshot.node_stores[0]).keys())
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_static_signal_typing_batch():
dataset = DynamicHeteroGraphStaticSignalBatch(
[None], [None], {'author': np.array([1])}, [{'author': np.array([2])}], [None]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert 'batch' not in list(dict(snapshot.node_stores[0]).keys())
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_temporal_signal_typing_batch():
dataset = DynamicHeteroGraphTemporalSignalBatch(
[None], [None], [{'author': np.array([1])}], [{'author': np.array([2])}], [None]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert 'batch' not in list(dict(snapshot.node_stores[0]).keys())
assert len(snapshot.edge_types) == 0


def test_dynamic_graph_static_signal_typing_batch():
dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None])
for snapshot in dataset:
Expand Down Expand Up @@ -135,6 +208,42 @@ def test_static_graph_temporal_signal_batch_additional_attrs():
assert snapshot.optional2.shape == (1,)


def test_static_hetero_graph_temporal_signal_batch_additional_attrs():
dataset = StaticHeteroGraphTemporalSignalBatch(None, None, [None], [None], None,
optional1=[{'author': np.array([1])}],
optional2=[{'author': np.array([2])}],
optional3=[None])
assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"]
for snapshot in dataset:
assert snapshot.node_stores[0]['optional1'].shape == (1,)
assert snapshot.node_stores[0]['optional2'].shape == (1,)
assert "optional3" not in list(dict(snapshot.node_stores[0]).keys())


def test_dynamic_hetero_graph_static_signal_batch_additional_attrs():
dataset = DynamicHeteroGraphStaticSignalBatch([None], [None], None, [None], [None],
optional1=[{'author': np.array([1])}],
optional2=[{'author': np.array([2])}],
optional3=[None])
assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"]
for snapshot in dataset:
assert snapshot.node_stores[0]['optional1'].shape == (1,)
assert snapshot.node_stores[0]['optional2'].shape == (1,)
assert "optional3" not in list(dict(snapshot.node_stores[0]).keys())


def test_dynamic_hetero_graph_temporal_signal_batch_additional_attrs():
dataset = DynamicHeteroGraphTemporalSignalBatch([None], [None], [None], [None], [None],
optional1=[{'author': np.array([1])}],
optional2=[{'author': np.array([2])}],
optional3=[None])
assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"]
for snapshot in dataset:
assert snapshot.node_stores[0]['optional1'].shape == (1,)
assert snapshot.node_stores[0]['optional2'].shape == (1,)
assert "optional3" not in list(dict(snapshot.node_stores[0]).keys())


def test_dynamic_graph_static_signal_batch_additional_attrs():
dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None],
optional1=[np.array([1])], optional2=[np.array([2])])
Expand All @@ -144,6 +253,83 @@ def test_dynamic_graph_static_signal_batch_additional_attrs():
assert snapshot.optional2.shape == (1,)


def test_static_hetero_graph_temporal_signal_batch_edges():
dataset = StaticHeteroGraphTemporalSignalBatch({("author", "writes", "paper"): np.array([[0, 1], [1, 0]])},
{("author", "writes", "paper"): np.array([[0.1], [0.1]])},
[{"author": np.array([[0], [0]]),
"paper": np.array([[0], [0], [0]])},
{"author": np.array([[0.1], [0.1]]),
"paper": np.array([[0.1], [0.1], [0.1]])}],
[None, None],
None)
for snapshot in dataset:
assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2)
assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1)
assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0]


def test_dynamic_hetero_graph_static_signal_batch_edges():
dataset = DynamicHeteroGraphStaticSignalBatch([{("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}],
[{("author", "writes", "paper"): np.array([[0.1], [0.1]])}],
{"author": np.array([[0], [0]]),
"paper": np.array([[0], [0], [0]])},
[None],
[None])
for snapshot in dataset:
assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2)
assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1)
assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0]


def test_dynamic_hetero_graph_temporal_signal_batch_edges():
dataset = DynamicHeteroGraphTemporalSignalBatch([{("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}],
[{("author", "writes", "paper"): np.array([[0.1], [0.1]])}],
[{"author": np.array([[0], [0]]),
"paper": np.array([[0], [0], [0]])}],
[None],
[None])
for snapshot in dataset:
assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2)
assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1)
assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0]


def test_static_hetero_graph_temporal_signal_batch_assigned():
dataset = StaticHeteroGraphTemporalSignalBatch(
None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], {'author': np.array([1])}
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert snapshot.node_stores[0]['batch'].shape == (1,)
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_static_signal_batch_assigned():
dataset = DynamicHeteroGraphStaticSignalBatch(
[None], [None], {'author': np.array([1])}, [{'author': np.array([2])}], [{'author': np.array([1])}]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert snapshot.node_stores[0]['batch'].shape == (1,)
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_temporal_signal_batch_assigned():
dataset = DynamicHeteroGraphTemporalSignalBatch(
[None], [None], [{'author': np.array([1])}], [{'author': np.array([2])}], [{'author': np.array([1])}]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert snapshot.node_stores[0]['batch'].shape == (1,)
assert len(snapshot.edge_types) == 0


def test_discrete_train_test_split_dynamic_batch():

snapshot_count = 250
Expand Down