In [1]:
import dgl
from dgl.data.utils import load_graphs

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from model.model import HTGNN, NodePredictor
from utils.pytorchtools import EarlyStopping
from utils.data import load_COVID_data

dgl.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [2]:
def evaluate(model, val_feats, val_labels):
    val_mae_list, val_rmse_list = [], []
    model.eval()
    with torch.no_grad():
        for (G_feat, G_label) in zip(val_feats, val_labels):
            h = model[0](G_feat.to(device), 'state')
            pred = model[1](h)
            label = G_label.nodes['state'].data['feat']
            loss = F.l1_loss(pred, label.to(device))
            rmse = torch.sqrt(F.mse_loss(pred, label.to(device)))

            val_mae_list.append(loss.item())
            val_rmse_list.append(rmse.item())
            
        loss = sum(val_mae_list) / len(val_mae_list)
        rmse = sum(val_rmse_list) / len(val_rmse_list)

    return loss, rmse

In [None]:
device = torch.device('cuda:0')
glist, _ = load_graphs('data/covid_graphs.bin')
time_window = 7

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_COVID_data(glist, time_window)

graph_atom = test_feats[0]
mae_list, rmse_list = [], []
model_out_path = 'output/COVID19'

for k in range(5):
    htgnn = HTGNN(graph=graph_atom, n_inp=1, n_hid=8, n_layers=2, n_heads=1, time_window=time_window, norm=False, device=device)
    predictor = NodePredictor(n_inp=8, n_classes=1)
    model = nn.Sequential(htgnn, predictor).to(device)

    print(f'---------------Repeat time: {k+1}---------------------')
    print(f'# params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    
    early_stopping = EarlyStopping(patience=10, verbose=True, path=f'{model_out_path}/checkpoint_HTGNN_{k}.pt')
    optim = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
    
    train_mae_list, train_rmse_list = [], []
    idx = np.random.permutation(len(train_feats))

    for epoch in range(500):
        model.train()
        for i in idx:

            G_feat = train_feats[i]
            G_label = train_labels[i]
            
            h = model[0](G_feat.to(device), 'state')
            pred = model[1](h)
            label = G_label.nodes['state'].data['feat']
            loss = F.l1_loss(pred, label.to(device))
            rmse = torch.sqrt(F.mse_loss(pred, label.to(device)))

            train_mae_list.append(loss.item())
            train_rmse_list.append(rmse.item())
            optim.zero_grad()
            loss.backward()
            optim.step()
        print(sum(train_mae_list) / len(train_mae_list), sum(train_rmse_list) / len(train_rmse_list))

        loss, rmse = evaluate(model, val_feats, val_labels)
        early_stopping(loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    model.load_state_dict(torch.load(f'{model_out_path}/checkpoint_HTGNN_{k}.pt'))
    mae, rmse = evaluate(model, test_feats, test_labels)

    print(f'mae: {mae}, rmse: {rmse}')
    mae_list.append(mae)
    rmse_list.append(rmse)

---------------Repeat time: 1---------------------
# params: 7797
503.6185465945473 1059.5776576432497
Validation loss decreased (inf --> 1246.618062).  Saving model ...


In [None]:
import statistics

print(f'MAE: {statistics.mean(mae_list)}, {statistics.stdev(mae_list)}')
print(f'RMSE: {statistics.mean(rmse_list)}, {statistics.stdev(rmse_list)}')