In [1]:
import dgl
from dgl.data.utils import load_graphs

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from model.model_gcn import HTGNN, NodePredictor
from utils.pytorchtools import EarlyStopping
from utils.data import load_COVID_data

dgl.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def evaluate(model, val_feats, val_labels):
    val_mae_list, val_rmse_list = [], []
    model.eval()
    with torch.no_grad():
        for (G_feat, G_label) in zip(val_feats, val_labels):
            h = model[0](G_feat.to(device), 'state')
            pred = model[1](h)
            label = G_label.nodes['state'].data['feat']
            loss = F.l1_loss(pred, label.to(device))
            rmse = torch.sqrt(F.mse_loss(pred, label.to(device)))

            val_mae_list.append(loss.item())
            val_rmse_list.append(rmse.item())
            
        loss = sum(val_mae_list) / len(val_mae_list)
        rmse = sum(val_rmse_list) / len(val_rmse_list)

    return loss, rmse

In [3]:
device = torch.device('cuda')
glist, _ = load_graphs('data/covid_graphs.bin')
time_window = 15

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_COVID_data(glist, time_window)

In [8]:
test_feats[0]

Graph(num_nodes={'county': 3223, 'state': 51},
      num_edges={('county', 'affiliate_r_t0', 'state'): 3141, ('county', 'affiliate_r_t1', 'state'): 3141, ('county', 'affiliate_r_t10', 'state'): 3141, ('county', 'affiliate_r_t11', 'state'): 3141, ('county', 'affiliate_r_t12', 'state'): 3141, ('county', 'affiliate_r_t13', 'state'): 3141, ('county', 'affiliate_r_t14', 'state'): 3141, ('county', 'affiliate_r_t2', 'state'): 3141, ('county', 'affiliate_r_t3', 'state'): 3141, ('county', 'affiliate_r_t4', 'state'): 3141, ('county', 'affiliate_r_t5', 'state'): 3141, ('county', 'affiliate_r_t6', 'state'): 3141, ('county', 'affiliate_r_t7', 'state'): 3141, ('county', 'affiliate_r_t8', 'state'): 3141, ('county', 'affiliate_r_t9', 'state'): 3141, ('county', 'nearby_county_t0', 'county'): 22176, ('county', 'nearby_county_t1', 'county'): 22176, ('county', 'nearby_county_t10', 'county'): 22176, ('county', 'nearby_county_t11', 'county'): 22176, ('county', 'nearby_county_t12', 'county'): 22176, ('county

In [6]:
len(test_feats)

30

In [3]:
device = torch.device('cuda')
glist, _ = load_graphs('data/covid_graphs.bin')
time_window = 15

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_COVID_data(glist, time_window)

graph_atom = test_feats[0]
mae_list, rmse_list = [], []
model_out_path = 'output/COVID19'

for k in range(1):
    htgnn = HTGNN(graph=graph_atom, n_inp=1, n_hid=8, n_layers=2, n_heads=1, time_window=time_window, norm=False, device=device)
    predictor = NodePredictor(n_inp=8, n_classes=1)
    model = nn.Sequential(htgnn, predictor).to(device)

    print(f'---------------Repeat time: {k+1}---------------------')
    print(f'# params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    
    early_stopping = EarlyStopping(patience=10, verbose=True, path=f'{model_out_path}/checkpoint_HTGNN_{k}_w15.pt')
    optim = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
    
    train_mae_list, train_rmse_list = [], []
    idx = np.random.permutation(len(train_feats))

    for epoch in range(500):
        model.train()
        for i in idx:

            G_feat = train_feats[i]
            G_label = train_labels[i]
            
            h = model[0](G_feat.to(device), 'state')
            pred = model[1](h)
            label = G_label.nodes['state'].data['feat']
            loss = F.l1_loss(pred, label.to(device))
            rmse = torch.sqrt(F.mse_loss(pred, label.to(device)))

            train_mae_list.append(loss.item())
            train_rmse_list.append(rmse.item())
            optim.zero_grad()
            loss.backward()
            optim.step()
        print(sum(train_mae_list) / len(train_mae_list), sum(train_rmse_list) / len(train_rmse_list))

        loss, rmse = evaluate(model, val_feats, val_labels)
        early_stopping(loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    model.load_state_dict(torch.load(f'{model_out_path}/checkpoint_HTGNN_{k}_w15.pt'))
    mae, rmse = evaluate(model, test_feats, test_labels)

    print(f'mae: {mae}, rmse: {rmse}')
    mae_list.append(mae)
    rmse_list.append(rmse)

---------------Repeat time: 1---------------------
# params: 12789
568.4905486460857 1192.6530862999796
Validation loss decreased (inf --> 1278.985883).  Saving model ...
558.1615822637966 1177.5434535996883
Validation loss decreased (1278.985883 --> 1271.857255).  Saving model ...
554.2137448395531 1172.9796439313818
Validation loss decreased (1271.857255 --> 1268.975551).  Saving model ...
554.176777294109 1172.1263535178905
EarlyStopping counter: 1 out of 10
553.3977641609558 1172.756441194626
EarlyStopping counter: 2 out of 10
553.1357552640824 1172.635151547983
Validation loss decreased (1268.975551 --> 1261.758828).  Saving model ...
552.6517899025996 1172.6969718433363
EarlyStopping counter: 1 out of 10
552.3200276470601 1172.5582042677433
EarlyStopping counter: 2 out of 10
551.6183533721731 1171.9199905580597
EarlyStopping counter: 3 out of 10
550.8775803811685 1171.3035742947109
EarlyStopping counter: 4 out of 10
549.410109065269 1169.1360139744581
Validation loss decreased (1

FileNotFoundError: [Errno 2] No such file or directory: 'output/COVID19/checkpoint_HTGNN_0_w15.pt'

In [4]:
import statistics

print(f'MAE: {statistics.mean(mae_list)}, {statistics.stdev(mae_list)}')
print(f'RMSE: {statistics.mean(rmse_list)}, {statistics.stdev(rmse_list)}')

MAE: 555.4028908284505, 34.10586975793963
RMSE: 1136.4205775960286, 65.13613775925027
