<a href="https://colab.research.google.com/github/bachnguyenTE/temporal-mgn/blob/prototype-mgvae/baselines/graph_lstm_baseline_covid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Thu Apr 14 00:49:07 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from google.colab import drive
drive.mount("/content/drive")
!ls drive/MyDrive/Research/TMGN/model_checkpoints/baseline_models

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
graph_lstm  t_gcn


In [3]:
DATA_ROOT_DIR="drive/MyDrive/Research/TMGN/model_checkpoints/baseline_models/graph_lstm/"

In [4]:
# Add this in a Google Colab cell to install the correct version of Pytorch Geometric.
%%capture
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric 
!pip install torch-geometric-temporal

!pip install einops
!wget -c https://gist.githubusercontent.com/Luvata/55f7b3e9ae451122b9e3faf0a7387b4f/raw/440fac5c6e7153fd39e4eb9ebec6e51c9520ef1f/visualize.py
!pip install --upgrade graphviz

!pip install wandb -qqq
!pip install prettytable

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric_temporal.nn.recurrent import GConvLSTM

from torch_geometric_temporal.dataset import EnglandCovidDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

import wandb
import datetime

In [6]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [7]:
wandb.init(project="graph-lstm-baseline-covid", entity="bachnguyen")
wandb.config = {
    "learning_rate": 0.01,
    "epochs": 200,
    "batch_size": 1
}

[34m[1mwandb[0m: Currently logged in as: [33mbachnguyen[0m (use `wandb login --relogin` to force relogin)


In [8]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: 
            continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [9]:
# Fix all random seed
torch_geometric.seed.seed_everything(69420)

# Set device to gpu
device = torch.device('cuda')

In [45]:
loader = EnglandCovidDatasetLoader()
dataset = loader.get_dataset()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)

In [46]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = GConvLSTM(node_features, 64, 1)
        self.linear = torch.nn.Linear(64, 1)

    def forward(self, x, edge_index, edge_weight, h, c):
        h_0, c_0 = self.recurrent(x, edge_index, edge_weight, h, c)
        h = F.relu(h_0)
        h = self.linear(h)
        return h, h_0, c_0

In [57]:
model = RecurrentGCN(node_features=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

param_count = count_parameters(model)

for epoch in tqdm(range(50)):
    cost = 0
    h, c = None, None
    for time, snapshot in enumerate(train_dataset):
        snapshot.to(device)
        y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
        cost = cost + torch.mean((y_hat-snapshot.y)**2)
    cost = cost / (time+1)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Log metrics
    wandb.log({"loss": cost})

+----------------------------------+------------+
|             Modules              | Parameters |
+----------------------------------+------------+
|         recurrent.w_c_i          |     64     |
|          recurrent.b_i           |     64     |
|         recurrent.w_c_f          |     64     |
|          recurrent.b_f           |     64     |
|          recurrent.b_c           |     64     |
|         recurrent.w_c_o          |     64     |
|          recurrent.b_o           |     64     |
|     recurrent.conv_x_i.bias      |     64     |
| recurrent.conv_x_i.lins.0.weight |    512     |
|     recurrent.conv_h_i.bias      |     64     |
| recurrent.conv_h_i.lins.0.weight |    4096    |
|     recurrent.conv_x_f.bias      |     64     |
| recurrent.conv_x_f.lins.0.weight |    512     |
|     recurrent.conv_h_f.bias      |     64     |
| recurrent.conv_h_f.lins.0.weight |    4096    |
|     recurrent.conv_x_c.bias      |     64     |
| recurrent.conv_x_c.lins.0.weight |    512     |


100%|██████████| 50/50 [00:10<00:00,  4.73it/s]


In [58]:
NOW = datetime.datetime.now()
timestamp = NOW.isoformat().replace(":", "_")
print(timestamp)
print('Saving trained model...')
torch.save(model.state_dict(), DATA_ROOT_DIR+f"graph_lstm_covid_{timestamp}.pth")

2022-04-14T00_55_08.252534
Saving trained model...


In [62]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    snapshot.to(device)
    y_hat, h, c = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h, c)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))

MSE: 1.0079
