In [24]:
import dgl
from dgl.data.utils import load_graphs
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from model.model import HTGNN, NodePredictor
from utils.pytorchtools import EarlyStopping
from sklearn import metrics

dgl.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [25]:
# device = torch.device('cuda:1')
device = torch.device('cpu')

In [26]:

train_feats, _ = load_graphs('./data/dgraph/train_feats.bin')
valid_feats, _ = load_graphs('./data/dgraph/valid_feats.bin')
test_feats, _ = load_graphs('./data/dgraph/test_feats.bin')

train_labels = torch.load("./data/dgraph/train_labels.pt")
valid_labels = torch.load("./data/dgraph/valid_labels.pt")
test_labels = torch.load("./data/dgraph/test_labels.pt")

In [27]:
train_feats[0], valid_feats[0]

(Graph(num_nodes={'A': 521, 'B': 916, 'C': 307},
       num_edges={('A', '10_t1', 'A'): 42, ('A', '10_t1', 'B'): 46, ('A', '10_t1', 'C'): 32, ('A', '11_t1', 'A'): 10, ('A', '11_t1', 'B'): 15, ('A', '11_t1', 'C'): 9, ('A', '9_t1', 'A'): 15, ('A', '9_t1', 'B'): 58, ('A', '9_t1', 'C'): 14, ('B', '10_t1', 'A'): 64, ('B', '10_t1', 'B'): 142, ('B', '10_t1', 'C'): 78, ('B', '11_t1', 'A'): 23, ('B', '11_t1', 'B'): 72, ('B', '11_t1', 'C'): 17, ('B', '9_t1', 'A'): 89, ('B', '9_t1', 'B'): 18, ('B', '9_t1', 'C'): 26, ('C', '10_t1', 'A'): 23, ('C', '10_t1', 'B'): 25, ('C', '10_t1', 'C'): 12, ('C', '11_t1', 'A'): 1, ('C', '11_t1', 'B'): 6, ('C', '11_t1', 'C'): 6, ('C', '9_t1', 'A'): 21, ('C', '9_t1', 'B'): 11, ('C', '9_t1', 'C'): 5},
       metagraph=[('A', 'A', '10_t1'), ('A', 'A', '11_t1'), ('A', 'A', '9_t1'), ('A', 'B', '10_t1'), ('A', 'B', '11_t1'), ('A', 'B', '9_t1'), ('A', 'C', '10_t1'), ('A', 'C', '11_t1'), ('A', 'C', '9_t1'), ('B', 'A', '10_t1'), ('B', 'A', '11_t1'), ('B', 'A', '9_t1'), ('B'

In [28]:
train_feats[0].nodes['A'].data['feat'].shape

torch.Size([521, 16])

In [29]:
# train_labels[0], valid_labels[0]

In [30]:
time_window = 2

In [31]:
def valid_graph_feat(g_feat, time_window):
    all_etype_t = sorted(
        list(set([etype.split("_")[-1] for _, etype, _ in g_feat.canonical_etypes]))
    )

    if len(all_etype_t) >= time_window:
        return True
    else:
        return False

In [38]:
def write_to_file(value, fpath, name=None):
    with open(fpath, 'a') as fout:
        fout.write(f"{value}\n")

In [33]:
def evaluate(model, val_feats, val_labels, pred_node_type="ALL"):
    val_mae_list, val_rmse_list = [], []
    val_auc_list, val_ap_list = [], []

    model.eval()

    with torch.no_grad():
        for i, (G_feat, G_label) in enumerate(zip(val_feats, val_labels)):
            if not valid_graph_feat(G_feat, time_window):
                continue
            try:
                h = model[0](G_feat.to(device), pred_node_type)
                f_labels = []
                f_pred = []
                for ntype in G_label.keys():
                    pred = model[1](h[ntype])
                    label = G_label[ntype].view(-1, 1)

                    label_mask = (label == 0) | (label == 1)

                    masked_label = label[label_mask]
                    masked_pred = pred[label_mask]

                    f_labels.append(masked_label)
                    f_pred.append(masked_pred)

                f_labels = torch.cat(f_labels)
                f_pred = torch.cat(f_pred)

                loss = F.l1_loss(f_pred, f_labels)
                rmse = torch.sqrt(F.mse_loss(f_pred, f_labels))
            except Exception as e:
                print(f"failed val index: {i}")
                raise Exception(e)

            val_mae_list.append(loss.item())
            val_rmse_list.append(rmse.item())

            if f_labels.unique().shape[0] >= 2:
                # AUC
                fpr, tpr, thresholds = metrics.roc_curve(
                    f_labels.numpy(), f_pred.numpy()
                )
                auc = metrics.auc(fpr, tpr)

                # AP
                precision, recall, thresholds = metrics.precision_recall_curve(
                    f_labels.numpy(), f_pred.numpy()
                )
                ap = metrics.auc(recall, precision)

                val_auc_list.append(auc)
                val_ap_list.append(ap)

        loss = sum(val_mae_list) / len(val_mae_list)
        rmse = sum(val_rmse_list) / len(val_rmse_list)

        auc = sum(val_auc_list) / len(val_auc_list)
        ap = sum(val_ap_list) / len(val_ap_list)

        print(f"\tEval MAE/RMSE: {loss} / {rmse}")
        print(f"\tEval AUC/AP: {auc} / {ap}")

    return loss, rmse, auc, ap

In [34]:

graph_atom = train_feats[10]
mae_list, rmse_list = [], []
model_out_path = 'checkpoint'


In [35]:
htgnn = HTGNN(graph=graph_atom, n_inp=16, n_hid=8, n_layers=2, n_heads=1, time_window=time_window, norm=False, device=device)
predictor = NodePredictor(n_inp=8, n_classes=1)
model = nn.Sequential(htgnn, predictor)

In [36]:
early_stopping = EarlyStopping(patience=10, verbose=True, path=f'{model_out_path}/checkpoint_HTGNN.pt')
optim = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)

train_mae_list, train_rmse_list = [], []
idx = np.random.permutation(len(train_feats))

In [39]:
pred_node_type = "ALL"

for epoch in range(200):
    model.train()

    print(f"============ Epoch {epoch} ============")
    for i in tqdm(idx):
        G_feat = train_feats[i]
        G_label = train_labels[i]

        # check if graph contains more than 2 windows
        if not valid_graph_feat(G_feat, time_window):
            continue

        h = model[0](G_feat, pred_node_type)

        f_labels = []
        f_pred = []
        for ntype in G_label.keys():
            pred = model[1](h[ntype])
            label = G_label[ntype].view(-1, 1)

            label_mask = (label == 0) | (label == 1)

            masked_label = label[label_mask]
            masked_pred = pred[label_mask]

            f_labels.append(masked_label)
            f_pred.append(masked_pred)

        f_labels = torch.cat(f_labels)
        f_pred = torch.cat(f_pred)

        loss = F.l1_loss(f_pred, f_labels)
        rmse = torch.sqrt(F.mse_loss(f_pred, f_labels))

        train_mae_list.append(loss.item())
        train_rmse_list.append(rmse.item())
        optim.zero_grad()
        loss.backward()
        optim.step()

    epoch_mae = sum(train_mae_list) / len(train_mae_list)
    epoch_rmse = sum(train_rmse_list) / len(train_rmse_list)
    print(f"Epoch MAE/RMSE: {epoch_mae} / {epoch_rmse}")

    write_to_file(epoch_mae, "./results/dgraph/train_mae.txt")
    write_to_file(epoch_rmse, "./results/dgraph/train_rmse.txt")

    if epoch % 2 == 0:
        loss, rmse, auc, ap = evaluate(model, valid_feats, valid_labels)
        write_to_file(loss, "./results/dgraph/eval_mae.txt")
        write_to_file(rmse, "./results/dgraph/eval_rmse.txt")
        write_to_file(auc, "./results/dgraph/eval_auc.txt")
        write_to_file(ap, "./results/dgraph/eval_ap.txt")
        early_stopping(loss, model)



  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.30266185104846954 / 0.3324214965105057
	Eval MAE/RMSE: 0.1980592906475067 / 0.2372611165046692
	Eval AUC/AP: 0.5733204134366925 / 0.02407563303866711
Validation loss decreased (inf --> 0.198059).  Saving model ...


  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.263452226916949 / 0.29489146173000336


  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.2290414460003376 / 0.2632531076669693
	Eval MAE/RMSE: 0.09635315090417862 / 0.16672970354557037
	Eval AUC/AP: 0.6030361757105943 / 0.026009562491575662
Validation loss decreased (0.198059 --> 0.096353).  Saving model ...


  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.1996922492980957 / 0.2374409854412079


  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.1749105310688416 / 0.21695010488231978
	Eval MAE/RMSE: 0.037790823727846146 / 0.15158723294734955
	Eval AUC/AP: 0.6941214470284238 / 0.033173376397178184
Validation loss decreased (0.096353 --> 0.037791).  Saving model ...


  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.15366981684097222 / 0.20104173570871353


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.1363360066898167 / 0.18865781743079424
	Eval MAE/RMSE: 0.023528387770056725 / 0.1508377492427826
	Eval AUC/AP: 0.4786821705426357 / 0.011363636363636364
Validation loss decreased (0.037791 --> 0.023528).  Saving model ...


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.12241268737448587 / 0.17893750303321415


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.11120052821934223 / 0.17117129266262054
	Eval MAE/RMSE: 0.022727273404598236 / 0.15075567364692688
	Eval AUC/AP: 0.5 / 0.5113636363636364
Validation loss decreased (0.023528 --> 0.022727).  Saving model ...


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.10202379025180232 / 0.16481698778542606


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.09437650861218572 / 0.15952173372109732
	Eval MAE/RMSE: 0.022727273404598236 / 0.15075567364692688
	Eval AUC/AP: 0.5 / 0.5113636363636364
Validation loss decreased (0.022727 --> 0.022727).  Saving model ...


  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.08790573184020244 / 0.15504113412820375


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.08235935174993106 / 0.15120062019143785
	Eval MAE/RMSE: 0.022727273404598236 / 0.15075567364692688
	Eval AUC/AP: 0.5 / 0.5113636363636364
Validation loss decreased (0.022727 --> 0.022727).  Saving model ...


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.0775524890050292 / 0.14787217477957407


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.07334648410324007 / 0.14495978504419327
	Eval MAE/RMSE: 0.022727273404598236 / 0.15075567364692688
	Eval AUC/AP: 0.5 / 0.5113636363636364
Validation loss decreased (0.022727 --> 0.022727).  Saving model ...


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.06963530330754378 / 0.14239002939532786


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.06633647593359153 / 0.14010580215189192
	Eval MAE/RMSE: 0.022727273404598236 / 0.15075567364692688
	Eval AUC/AP: 0.5 / 0.5113636363636364
Validation loss decreased (0.022727 --> 0.022727).  Saving model ...


  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.06338489354637108 / 0.13806201988144925


  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.060728469397872686 / 0.13622261583805084
	Eval MAE/RMSE: 0.022727273404598236 / 0.15075567364692688
	Eval AUC/AP: 0.5 / 0.5113636363636364
Validation loss decreased (0.022727 --> 0.022727).  Saving model ...


  0%|          | 0/821 [00:01<?, ?it/s]


Epoch MAE/RMSE: 0.05832503802542176 / 0.13455839313211895


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.05614010041410273 / 0.13304546339945358
	Eval MAE/RMSE: 0.022727273404598236 / 0.15075567364692688
	Eval AUC/AP: 0.5 / 0.5113636363636364
Validation loss decreased (0.022727 --> 0.022727).  Saving model ...


  0%|          | 0/821 [00:00<?, ?it/s]


Epoch MAE/RMSE: 0.054145157377681004 / 0.13166409277397653


  0%|          | 0/821 [00:01<?, ?it/s]


KeyboardInterrupt: 