In [1]:
import dgl
from dgl.data.utils import load_graphs

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from model.model import HTGNN, NodePredictor
from utils.pytorchtools import EarlyStopping
from utils.data import load_COVID_data

dgl.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def evaluate(model, val_feats, val_labels):
    val_mae_list, val_rmse_list = [], []
    model.eval()
    with torch.no_grad():
        for (G_feat, G_label) in zip(val_feats, val_labels):
            h = model[0](G_feat.to(device), 'state')
            pred = model[1](h)
            label = G_label.nodes['state'].data['feat']
            loss = F.l1_loss(pred, label.to(device))
            rmse = torch.sqrt(F.mse_loss(pred, label.to(device)))

            val_mae_list.append(loss.item())
            val_rmse_list.append(rmse.item())
            
        loss = sum(val_mae_list) / len(val_mae_list)
        rmse = sum(val_rmse_list) / len(val_rmse_list)

    return loss, rmse

In [3]:
device = torch.device('cuda')
glist, _ = load_graphs('data/covid_graphs.bin')
time_window = 7

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_COVID_data(glist, time_window)

In [12]:
val_feats[20].number_of_nodes('state')

51

In [8]:
train_labels[0].nodes['state'].data['feat']

tensor([[3.3800e+02],
        [3.0000e+00],
        [5.7300e+02],
        [5.3000e+01],
        [2.1600e+03],
        [4.5600e+02],
        [6.2800e+02],
        [2.4500e+02],
        [1.7200e+02],
        [3.6800e+02],
        [5.6400e+02],
        [0.0000e+00],
        [2.7000e+01],
        [2.8880e+03],
        [6.4500e+02],
        [3.9500e+02],
        [3.3400e+02],
        [1.6500e+02],
        [2.0300e+02],
        [4.4000e+01],
        [1.0860e+03],
        [1.6120e+03],
        [7.0000e+02],
        [7.2200e+02],
        [4.0300e+02],
        [1.6700e+02],
        [2.0000e+00],
        [6.4100e+02],
        [6.5000e+01],
        [1.0300e+02],
        [1.7410e+03],
        [1.7000e+02],
        [2.6390e+03],
        [4.6600e+02],
        [5.4000e+01],
        [8.8600e+02],
        [9.1000e+01],
        [7.9000e+01],
        [1.4150e+03],
        [2.4900e+02],
        [2.2500e+02],
        [2.3900e+02],
        [3.4400e+02],
        [1.3400e+03],
        [1.9500e+02],
        [2

In [3]:
device = torch.device('cuda')
glist, _ = load_graphs('data/covid_graphs.bin')
time_window = 7

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_COVID_data(glist, time_window)

graph_atom = test_feats[0]
mae_list, rmse_list = [], []
model_out_path = 'output/COVID19'

for k in range(5):
    htgnn = HTGNN(graph=graph_atom, n_inp=1, n_hid=8, n_layers=2, n_heads=1, time_window=time_window, norm=False, device=device)
    predictor = NodePredictor(n_inp=8, n_classes=1)
    model = nn.Sequential(htgnn, predictor).to(device)

    print(f'---------------Repeat time: {k+1}---------------------')
    print(f'# params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    
    early_stopping = EarlyStopping(patience=10, verbose=True, path=f'{model_out_path}/checkpoint_HTGNN_{k}.pt')
    optim = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
    
    train_mae_list, train_rmse_list = [], []
    idx = np.random.permutation(len(train_feats))

    for epoch in range(500):
        model.train()
        for i in idx:

            G_feat = train_feats[i]
            G_label = train_labels[i]
            
            h = model[0](G_feat.to(device), 'state')
            pred = model[1](h)
            label = G_label.nodes['state'].data['feat']
            loss = F.l1_loss(pred, label.to(device))
            rmse = torch.sqrt(F.mse_loss(pred, label.to(device)))

            train_mae_list.append(loss.item())
            train_rmse_list.append(rmse.item())
            optim.zero_grad()
            loss.backward()
            optim.step()
        print(sum(train_mae_list) / len(train_mae_list), sum(train_rmse_list) / len(train_rmse_list))

        loss, rmse = evaluate(model, val_feats, val_labels)
        early_stopping(loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    model.load_state_dict(torch.load(f'{model_out_path}/checkpoint_HTGNN_{k}.pt'))
    mae, rmse = evaluate(model, test_feats, test_labels)

    print(f'mae: {mae}, rmse: {rmse}')
    mae_list.append(mae)
    rmse_list.append(rmse)

---------------Repeat time: 1---------------------
# params: 7349
504.359294714304 1057.8320653086473
Validation loss decreased (inf --> 1242.325362).  Saving model ...
496.99236737021914 1051.9867027218331
Validation loss decreased (1242.325362 --> 1227.936530).  Saving model ...
493.44771888185653 1049.8767646864665
EarlyStopping counter: 1 out of 10
492.7293161319781 1049.9878524828562
EarlyStopping counter: 2 out of 10
491.0248128899039 1048.6511033263387
EarlyStopping counter: 3 out of 10
490.5049708724525 1047.6889496547092
EarlyStopping counter: 4 out of 10
489.89075274981957 1046.9816573949518
EarlyStopping counter: 5 out of 10
489.37654220806394 1046.4771463176871
Validation loss decreased (1227.936530 --> 1210.795846).  Saving model ...
489.4383456788597 1047.1144363503472
EarlyStopping counter: 1 out of 10
488.88819354415443 1046.886245476445
EarlyStopping counter: 2 out of 10
488.43787471535654 1046.637757927968
EarlyStopping counter: 3 out of 10
488.01810787163373 1046.311

In [4]:
import statistics

print(f'MAE: {statistics.mean(mae_list)}, {statistics.stdev(mae_list)}')
print(f'RMSE: {statistics.mean(rmse_list)}, {statistics.stdev(rmse_list)}')

MAE: 555.4028908284505, 34.10586975793963
RMSE: 1136.4205775960286, 65.13613775925027
