In [1]:
try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable):
        return iterable

import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import EvolveGCNO

from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

loader = ChickenpoxDatasetLoader()

dataset = loader.get_dataset()

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)

In [2]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = EvolveGCNO(node_features)
        self.linear = torch.nn.Linear(node_features, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h
        
model = RecurrentGCN(node_features = 4)
for param in model.parameters():
    param.retain_grad()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [5]:
for time, snapshot in enumerate(train_dataset):
    print(snapshot)

Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[20, 4], edge_index=[2, 102], edge_attr=[102], y=[20])
Data(x=[

In [16]:
snapshot.y

tensor([-1.0814e-03,  4.5349e-01,  1.2904e+00, -2.3982e+00,  5.5468e-01,
         5.6077e-01,  3.6383e+00,  4.3413e-01, -4.2117e-01,  1.9873e+00,
         5.1625e-01,  6.9223e-01,  4.1767e-01,  7.6423e-01,  8.2758e-01,
         2.5313e-01,  1.2545e+00, -1.3190e+00,  7.0646e-01,  1.3516e+00])

In [None]:
model.train()

for epoch in tqdm(range(200)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    cost.backward(retain_graph=True)
    optimizer.step()
    optimizer.zero_grad()
    

In [None]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    if time == 0:
        model.recurrent.weight = None
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))