In [1]:
'''
This code opens the Microsoft Academic Graph Dataset and trains HGT
Based on code provided by original HGT paper
'''
import torch
from hgt import *
from hgt_utils import *
from model import *
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.nodeproppred import Evaluator
import multiprocessing as mp
import argparse
import numpy as np
import time
import pandas as pd
import matplotlib as plt
import seaborn as sb

print("Microsoft Academic Graph Dataset Experiment")

Microsoft Academic Graph Dataset Experiment


In [2]:
'''
Data Preprocessing
ogbn-mag only comes with paper node features, thus for other nodes types we take the average
of connected paper nodes as input features. 
'''
print("Begin Data Preprocessing")
print("")
print("Retrieving Data from Open Graph Benchmark ...")

# Get dataset using Pytorch Geometric Loader
dataset = PygNodePropPredDataset(name='ogbn-mag')
print("... Retrieval complete")
data = dataset[0] # pyg graph object

Begin Data Preprocessing

Retrieving Data from Open Graph Benchmark ...
... Retrieval complete


In [3]:
# Constructing Custom Graph
graph, y, train_paper, valid_paper, test_paper = prepare_graph(data, dataset)

Populating edge lists into Graph object
('author', 'affiliated_with', 'institution')
('author', 'writes', 'paper')
('paper', 'cites', 'paper')
('paper', 'has_topic', 'field_of_study')

Reformatting edge lists and computing node degrees
institution author affiliated_with 8740
author institution rev_affiliated_with 852987
author paper rev_writes 1134649
paper author writes 736389
paper paper cites 629169
paper paper rev_cites 617924
paper field_of_study rev_has_topic 736389
field_of_study paper has_topic 59965

Constructing node feature vectors for each node type in graph
author
field_of_study
institution
paper

Constructing Node features for institutions

Splitting dataset into train, val and test

Creating Masks

Preprocessing complete


In [10]:
# CUSTOMIZE PARAMETERS HERE
n_batch = 32        # number of sampled graphs for each epoch
batch_size = 128
num_epochs = 1
clip = 1.0
sample_depth = 6
sample_width = 520
plot = False # True or false to plot data

# Model parameters
hidden_dim = 256
num_heads = 8
num_layers = 4
dropout = 0.2

In [11]:
# Creating Model
print("Creating Model")
target_nodes = np.arange(len(graph.node_feature['paper']))
hgt_GNN = HGTModel(len(graph.node_feature['paper'][0]), # input_dim
                   hidden_dim,                          # hidden_dim
                   len(graph.get_types()),              # num_node_types
                   len(graph.get_meta_graph()),         # num_edge_types
                   num_heads,                           # num_heads
                   num_layers,                          # num_layers
                   dropout,                             # dropout
                   prev_norm = True,                    # normalization on all but last layer
                   last_norm = False,                   # normalization on last layer
                   use_rte = False)                     # use relative temporal encoding 
classifier = Classifier(hidden_dim, graph.y.max()+1)
model = nn.Sequential(hgt_GNN, classifier)
print(model)

Creating Model
Sequential(
  (0): HGTModel(
    (adapt_features): ModuleList(
      (0-3): 4 x Linear(in_features=129, out_features=256, bias=True)
    )
    (hgt_layers): ModuleList(
      (0-3): 4 x HGTLayer()
    )
    (drop): Dropout(p=0.2, inplace=False)
  )
  (1): Classifier(n_hid=256, n_out=349)
)


In [12]:
# Defining Optimizer, Scheduler, Loss, etc. 
criterion = nn.NLLLoss()
evaluator = Evaluator(name='ogbn-mag')

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],     'weight_decay': 0.0}
    ]

optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=1e-06)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.05, anneal_strategy='linear', final_div_factor=10,\
                        max_lr = 5e-4, total_steps = n_batch * num_epochs + 1)

In [13]:
stats = []
res = []
best_val = 0
train_step = 0

In [16]:
# Model Training
start_time = time.time()
epoch_counter = 0


for epoch in np.arange(num_epochs):
    print(f'Current Epoch is: {epoch_counter} ---------------------')
    
    # Get Data for Training
    datas = get_n_batches_training_data(n_batch, graph, sample_depth, sample_width, target_nodes, batch_size)

    # TRAINING
    model.train()
    stat = []
    for data in datas:
        node_feature = data[0]
        node_type = data[1]
        edge_time = data[2]
        edge_index = data[3]
        edge_type = data[4]
        (train_mask, valid_mask, test_mask) = data[5]
        ylabel = data[6]

        # Forward
        node_rep = hgt_GNN.forward(node_feature, node_type, edge_index, edge_type, edge_time)
        ylabel = torch.LongTensor(ylabel)
        train_res  = classifier.forward(node_rep[:len(ylabel)][train_mask])
        valid_res  = classifier.forward(node_rep[:len(ylabel)][valid_mask])
        test_res   = classifier.forward(node_rep[:len(ylabel)][test_mask])

        train_loss = criterion(train_res, ylabel[train_mask])
        
        optimizer.zero_grad()
        train_loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        train_step += 1
        scheduler.step(train_step)

        train_acc  = evaluator.eval({
                        'y_true': ylabel[train_mask].unsqueeze(-1),
                        'y_pred': train_res.argmax(dim=1).unsqueeze(-1)
                    })['acc']
        valid_acc  = evaluator.eval({
                        'y_true': ylabel[valid_mask].unsqueeze(-1),
                        'y_pred': valid_res.argmax(dim=1).unsqueeze(-1)
                    })['acc']
        test_acc   = evaluator.eval({
                        'y_true': ylabel[test_mask].unsqueeze(-1),
                        'y_pred': test_res.argmax(dim=1).unsqueeze(-1)
                    })['acc']
        print(train_acc)
        stat += [[train_loss.item(), train_acc, valid_acc, test_acc]]
        del node_rep, train_loss, ylabel

    epoch_counter += 1
    print("")
stop_time = time.time()
time_elapsed = stop_time - start_time
print(f'time elapsed is: {time_elapsed}')

Current Epoch is: 0 ---------------------
Starting Sampling...
Batch number: 0
Batch number: 1
Batch number: 2
Batch number: 3
Batch number: 4
Batch number: 5
Batch number: 6
Batch number: 7
Batch number: 8
