In [1]:
'''
This code opens the Microsoft Academic Graph Dataset and trains HGT
Based on code provided by original HGT paper
'''
import torch
from hgt import *
from hgt_utils import *
from model import *
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.nodeproppred import Evaluator
import multiprocessing as mp
import argparse
import numpy as np
import time
import pandas as pd
import matplotlib as plt
import seaborn as sb

print("Microsoft Academic Graph Dataset Experiment")

Microsoft Academic Graph Dataset Experiment


In [2]:
'''
Data Preprocessing
ogbn-mag only comes with paper node features, thus for other nodes types we take the average
of connected paper nodes as input features. 
'''
print("Begin Data Preprocessing")
print("")
print("Retrieving Data from Open Graph Benchmark ...")

# Get dataset using Pytorch Geometric Loader
dataset = PygNodePropPredDataset(name='ogbn-mag')
print("... Retrieval complete")
data = dataset[0] # pyg graph object

Begin Data Preprocessing

Retrieving Data from Open Graph Benchmark ...
... Retrieval complete


In [3]:
evaluator = Evaluator(name='ogbn-mag')
# Preparing Graph
graph, y, train_paper, valid_paper, test_paper = prepare_graph(data, dataset)

Populating edge lists into Graph object
('author', 'affiliated_with', 'institution')
('author', 'writes', 'paper')
('paper', 'cites', 'paper')
('paper', 'has_topic', 'field_of_study')

Reformatting edge lists and computing node degrees
institution author affiliated_with 8740
author institution rev_affiliated_with 852987
author paper rev_writes 1134649
paper author writes 736389
paper paper cites 629169
paper paper rev_cites 617924
paper field_of_study rev_has_topic 736389
field_of_study paper has_topic 59965

Constructing node feature vectors for each node type in graph
author
field_of_study
institution
paper

Constructing Node features for institutions

Splitting dataset into train, val and test

Creating Masks

Preprocessing complete


In [4]:
'''
Creating Model
'''
print("Creating Model")
hgt_GNN = HGTModel(len(graph.node_feature['paper'][0]), # input_dim
                   256,                                 # hidden_dim
                   len(graph.get_types()),              # num_node_types
                   len(graph.get_meta_graph()),         # num_edge_types
                   8,                                   # num_heads
                   4,                                   # num_layers
                   0.2,                                 # dropout
                   prev_norm = True,                    # normalization on all but last layer
                   last_norm = False,                   # normalization on last layer
                   use_rte = False)                     # use relative temporal encoding 
classifier = Classifier(256, graph.y.max()+1)

print(f'Classifier output dim is: {graph.y.max()+1}')

HGT_classifier = nn.Sequential(hgt_GNN, classifier)

print(HGT_classifier)

Creating Model
Classifier output dim is: 349
Sequential(
  (0): HGTModel(
    (adapt_features): ModuleList(
      (0-3): 4 x Linear(in_features=129, out_features=256, bias=True)
    )
    (hgt_layers): ModuleList(
      (0-3): 4 x HGTLayer()
    )
    (drop): Dropout(p=0.2, inplace=False)
  )
  (1): Classifier(n_hid=256, n_out=349)
)


In [5]:
'''
Preprocessing data
'''
batch_number = 32 # number of sampled graphs for each epoch
batch_size = 128
num_epochs = 10
num_workers = 8
clip = 1.0
sample_depth = 6
sample_width = 520
plot = False # True or false to plot data
target_nodes = np.arange(len(graph.node_feature['paper']))

# Negative Log Likelihood Loss
criterion = nn.NLLLoss()

# Get list of model parameters w/ associated names
parameters_optimizer = list(HGT_classifier.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in parameters_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in parameters_optimizer if any(nd in n for nd in no_decay)],     'weight_decay': 0.0}
]
# AdamW optimizer w/specified parameter groups and epsilon value
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=1e-06)
# Create a OneCycleLR learning rate scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.05, anneal_strategy='linear', final_div_factor=10,\
                        max_lr = 5e-4, total_steps = batch_size * num_epochs + 1)
stats = []
result = []
best_val = 0
training_step = 0

In [6]:
# Preparing parallel processing
pool = mp.Pool(num_workers)

# sample nodes:
samp_nodes = np.random.choice(target_nodes, batch_size, replace = False)
sample_depth = 6
sample_width = 520           
inp = {'paper': np.concatenate([samp_nodes, graph.years[samp_nodes]]).reshape(2, -1).transpose()}
seed = randint()

In [8]:
#node_feature, node_type, edge_time, edge_index, edge_type, (train_mask, valid_mask, test_mask), ylabel = ogbn_sample(seed, samp_nodes, graph, sample_depth, sample_width)
np.random.seed(seed)
ylabel      = torch.LongTensor(graph.y[samp_nodes])
feature, times, edge_list, indxs, _ = sample_subgraph(
    graph,
    inp = {'paper': np.concatenate([samp_nodes, graph.years[samp_nodes]]).reshape(2, -1).transpose()},
    sampled_depth = sample_depth, 
    sampled_number = sample_width,
    feature_extractor = feature_MAG)

{'paper': array([[-0.125205  , -0.134075  , -0.20040999, ..., -0.114702  ,
        -0.434517  ,  1.32221929],
       [-0.26176801,  0.051687  , -0.13126799, ...,  0.164425  ,
        -0.183248  ,  1.30103   ],
       [-0.011906  ,  0.162616  ,  0.01099   , ..., -0.03533   ,
        -0.176313  ,  1.50514998],
       ...,
       [-0.080194  ,  0.12802701, -0.195675  , ..., -0.141542  ,
        -0.124276  ,  1.39794001],
       [-0.075198  , -0.056696  , -0.080165  , ..., -0.171854  ,
        -0.32188699,  1.5797836 ],
       [-0.41930601,  0.048521  , -0.201967  , ...,  0.018828  ,
        -0.52563   ,  1.44715803]]), 'author': array([[ 7.51385358e-02, -1.08575936e-01, -1.86865813e-01, ...,
        -8.24531386e-02, -3.51220616e-01,  2.57634135e+00],
       [-2.73372178e-02,  1.02903695e-01, -1.56183869e-01, ...,
        -6.17252172e-02, -3.33473914e-01,  1.39794001e+00],
       [-1.63193751e-01,  1.29883290e-03, -2.85722251e-01, ...,
        -2.33118668e-01, -1.62728502e-01,  1.11394335e