In [193]:
import argparse, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data import BitcoinOTC
import datetime
from dgl.nn.pytorch import GraphConv
import time
from sklearn.metrics import f1_score
import os
import json
from collections import defaultdict, Counter
from tqdm import tqdm
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy
import matplotlib.pyplot as plt
from os import listdir
from os.path import isfile, join
import networkx as nx
import pickle as pkl
import importlib.util
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

def module_from_file(module_name, file_path):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

graph_utils = module_from_file("graph_utils", "../graph_utils.py")
gen_utils = module_from_file("gen_utils", "../utils.py")
topic_utils = module_from_file("topic_utils", "../utils.py")

## Hyperparams

In [502]:
node_dim = 300
hid_dim = 512
n_layers = 2
dropout = 0
learning_rate = 0.01
wt_decay = 0
stpsize = 5000
n_epochs = 1000
n_classes = 45
out_path = '/misc/vlgscratch4/BrunaGroup/rj1408/inference/models/iter1/'
self_loop = True
data_path = '/misc/vlgscratch4/BrunaGroup/rj1408/inference/data/'
graph_file = 'graph_df.pkl'
feature_file = 'text_embed_en.pkl'
label_file = 'en_outlinks_tokens_df.pkl'
activation = F.relu

In [463]:
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    device = 'cuda'
else:
    device = 'cpu'

## Data preparation

In [464]:
def removeSelfEdges(edgeList, colFrom, colTo):
    mask = edgeList[:, colFrom] - edgeList[:, colTo] != 0
    edgeList = edgeList[mask]
    return edgeList

In [465]:
def load_graphs(data, self_loop):
    g = graph_utils.build_graph(data)
    if self_loop == True:
        g.add_edges(g.nodes(), g.nodes())
    return g

In [171]:
#load graph
with open(os.path.join(data_path, graph_file), "rb") as f:
    wiki_graph_df = pkl.load(f)
    
graph = load_graphs(wiki_graph_df, True)

In [172]:
#load feature dataframes
with open(os.path.join(data_path, feature_file), "rb") as f:
    wiki_feature_df = pkl.load(f)

with open(os.path.join(data_path, label_file), "rb") as f:
    wiki_label_df = pkl.load(f)
    
joined_df = wiki_feature_df.join(wiki_graph_df, lsuffix='1')
joined_df = joined_df.join(wiki_label_df, lsuffix='2').sort_values(by='node_id')

In [176]:
features = joined_df['text_1000_embed']
features = np.nan_to_num(np.stack(features))
mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(joined_df.mid_level_categories)
labels = torch.FloatTensor(labels)
graph.ndata['feat'] = torch.FloatTensor(features)
graph.ndata['labels'] = labels

In [181]:
#Sample training/validation and test
train_p = 0.6
val_p = 0.15
test_p = 0.25

indices = np.arange(graph.number_of_nodes())
indices_train,indices_test = train_test_split(indices, test_size=test_p, random_state=42)
indices_train, indices_val = train_test_split(indices_train, train_size = train_p/(1 - test_p), random_state=42)
train_mask = torch.zeros(graph.number_of_nodes()).bool()
val_mask = torch.zeros(graph.number_of_nodes()).bool()
test_mask = torch.zeros(graph.number_of_nodes()).bool()
train_mask[indices_train] = True
val_mask[indices_val] = True
test_mask[indices_test] = True

In [328]:
#For class imbalance
pos_examples = np.sum(labels[train_mask].long().numpy(), axis = 0)
neg_examples = labels[train_mask].shape[0] - pos_examples
psweights = neg_examples/pos_examples
psweights = torch.FloatTensor(weights)

In [346]:
#For interclass imbalance
max_examples = np.max(pos_examples)
class_weights = max_examples/pos_examples
class_weights = torch.FloatTensor(class_weights).to(device)

## Model definition

In [190]:
class GCN(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_layers,
                 n_classes,
                 activation,
                 dropout):
        super(GCN, self).__init__()
        self.layers = nn.ModuleList()
        
        # input layer
        self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
        
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
        
        # output layer
        self.outlayer = nn.Linear(n_hidden, n_classes, bias = True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, features, g):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(g, h)
        outputs = self.outlayer(h)
        return outputs

## Training loop

In [476]:
def evaluate_loss(model, criterion, device, graph, mask):
    model.eval()
    
    #validation phase
    with torch.set_grad_enabled(False):
        feat = graph.ndata['feat'].to(device)
        outputs = model(feat, graph)
        labels = graph.ndata['labels']
        labels = labels.to(device)
        outputs = outputs[mask]
        labels = labels[mask]
        loss = criterion(outputs, labels)
        loss = loss*class_weights
        loss = torch.mean(loss)
    return loss.item()

In [396]:
def predict(model, criterion, device, graph, mask):
    model.eval()
    
    #validation phase
    with torch.set_grad_enabled(False):
        feat = graph.ndata['feat'].to(device)
        outputs = model(feat, graph)
        outputs = outputs[mask]
                    
    return outputs

In [397]:
def evaluate_metrics(model, criterion, device, graph, mask, labels):
    logits = predict(model, criterion, device, graph, mask)
    preds = (logits > 0).long()
    ground_labels = labels[mask].long()
    dict_metrics = gen_utils.get_metrics_dict(ground_labels, preds)
    return dict_metrics

In [488]:
#Code for supervised training
def train_model(model, criterion, optimizer, scheduler, device, checkpoint_path, hyperparams, graph, train_mask, val_mask, num_epochs=25, check_iter=50):
    metrics_dict = {}
    metrics_dict["train"] = {}
    metrics_dict["valid"] = {}
    metrics_dict["train"]["loss"] = {}
    metrics_dict["train"]["loss"]["epochwise"] = []
    metrics_dict["valid"]["loss"] = {}
    metrics_dict["valid"]["loss"]["epochwise"] = []
        
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 9999999999999999

    for epoch in range(num_epochs):
        
        #train phase
        scheduler.step()
        model.train() 
        optimizer.zero_grad()
        # forward
        # track history if only in train
        forward_start_time  = time.time()
        with torch.set_grad_enabled(True):
            feats = graph.ndata['feat'].to(device)
            outputs = model(feats, graph)
            labels = graph.ndata['labels']
            labels = labels.to(device)
            outputs = outputs[train_mask]
            labels = labels[train_mask]
            loss = criterion(outputs, labels)
            loss = loss*class_weights
            loss = torch.mean(loss)
            epoch_loss = loss.item()
            loss.backward()
            optimizer.step()
        forward_time = time.time() - forward_start_time
        
        
        metrics_dict["train"]["loss"]["epochwise"].append(epoch_loss)
        
        #validation phase
        val_epoch_loss = evaluate_loss(model, criterion, device, graph, val_mask)
        metrics_dict["valid"]["loss"]["epochwise"].append(val_epoch_loss)
        
        if epoch%check_iter==0:
            print('Epoch {}/{} \n'.format(epoch, num_epochs - 1))
            print('-' * 10)
            print('\n')
            print('Train Loss: {:.4f} \n'.format(epoch_loss))
            print('Validation Loss: {:.4f} \n'.format(val_epoch_loss))
        
        
        # deep copy the model
        if val_epoch_loss < best_loss:
            best_loss = val_epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            
        if epoch%check_iter==0:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'full_metrics': metrics_dict,
            'hyperparams': hyperparams
            }, '%s/net_epoch_%d.pth' % (checkpoint_path, epoch))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s \n'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f} \n'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [503]:
 # create GCN model
model = GCN(node_dim, hid_dim, n_layers, n_classes, activation, dropout)
model.to(device)
criterion = nn.BCEWithLogitsLoss()#pos_weight=class_weights.to(device), reduction='none')
model_parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(model_parameters, lr=learning_rate, weight_decay = wt_decay)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=stpsize, gamma=0.1)
hyper_params = {'node_dim' : node_dim,
    'hid_dim': hid_dim,
    'n_layers' : n_layers,
    'dropout' : dropout,
    'wt_decay' : wt_decay,
    }

bst_model = train_model(model, criterion, optimizer, exp_lr_scheduler, device, out_path, hyper_params, graph, train_mask, val_mask, n_epochs)

Epoch 0/999 

----------


Train Loss: 25.1099 

Validation Loss: 20.9552 

Epoch 50/999 

----------


Train Loss: 5.3910 

Validation Loss: 5.4193 

Epoch 100/999 

----------


Train Loss: 4.0666 

Validation Loss: 4.1054 

Epoch 150/999 

----------


Train Loss: 3.4064 

Validation Loss: 3.4436 

Epoch 200/999 

----------


Train Loss: 3.0915 

Validation Loss: 3.1529 

Epoch 250/999 

----------


Train Loss: 2.9588 

Validation Loss: 3.0158 

Epoch 300/999 

----------


Train Loss: 2.8257 

Validation Loss: 2.9144 

Epoch 350/999 

----------


Train Loss: 2.7866 

Validation Loss: 2.8570 

Epoch 400/999 

----------


Train Loss: 2.7098 

Validation Loss: 2.8235 

Epoch 450/999 

----------


Train Loss: 2.6635 

Validation Loss: 2.7473 

Epoch 500/999 

----------


Train Loss: 2.6020 

Validation Loss: 2.7078 

Epoch 550/999 

----------


Train Loss: 2.5580 

Validation Loss: 2.6924 

Epoch 600/999 

----------


Train Loss: 2.5922 

Validation Loss: 2.6604 

Epoch 650/999

In [504]:
def predict_logits(model, device, graph, mask=None):
    model.eval()
    with torch.no_grad():
        features = graph.ndata['feat'].to(device)
        logits = model(features, graph)
        
        if mask is not None:
            logits = logits[mask]
    return logits

In [505]:
predicted_logits = predict_logits(bst_model, device, graph, test_mask)
hard_predictions = (predicted_logits > 0).long()
hard_predictions = hard_predictions.cpu().detach().numpy()
ground_truth_labels = labels[test_mask].long().cpu().detach().numpy()

In [506]:
topic_utils.get_metrics_dict(ground_truth_labels, hard_predictions)

{'precision_macro': 0.712,
 'recall_macro': 0.391,
 'f1_macro': 0.473,
 'precision_micro': 0.826,
 'recall_micro': 0.596,
 'f1_micro': 0.693}