In [1]:
import dgl
from dgl.data.utils import load_graphs

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from model.model import HTGNN, NodePredictor
from utils.pytorchtools import EarlyStopping
from utils.data import load_COVID_data

dgl.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [2]:
# device = torch.device('cuda:1')
device = torch.device('cpu')

In [3]:
def evaluate(model, val_feats, val_labels):
    val_mae_list, val_rmse_list = [], []
    model.eval()
    with torch.no_grad():
        for (G_feat, G_label) in zip(val_feats, val_labels):
            h = model[0](G_feat.to(device), 'state')
            pred = model[1](h)
            label = G_label.nodes['state'].data['feat']
            loss = F.l1_loss(pred, label.to(device))
            rmse = torch.sqrt(F.mse_loss(pred, label.to(device)))

            val_mae_list.append(loss.item())
            val_rmse_list.append(rmse.item())
            
        loss = sum(val_mae_list) / len(val_mae_list)
        rmse = sum(val_rmse_list) / len(val_rmse_list)

    return loss, rmse

In [4]:

glist, _ = load_graphs('data/covid_graphs.bin')
time_window = 7

In [32]:
glist[0].in_edges(glist[0].nodes('state'), etype='affiliate_r')[0]

tensor([   0,    1,    2,  ..., 3138, 3139, 3140], dtype=torch.int32)

In [34]:
# glist[0].nodes['state'].data["feat"]

In [5]:
# _graph.nodes['county'].data['feat'].shape, _graph.nodes['state'].data['feat'].shape

In [6]:
glist[1].nodes['county'].data['feat'].shape, glist[1].nodes['state'].data['feat'].shape

(torch.Size([3223, 1]), torch.Size([51, 1]))

In [7]:
glist[1].nodes('county')

tensor([   0,    1,    2,  ..., 3220, 3221, 3222], dtype=torch.int32)

In [8]:
# train_feats[45].nodes['county'].data['feat'].shape, train_feats[45].nodes['state'].data['feat'].shape

In [9]:
train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_COVID_data(glist, time_window)

In [29]:
train_feats[0].nodes('state')

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
       dtype=torch.int32)

In [30]:
# _l = []
# for i in range(len(train_feats)):
#     _l.append(train_feats[i].nodes['county'].data['t0'].shape)
#     if i == 0:
#         prev = train_feats[i].nodes['county'].data['t0'].shape
#         print(train_feats[i].nodes['county'].data['t0'].shape)
        
#     else:
#         if train_feats[i].nodes['county'].data['t0'].shape == prev:
#             pass
#         else:
#             print(train_feats[i].nodes['county'].data['t0'].shape)

In [35]:
train_feats[i].nodes['county'].data['t6'].shape

torch.Size([3223, 1])

In [10]:
(train_feats[45].nodes['county'].data['t0'].shape, 
val_feats[2].nodes['county'].data['t0'].shape
)

(torch.Size([3223, 1]), torch.Size([3223, 1]))

In [11]:
train_feats[0].canonical_etypes

[('county', 'affiliate_r_t0', 'state'),
 ('county', 'affiliate_r_t1', 'state'),
 ('county', 'affiliate_r_t2', 'state'),
 ('county', 'affiliate_r_t3', 'state'),
 ('county', 'affiliate_r_t4', 'state'),
 ('county', 'affiliate_r_t5', 'state'),
 ('county', 'affiliate_r_t6', 'state'),
 ('county', 'nearby_county_t0', 'county'),
 ('county', 'nearby_county_t1', 'county'),
 ('county', 'nearby_county_t2', 'county'),
 ('county', 'nearby_county_t3', 'county'),
 ('county', 'nearby_county_t4', 'county'),
 ('county', 'nearby_county_t5', 'county'),
 ('county', 'nearby_county_t6', 'county'),
 ('state', 'affiliate_t0', 'county'),
 ('state', 'affiliate_t1', 'county'),
 ('state', 'affiliate_t2', 'county'),
 ('state', 'affiliate_t3', 'county'),
 ('state', 'affiliate_t4', 'county'),
 ('state', 'affiliate_t5', 'county'),
 ('state', 'affiliate_t6', 'county'),
 ('state', 'nearby_state_t0', 'state'),
 ('state', 'nearby_state_t1', 'state'),
 ('state', 'nearby_state_t2', 'state'),
 ('state', 'nearby_state_t3', 'st

In [12]:

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_COVID_data(glist, time_window)

graph_atom = test_feats[0]
mae_list, rmse_list = [], []
model_out_path = 'output/COVID19'

for k in range(5):
    htgnn = HTGNN(graph=graph_atom, n_inp=1, n_hid=8, n_layers=2, n_heads=1, time_window=time_window, norm=False, device=device)
    predictor = NodePredictor(n_inp=8, n_classes=1)
    model = nn.Sequential(htgnn, predictor).to(device)

    print(f'---------------Repeat time: {k+1}---------------------')
    print(f'# params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    
    early_stopping = EarlyStopping(patience=10, verbose=True, path=f'{model_out_path}/checkpoint_HTGNN_{k}.pt')
    optim = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
    
    train_mae_list, train_rmse_list = [], []
    idx = np.random.permutation(len(train_feats))

    for epoch in range(500):
        model.train()
        for i in idx:

            G_feat = train_feats[i]
            G_label = train_labels[i]
            
            h = model[0](G_feat.to(device), 'state')
            pred = model[1](h)
            label = G_label.nodes['state'].data['feat']
            loss = F.l1_loss(pred, label.to(device))
            rmse = torch.sqrt(F.mse_loss(pred, label.to(device)))

            train_mae_list.append(loss.item())
            train_rmse_list.append(rmse.item())
            optim.zero_grad()
            loss.backward()
            optim.step()
        print(sum(train_mae_list) / len(train_mae_list), sum(train_rmse_list) / len(train_rmse_list))

        loss, rmse = evaluate(model, val_feats, val_labels)
        early_stopping(loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    model.load_state_dict(torch.load(f'{model_out_path}/checkpoint_HTGNN_{k}.pt'))
    mae, rmse = evaluate(model, test_feats, test_labels)

    print(f'mae: {mae}, rmse: {rmse}')
    mae_list.append(mae)
    rmse_list.append(rmse)

---------------Repeat time: 1---------------------
# params: 7797


KeyboardInterrupt: 

In [None]:

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_COVID_data(glist, time_window)

graph_atom = test_feats[0]
mae_list, rmse_list = [], []
model_out_path = 'output/COVID19'

for k in range(5):
    htgnn = HTGNN(graph=graph_atom, n_inp=1, n_hid=8, n_layers=2, n_heads=1, time_window=time_window, norm=False, device=device)
    predictor = NodePredictor(n_inp=8, n_classes=1)
    model = nn.Sequential(htgnn, predictor).to(device)

    print(f'---------------Repeat time: {k+1}---------------------')
    print(f'# params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    
    early_stopping = EarlyStopping(patience=10, verbose=True, path=f'{model_out_path}/checkpoint_HTGNN_{k}.pt')
    optim = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
    
    train_mae_list, train_rmse_list = [], []
    idx = np.random.permutation(len(train_feats))

    for epoch in range(500):
        model.train()
        for i in idx:

            G_feat = train_feats[i]
            G_label = train_labels[i]
            
            h = model[0](G_feat.to(device), 'state')
            pred = model[1](h)
            label = G_label.nodes['state'].data['feat']
            loss = F.l1_loss(pred, label.to(device))
            rmse = torch.sqrt(F.mse_loss(pred, label.to(device)))

            train_mae_list.append(loss.item())
            train_rmse_list.append(rmse.item())
            optim.zero_grad()
            loss.backward()
            optim.step()
        print(sum(train_mae_list) / len(train_mae_list), sum(train_rmse_list) / len(train_rmse_list))

        loss, rmse = evaluate(model, val_feats, val_labels)
        early_stopping(loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    model.load_state_dict(torch.load(f'{model_out_path}/checkpoint_HTGNN_{k}.pt'))
    mae, rmse = evaluate(model, test_feats, test_labels)

    print(f'mae: {mae}, rmse: {rmse}')
    mae_list.append(mae)
    rmse_list.append(rmse)

In [4]:
import statistics

print(f'MAE: {statistics.mean(mae_list)}, {statistics.stdev(mae_list)}')
print(f'RMSE: {statistics.mean(rmse_list)}, {statistics.stdev(rmse_list)}')

MAE: 555.4028908284505, 34.10586975793963
RMSE: 1136.4205775960286, 65.13613775925027


In [4]:
from utils.distributed_data import DataPartitioner, DGraphDataset
ddataset = DGraphDataset([1,2,3,4,5,6,7])
partitioner = DataPartitioner(ddataset, 4)

partitioner.partitions

[[0, 1], [2, 3], [4, 5], [6]]

In [3]:
ddataset[3]

4

[[0], [1], [2], [3]]

In [10]:
partitioner.use(2).index

[4, 5]