In [32]:
import numpy as np
import networkx as nx

from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
from torch_geometric_temporal.dataset import PedalMeDatasetLoader, ChickenpoxDatasetLoader

import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

from tqdm import tqdm

### https://pytorch-geometric-temporal.readthedocs.io/en/latest/notes/introduction.html#epidemiological-forecasting

In [29]:
#loader = PedalMeDatasetLoader()
loader = ChickenpoxDatasetLoader()
dataset = loader.get_dataset()

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)

In [25]:
[ i for i in dir(dataset) if '__' not in i ]

['_check_temporal_consistency',
 '_get_edge_index',
 '_get_edge_weight',
 '_get_features',
 '_get_snapshot',
 '_get_target',
 '_set_snapshot_count',
 'edge_index',
 'edge_weight',
 'features',
 'snapshot_count',
 'targets']

In [30]:
dataset.snapshot_count

517

In [27]:
[ snapshot for snapshot in dataset ][:5]

[Data(edge_attr=[102], edge_index=[2, 102], x=[20, 4], y=[20]),
 Data(edge_attr=[102], edge_index=[2, 102], x=[20, 4], y=[20]),
 Data(edge_attr=[102], edge_index=[2, 102], x=[20, 4], y=[20]),
 Data(edge_attr=[102], edge_index=[2, 102], x=[20, 4], y=[20]),
 Data(edge_attr=[102], edge_index=[2, 102], x=[20, 4], y=[20])]

In [37]:
type(snapshot)

torch_geometric.data.data.Data

In [31]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 32, 1)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [33]:
model = RecurrentGCN(node_features = 4)

optimizer = torch.optim.Adam( model.parameters(), lr=1e-2 )

model.train()

for epoch in tqdm(range(200)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model( snapshot.x, snapshot.edge_index, snapshot.edge_attr )
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

100%|██████████| 200/200 [00:38<00:00,  5.18it/s]


In [34]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))

MSE: 1.0405


In [38]:
y_hat

tensor([[-0.1272],
        [-0.3243],
        [ 0.1616],
        [ 0.1606],
        [-0.0433],
        [ 0.0745],
        [ 0.1564],
        [-0.0787],
        [ 0.1371],
        [ 0.0501],
        [-0.0391],
        [ 0.0712],
        [-0.1210],
        [-0.0326],
        [ 0.1595],
        [-0.0230],
        [ 0.1796],
        [ 0.0653],
        [-0.1264],
        [-0.0870]], grad_fn=<AddmmBackward>)