In [1]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric_temporal.nn.recurrent import GConvGRU
from tqdm import tqdm

In [2]:
loader = WikiMathsDatasetLoader()

dataset = loader.get_dataset(lags=14)

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)

In [11]:
for time, snapshot in enumerate(test_dataset):
    print(time, snapshot)

0 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
1 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
2 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
3 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
4 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
5 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
6 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
7 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
8 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
9 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
10 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
11 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
12 Data(x=[1068, 14], edge_index=[2, 27079], edge_attr=[27079], y=[1068])
13 Data(x=[1068, 14], edge_index=[2, 27079], edg

In [None]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features, filters):
        super(RecurrentGCN, self).__init__()
        self.recurrent = GConvGRU(node_features, filters, 2)
        self.linear = torch.nn.Linear(filters, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [None]:
model = RecurrentGCN(node_features=14, filters=32)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(50)):
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = torch.mean((y_hat-snapshot.y)**2)
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))