In [26]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.dataset import PedalMeDatasetLoader
from torch_geometric_temporal.nn.recurrent import TGCN
from torch_geometric_temporal.signal import temporal_signal_split
from tqdm import tqdm

In [27]:
downloader = PedalMeDatasetLoader()

In [28]:
dataset = downloader.get_dataset()

In [29]:
dataset.edge_index.shape

(2, 225)

In [30]:
print(f"Total number of dataset: {len(dataset.features)}")
for feature in dataset.features:
    print(feature.shape)

# n_nodes * features

Total number of dataset: 31
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)
(15, 4)


In [31]:
print(f"Total number of dataset: {len(dataset.targets)}")
for target in dataset.targets:
    print(target.shape)

# n_node * 1 ==> 

Total number of dataset: 31
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)
(15,)


In [32]:
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)

In [33]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = TGCN(node_features, 32)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [34]:
model = RecurrentGCN(node_features=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

for epoch in tqdm(range(50)):
    loss = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        loss = loss + torch.mean((y_hat - snapshot.y) ** 2)
    loss = loss / (time + 1)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

model.eval()
loss = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    loss = loss + torch.mean((y_hat - snapshot.y) ** 2)
loss = loss / (time + 1)
loss = loss.item()
print("MSE: {:.4f}".format(loss))


100%|██████████| 50/50 [00:00<00:00, 86.18it/s]

MSE: 0.5202



