Skip to content

Commit

Permalink
Merge pull request #125 from doGregor/hetero_support
Browse files Browse the repository at this point in the history
First commit for heterogeneous graph support
  • Loading branch information
benedekrozemberczki committed Feb 3, 2022
2 parents 5d5c7dd + 0437310 commit 01b4394
Show file tree
Hide file tree
Showing 10 changed files with 1,379 additions and 0 deletions.
30 changes: 30 additions & 0 deletions docs/source/modules/signal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ Temporal Signal Iterators
:members:
:undoc-members:

Heterogeneous Temporal Signal Iterators
-------------------------

.. automodule:: torch_geometric_temporal.signal.static_hetero_graph_temporal_signal
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_temporal_signal
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_static_signal
:members:
:undoc-members:

Temporal Signal Batch Iterators
-------------------------------

Expand All @@ -34,6 +49,21 @@ Temporal Signal Batch Iterators
:members:
:undoc-members:

Heterogeneous Temporal Signal Batch Iterators
-------------------------------

.. automodule:: torch_geometric_temporal.signal.static_hetero_graph_temporal_signal_batch
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_temporal_signal_batch
:members:
:undoc-members:

.. automodule:: torch_geometric_temporal.signal.dynamic_hetero_graph_static_signal_batch
:members:
:undoc-members:

Temporal Signal Train-Test Split
--------------------------------

Expand Down
186 changes: 186 additions & 0 deletions test/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from torch_geometric_temporal.signal import DynamicGraphTemporalSignalBatch
from torch_geometric_temporal.signal import DynamicGraphStaticSignalBatch

from torch_geometric_temporal.signal import StaticHeteroGraphTemporalSignalBatch
from torch_geometric_temporal.signal import DynamicHeteroGraphTemporalSignalBatch
from torch_geometric_temporal.signal import DynamicHeteroGraphStaticSignalBatch


def get_edge_array(node_count, node_start):
edges = []
Expand Down Expand Up @@ -83,6 +87,39 @@ def test_static_graph_temporal_signal_batch():
assert snapshot.batch is None


def test_static_hetero_graph_temporal_signal_batch():
dataset = StaticHeteroGraphTemporalSignalBatch(
None, None, [None, None], [None, None], None
)
for snapshot in dataset:
assert len(snapshot.node_types) == 0
assert len(snapshot.node_stores) == 0
assert len(snapshot.edge_types) == 0
assert len(snapshot.edge_stores) == 0


def test_dynamic_hetero_graph_static_signal_batch():
dataset = DynamicHeteroGraphStaticSignalBatch(
[None], [None], None, [None], [None]
)
for snapshot in dataset:
assert len(snapshot.node_types) == 0
assert len(snapshot.node_stores) == 0
assert len(snapshot.edge_types) == 0
assert len(snapshot.edge_stores) == 0


def test_dynamic_hetero_graph_temporal_signal_batch():
dataset = DynamicHeteroGraphTemporalSignalBatch(
[None, None], [None, None], [None, None], [None, None], [None, None]
)
for snapshot in dataset:
assert len(snapshot.node_types) == 0
assert len(snapshot.node_stores) == 0
assert len(snapshot.edge_types) == 0
assert len(snapshot.edge_stores) == 0


def test_dynamic_graph_temporal_signal_batch():
dataset = DynamicGraphTemporalSignalBatch(
[None, None], [None, None], [None, None], [None, None], [None, None]
Expand All @@ -107,6 +144,42 @@ def test_static_graph_temporal_signal_typing_batch():
assert snapshot.batch is None


def test_static_hetero_graph_temporal_signal_typing_batch():
dataset = StaticHeteroGraphTemporalSignalBatch(
None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], None
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert 'batch' not in list(dict(snapshot.node_stores[0]).keys())
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_static_signal_typing_batch():
dataset = DynamicHeteroGraphStaticSignalBatch(
[None], [None], {'author': np.array([1])}, [{'author': np.array([2])}], [None]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert 'batch' not in list(dict(snapshot.node_stores[0]).keys())
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_temporal_signal_typing_batch():
dataset = DynamicHeteroGraphTemporalSignalBatch(
[None], [None], [{'author': np.array([1])}], [{'author': np.array([2])}], [None]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert 'batch' not in list(dict(snapshot.node_stores[0]).keys())
assert len(snapshot.edge_types) == 0


def test_dynamic_graph_static_signal_typing_batch():
dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None])
for snapshot in dataset:
Expand Down Expand Up @@ -135,6 +208,42 @@ def test_static_graph_temporal_signal_batch_additional_attrs():
assert snapshot.optional2.shape == (1,)


def test_static_hetero_graph_temporal_signal_batch_additional_attrs():
dataset = StaticHeteroGraphTemporalSignalBatch(None, None, [None], [None], None,
optional1=[{'author': np.array([1])}],
optional2=[{'author': np.array([2])}],
optional3=[None])
assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"]
for snapshot in dataset:
assert snapshot.node_stores[0]['optional1'].shape == (1,)
assert snapshot.node_stores[0]['optional2'].shape == (1,)
assert "optional3" not in list(dict(snapshot.node_stores[0]).keys())


def test_dynamic_hetero_graph_static_signal_batch_additional_attrs():
dataset = DynamicHeteroGraphStaticSignalBatch([None], [None], None, [None], [None],
optional1=[{'author': np.array([1])}],
optional2=[{'author': np.array([2])}],
optional3=[None])
assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"]
for snapshot in dataset:
assert snapshot.node_stores[0]['optional1'].shape == (1,)
assert snapshot.node_stores[0]['optional2'].shape == (1,)
assert "optional3" not in list(dict(snapshot.node_stores[0]).keys())


def test_dynamic_hetero_graph_temporal_signal_batch_additional_attrs():
dataset = DynamicHeteroGraphTemporalSignalBatch([None], [None], [None], [None], [None],
optional1=[{'author': np.array([1])}],
optional2=[{'author': np.array([2])}],
optional3=[None])
assert dataset.additional_feature_keys == ["optional1", "optional2", "optional3"]
for snapshot in dataset:
assert snapshot.node_stores[0]['optional1'].shape == (1,)
assert snapshot.node_stores[0]['optional2'].shape == (1,)
assert "optional3" not in list(dict(snapshot.node_stores[0]).keys())


def test_dynamic_graph_static_signal_batch_additional_attrs():
dataset = DynamicGraphStaticSignalBatch([None], [None], None, [None], [None],
optional1=[np.array([1])], optional2=[np.array([2])])
Expand All @@ -144,6 +253,83 @@ def test_dynamic_graph_static_signal_batch_additional_attrs():
assert snapshot.optional2.shape == (1,)


def test_static_hetero_graph_temporal_signal_batch_edges():
dataset = StaticHeteroGraphTemporalSignalBatch({("author", "writes", "paper"): np.array([[0, 1], [1, 0]])},
{("author", "writes", "paper"): np.array([[0.1], [0.1]])},
[{"author": np.array([[0], [0]]),
"paper": np.array([[0], [0], [0]])},
{"author": np.array([[0.1], [0.1]]),
"paper": np.array([[0.1], [0.1], [0.1]])}],
[None, None],
None)
for snapshot in dataset:
assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2)
assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1)
assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0]


def test_dynamic_hetero_graph_static_signal_batch_edges():
dataset = DynamicHeteroGraphStaticSignalBatch([{("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}],
[{("author", "writes", "paper"): np.array([[0.1], [0.1]])}],
{"author": np.array([[0], [0]]),
"paper": np.array([[0], [0], [0]])},
[None],
[None])
for snapshot in dataset:
assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2)
assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1)
assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0]


def test_dynamic_hetero_graph_temporal_signal_batch_edges():
dataset = DynamicHeteroGraphTemporalSignalBatch([{("author", "writes", "paper"): np.array([[0, 1], [1, 0]])}],
[{("author", "writes", "paper"): np.array([[0.1], [0.1]])}],
[{"author": np.array([[0], [0]]),
"paper": np.array([[0], [0], [0]])}],
[None],
[None])
for snapshot in dataset:
assert snapshot.edge_stores[0]['edge_index'].shape == (2, 2)
assert snapshot.edge_stores[0]['edge_attr'].shape == (2, 1)
assert snapshot.edge_stores[0]['edge_index'].shape[0] == snapshot.edge_stores[0]['edge_attr'].shape[0]


def test_static_hetero_graph_temporal_signal_batch_assigned():
dataset = StaticHeteroGraphTemporalSignalBatch(
None, None, [{'author': np.array([1])}], [{'author': np.array([2])}], {'author': np.array([1])}
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert snapshot.node_stores[0]['batch'].shape == (1,)
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_static_signal_batch_assigned():
dataset = DynamicHeteroGraphStaticSignalBatch(
[None], [None], {'author': np.array([1])}, [{'author': np.array([2])}], [{'author': np.array([1])}]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert snapshot.node_stores[0]['batch'].shape == (1,)
assert len(snapshot.edge_types) == 0


def test_dynamic_hetero_graph_temporal_signal_batch_assigned():
dataset = DynamicHeteroGraphTemporalSignalBatch(
[None], [None], [{'author': np.array([1])}], [{'author': np.array([2])}], [{'author': np.array([1])}]
)
for snapshot in dataset:
assert snapshot.node_types[0] == 'author'
assert snapshot.node_stores[0]['x'].shape == (1,)
assert snapshot.node_stores[0]['y'].shape == (1,)
assert snapshot.node_stores[0]['batch'].shape == (1,)
assert len(snapshot.edge_types) == 0


def test_discrete_train_test_split_dynamic_batch():

snapshot_count = 250
Expand Down

0 comments on commit 01b4394

Please sign in to comment.