In [1]:
import dgl
from dgl.data.utils import load_graphs

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from model.model_gcn import HTGNN, LinkPredictor
from utils.pytorchtools import EarlyStopping
from utils.utils import compute_metric, compute_loss
from utils.data import load_MAG_data

dgl.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def evaluate(model, val_feats, val_labels):
    val_mae_list, val_rmse_list = [], []
    model.eval()
    with torch.no_grad():
        for (G_feat, (pos_label, neg_label)) in zip(val_feats, val_labels):

            G_feat = G_feat.to(device)
            pos_label = pos_label.to(device)
            neg_label = neg_label.to(device)

            h = model[0](G_feat, 'author')
            pos_score = model[1](pos_label, h)
            neg_score = model[1](neg_label, h)

            loss = compute_loss(pos_score, neg_score, device)
            auc, ap = compute_metric(pos_score, neg_score)
    
    return auc, ap

In [3]:
glist, label_dict = load_graphs('data/ogbn_graphs.bin')
device = torch.device('cuda')
time_window = 3

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_MAG_data(glist, time_window, device)

graph_atom = test_feats[0]
model_out_path = 'output/OGBN-MAG'
auc_list, ap_list = [], []


loading mp2vec
generating train, val, test sets 


In [5]:
len(val_feats)

1

: 

In [10]:
htgnn = HTGNN(graph=graph_atom, n_inp=128, n_hid=32, n_layers=2, n_heads=1, time_window=time_window, norm=True, device=device)
predictor = LinkPredictor(n_inp=32, n_classes=1)
model = nn.Sequential(htgnn, predictor).to(device)
model.load_state_dict(torch.load('/home/jiazhengli/xdgnn/HTGNN/output/OGBN-MAG/checkpoint_HTGNN_0.pt'))

<All keys matched successfully>

In [12]:
sg, inverse_indices = dgl.khop_in_subgraph(train_feats[0], {'author': 3}, k=2, store_ids=True)

In [17]:
test_feats[0]

Graph(num_nodes={'author': 17764, 'field_of_study': 23109, 'institution': 2276, 'paper': 84344},
      num_edges={('author', 'affiliated_with_t0', 'institution'): 40307, ('author', 'affiliated_with_t1', 'institution'): 40307, ('author', 'affiliated_with_t2', 'institution'): 40307, ('author', 'writes_t0', 'paper'): 256070, ('author', 'writes_t1', 'paper'): 235822, ('author', 'writes_t2', 'paper'): 315154, ('field_of_study', 'has_topic_r_t0', 'paper'): 310210, ('field_of_study', 'has_topic_r_t1', 'paper'): 281615, ('field_of_study', 'has_topic_r_t2', 'paper'): 258168, ('institution', 'affiliated_with_r_t0', 'author'): 40307, ('institution', 'affiliated_with_r_t1', 'author'): 40307, ('institution', 'affiliated_with_r_t2', 'author'): 40307, ('paper', 'cites_r_t0', 'paper'): 25054, ('paper', 'cites_r_t1', 'paper'): 22423, ('paper', 'cites_r_t2', 'paper'): 23017, ('paper', 'cites_t0', 'paper'): 25054, ('paper', 'cites_t1', 'paper'): 22423, ('paper', 'cites_t2', 'paper'): 23017, ('paper', 'ha

In [26]:
train_labels[0][0]

Graph(num_nodes=17764, num_edges=74693,
      ndata_schemes={}
      edata_schemes={})

In [12]:
for k in range(5):
    htgnn = HTGNN(graph=graph_atom, n_inp=128, n_hid=32, n_layers=2, n_heads=1, time_window=time_window, norm=True, device=device)
    predictor = LinkPredictor(n_inp=32, n_classes=1)
    model = nn.Sequential(htgnn, predictor).to(device)

    print(f'---------------Repeat time: {k+1}---------------------')
    print(f'# params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    optim = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)

    early_stopping = EarlyStopping(patience=50, verbose=True, path=f'{model_out_path}/checkpoint_HTGNN_{k}.pt')
    for epoch in range(500):
        model.train()
        for (G_feat, (pos_label, neg_label)) in zip(train_feats, train_labels):

            G_feat = G_feat.to(device)

            pos_label = pos_label.to(device)
            neg_label = neg_label.to(device)

            h = model[0](G_feat, 'author')

            pos_score = model[1](pos_label, h)
            neg_score = model[1](neg_label, h)
            
            loss = compute_loss(pos_score, neg_score, device)
            auc, ap = compute_metric(pos_score, neg_score)

            optim.zero_grad()
            loss.backward()
            optim.step()
        
        auc, ap = evaluate(model, val_feats, val_labels)
        early_stopping(loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    model.load_state_dict(torch.load(f'{model_out_path}/checkpoint_HTGNN_{k}.pt'))
    auc, ap = evaluate(model, test_feats, test_labels)

    print(f'auc: {auc}, ap: {ap}')
    auc_list.append(auc)
    ap_list.append(ap)

---------------Repeat time: 1---------------------
# params: 126281
Validation loss decreased (inf --> 0.693301).  Saving model ...
Validation loss decreased (0.693301 --> 0.597250).  Saving model ...
Validation loss decreased (0.597250 --> 0.561513).  Saving model ...
Validation loss decreased (0.561513 --> 0.538988).  Saving model ...
Validation loss decreased (0.538988 --> 0.518697).  Saving model ...
Validation loss decreased (0.518697 --> 0.494623).  Saving model ...
Validation loss decreased (0.494623 --> 0.482902).  Saving model ...
Validation loss decreased (0.482902 --> 0.464226).  Saving model ...
Validation loss decreased (0.464226 --> 0.448665).  Saving model ...
Validation loss decreased (0.448665 --> 0.436250).  Saving model ...
Validation loss decreased (0.436250 --> 0.423156).  Saving model ...
Validation loss decreased (0.423156 --> 0.409451).  Saving model ...
Validation loss decreased (0.409451 --> 0.398901).  Saving model ...
Validation loss decreased (0.398901 --> 

KeyboardInterrupt: 

In [4]:
import statistics

print(f'AUC: {statistics.mean(auc_list)}, {statistics.stdev(auc_list)}')
print(f'AP: {statistics.mean(ap_list)}, {statistics.stdev(ap_list)}')

AUC: 0.9100877512024033, 0.007723909849077316
AP: 0.8917769733866674, 0.01237942594912229
