In [1]:
'''
This code opens the Microsoft Academic Graph Dataset and trains HGT
Based on code provided by original HGT paper
'''
import torch
from hgt import *
from hgt_utils import *
from model import *
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.nodeproppred import Evaluator
import multiprocessing as mp
import argparse
import numpy as np
import time
import pandas as pd
import matplotlib as plt
import seaborn as sb

print("Microsoft Academic Graph Dataset Experiment")

Microsoft Academic Graph Dataset Experiment


In [2]:
'''
Data Preprocessing
ogbn-mag only comes with paper node features, thus for other nodes types we take the average
of connected paper nodes as input features. 
'''
print("Begin Data Preprocessing")
print("")
print("Retrieving Data from Open Graph Benchmark ...")

# Get dataset using Pytorch Geometric Loader
dataset = PygNodePropPredDataset(name='ogbn-mag')
print("... Retrieval complete")
data = dataset[0] # pyg graph object

Begin Data Preprocessing

Retrieving Data from Open Graph Benchmark ...
... Retrieval complete


In [3]:
evaluator = Evaluator(name='ogbn-mag')
# Preparing Graph
graph, y, train_paper, valid_paper, test_paper = prepare_graph(data, dataset)

Populating edge lists into Graph object
('author', 'affiliated_with', 'institution')
('author', 'writes', 'paper')
('paper', 'cites', 'paper')
('paper', 'has_topic', 'field_of_study')

Reformatting edge lists and computing node degrees
institution author affiliated_with 8740
author institution rev_affiliated_with 852987
author paper rev_writes 1134649
paper author writes 736389
paper paper cites 629169
paper paper rev_cites 617924
paper field_of_study rev_has_topic 736389
field_of_study paper has_topic 59965

Constructing node feature vectors for each node type in graph
author
field_of_study
institution
paper

Constructing Node features for institutions

Splitting dataset into train, val and test

Creating Masks

Preprocessing complete


In [4]:
print(dir(graph))
print(len(graph.node_feature['paper'][0]))

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'add_edge', 'add_node', 'edge_list', 'get_meta_graph', 'get_types', 'node_bacward', 'node_feature', 'node_forward', 'test_mask', 'test_paper', 'times', 'train_mask', 'train_paper', 'update_node', 'valid_mask', 'valid_paper', 'y', 'years']
129


In [5]:
'''
Preprocessing data
'''
batch_number = 32 # number of sampled graphs for each epoch
batch_size = 128
num_epochs = 10
num_workers = 8
clip = 1.0
sample_depth = 6
sample_width = 520
plot = False # True or false to plot data
target_nodes = np.arange(len(graph.node_feature['paper']))

stats = []
result = []
best_val = 0
training_step = 0

In [6]:
# sample nodes:
samp_nodes = np.random.choice(target_nodes, batch_size, replace = False)
sample_depth = 6
sample_width = 520           
inp = {'paper': np.concatenate([samp_nodes, graph.years[samp_nodes]]).reshape(2, -1).transpose()}
seed = randint()
n_batch = 32
batch_size = 128
# node_feature, node_type, edge_time, edge_index, edge_type, (train_mask, valid_mask, test_mask), ylabel = ogbn_sample(seed, samp_nodes, graph, sample_depth, sample_width)

In [7]:
# Sampling Data from MAG
datas = []
print("Starting Sampling...")
for batch_id in np.arange(n_batch):
    print(f'Batch number: {batch_id}')
    node_feature, node_type, edge_time, edge_index, edge_type, (train_mask, valid_mask, test_mask), ylabel = ogbn_sample(seed, samp_nodes, graph, sample_depth, sample_width)
    p = (
        node_feature,
        node_type,
        edge_time,
        edge_index,
        edge_type,
        (train_mask, valid_mask, test_mask),
        ylabel
    )
    datas.append(p)
    print(f'Batch Number: {batch_id}, complete')
print("...Preprocessing complete ")

Starting Sampling...
Batch number: 0


  node_feature = torch.FloatTensor(node_feature)


Batch Number: 0, complete
Batch number: 1
Batch Number: 1, complete
Batch number: 2
Batch Number: 2, complete
Batch number: 3
Batch Number: 3, complete
Batch number: 4
Batch Number: 4, complete
Batch number: 5
Batch Number: 5, complete
Batch number: 6
Batch Number: 6, complete
Batch number: 7
Batch Number: 7, complete
Batch number: 8
Batch Number: 8, complete
Batch number: 9
Batch Number: 9, complete
Batch number: 10
Batch Number: 10, complete
Batch number: 11
Batch Number: 11, complete
Batch number: 12
Batch Number: 12, complete
Batch number: 13
Batch Number: 13, complete
Batch number: 14
Batch Number: 14, complete
Batch number: 15
Batch Number: 15, complete
Batch number: 16
Batch Number: 16, complete
Batch number: 17
Batch Number: 17, complete
Batch number: 18
Batch Number: 18, complete
Batch number: 19
Batch Number: 19, complete
Batch number: 20
Batch Number: 20, complete
Batch number: 21
Batch Number: 21, complete
Batch number: 22
Batch Number: 22, complete
Batch number: 23
Batch N

In [8]:
def het_mutual_attention(target_node_rep, source_node_rep, key_source_linear, query_source_linear, edge_type_index, num_heads, head_dim, rel_attention, rel_priority, sqrt_head_dim):
    '''
    Heterogeneous Mutual Attention calculation
    Input:
        - target_node_rep      - Node representation of target
        - source_node_rep      - Node representation of source
        - key_source_linear    - Linear projection of key source    (nn.ModuleList(), looped nn.Linear layers)
        - query_source_linear  - Linear projection of query source  (nn.MOduleList(), looped nn.Linear layers)
        - edge_type_index      - index
    Output:
        - res_attention - Tensor storing computed attention coefficients between source and target nodes. 
    '''
    # Apply linear layers for Key (source) and Query (target)
    print(target_node_rep.shape)
    print(query_source_linear.weight.shape)

    query_lin_matrix = query_source_linear(target_node_rep).view(-1, num_heads, head_dim)
    key_lin_matrix = key_source_linear(source_node_rep).view(-1, num_heads, head_dim)


    # Calculate Relation Attention with Key matrix
    key_lin_attention_matrix = torch.bmm(key_lin_matrix.transpose(1,0), rel_attention[edge_type_index]).transpose(1,0)

    # Dot product between new Key matrix and query, then include meta relation triplet tensor divided by root of head dim
    res_attention = (query_lin_matrix * key_lin_attention_matrix).sum(dim = -1) * (rel_priority[edge_type_index] / sqrt_head_dim)

    return res_attention

In [9]:
# Example heterogeneous mutual attention
stat = []

# p = (
#     0 node_feature,
#     1 node_type,
#     2 edge_time,
#     3 edge_index,     # An edge_index pair is the same index in the first and second tensors, edge between nodes of that index
#     4 edge_type,      # There is an edge type associated with every edge_index "pair"
#     5 (train_mask, valid_mask, test_mask),
#     6 ylabel
# )

node_feature = datas[0][0]
node_type = datas[0][1]
edge_time = datas[0][2]
edge_index = datas[0][3]
edge_type = datas[0][4]


num_heads = 8
source_type_index = 0
edge_type_index = 2
target_type_index = 0

num_edge_types = 8
num_node_types = 4
in_dim = 129
out_dim = 256
head_dim = out_dim // num_heads

# edge_mask holds true for all edges of type edge_type_index
edge_mask = (edge_type == int(edge_type_index))

# create mask for all edges that have source node of type source_type_index
source_nodes_mask = (node_type == int(source_type_index))
source_nodes_indexes = source_nodes_mask.nonzero(as_tuple = True)[0] # holds all indexes where a node is of type source_node_type
source_edges_mask = torch.isin(edge_index[0], source_nodes_indexes)

# create mask for all edges that have target nod eof type target_type_index
target_nodes_mask = (node_type == int(target_type_index))
target_nodes_indexes = target_nodes_mask.nonzero(as_tuple = True)[0]
target_edges_mask = torch.isin(edge_index[0], target_nodes_indexes)

# Meta relation triple, True at indexes where typing matches up
meta_relation_mask = edge_mask & source_edges_mask & target_edges_mask

# apply meta_relation_mask on to get indexes of node_feature
source_node_index_location = edge_index[0][meta_relation_mask]
target_node_index_location = edge_index[1][meta_relation_mask]

# get Node representations based on index_location
source_node_rep = node_feature[source_node_index_location]
target_node_rep = node_feature[target_node_index_location]

key_lin_list = nn.ModuleList()
value_lin_list = nn.ModuleList()
query_lin_list = nn.ModuleList()

for i in range(num_node_types):
    key_lin_list.append(nn.Linear(in_dim, out_dim))
    value_lin_list.append(nn.Linear(in_dim, out_dim))
    query_lin_list.append(nn.Linear(in_dim, out_dim))

rel_attention  = nn.Parameter(torch.Tensor(num_edge_types, num_heads, head_dim, head_dim))
rel_priority   = nn.Parameter(torch.ones(num_edge_types, num_heads))
sqrt_head_dim  = math.sqrt(head_dim)
# het_mutual_attention(target_node_rep, source_node_rep, key_source_linear, query_source_linear, edge_type_index, num_heads, head_dim, rel_attention, rel_priority, sqrt_head_dim)
glorot(rel_attention)

key_source_linear = key_lin_list[source_type_index]
value_source_linear = value_lin_list[source_type_index]
query_source_linear = query_lin_list[target_type_index]

target_node_rep = target_node_rep.to(torch.float32)

output = het_mutual_attention(target_node_rep, source_node_rep, key_source_linear, query_source_linear, edge_type_index, num_heads, head_dim, rel_attention, rel_priority, sqrt_head_dim)

torch.Size([8587, 129])
torch.Size([256, 129])


In [10]:
'''
Creating Model
'''
print("Creating Model")
print(len(graph.get_meta_graph()))
hgt_GNN = HGTModel(len(graph.node_feature['paper'][0]), # input_dim
                   256,                                 # hidden_dim
                   len(graph.get_types()),              # num_node_types
                   len(graph.get_meta_graph()),         # num_edge_types
                   8,                                   # num_heads
                   4,                                   # num_layers
                   0.2,                                 # dropout
                   prev_norm = True,                    # normalization on all but last layer
                   last_norm = False,                   # normalization on last layer
                   use_rte = False)                     # use relative temporal encoding 


Creating Model
8


In [11]:
print(f'node_feature shape is: :{node_feature.shape}')
node_rep = hgt_GNN.forward(node_feature, node_type, edge_index, edge_type, edge_time)

node_feature shape is: :torch.Size([10519, 129])
PRE-NSERT shape is: torch.Size([10519, 256])
PRE_DROP shape is (result): torch.Size([10519, 256])
POST_DROP shape is: torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
AGGREGATION HAPPENING
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size([10519, 256])
torch.Size(

In [12]:
print(node_rep.shape)
print(node_rep)

torch.Size([10519, 256])
tensor([[-1.1247, -0.0228,  0.2347,  ..., -0.7701, -0.6769, -1.3484],
        [-0.8180, -1.8696,  0.5648,  ..., -1.0659, -0.6553, -1.8062],
        [-1.0636, -1.7499,  0.1735,  ...,  0.1019,  0.0990, -0.8137],
        ...,
        [-1.8460, -1.1319, -2.3900,  ..., -0.9407,  0.0154,  2.1887],
        [-2.3474, -0.2455, -1.5087,  ..., -0.4388,  1.0901,  1.9827],
        [-1.8562, -0.3172, -0.6318,  ..., -0.3494,  0.6434,  2.3511]],
       grad_fn=<IndexPutBackward0>)


In [13]:
# # Negative Log Likelihood Loss
# criterion = nn.NLLLoss()

# # Get list of model parameters w/ associated names
# parameters_optimizer = list(HGT_classifier.named_parameters())
# no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
# optimizer_grouped_parameters = [
#     {'params': [p for n, p in parameters_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
#     {'params': [p for n, p in parameters_optimizer if any(nd in n for nd in no_decay)],     'weight_decay': 0.0}
# ]
# # AdamW optimizer w/specified parameter groups and epsilon value
# optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=1e-06)
# # Create a OneCycleLR learning rate scheduler
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.05, anneal_strategy='linear', final_div_factor=10,\
#                         max_lr = 5e-4, total_steps = batch_size * num_epochs + 1)
