In [None]:
import torch
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, GRUCell
from torch_geometric.data import Data
from sklearn.metrics import roc_auc_score,average_precision_score

import random

import bisect

import gc
import copy

from itertools import permutations

import pandas as pd

from torch_geometric.utils import negative_sampling, structured_negative_sampling
import torch_geometric.transforms as T
from torch_geometric.transforms import SVDFeatureReduction
from torch_geometric.utils import train_test_split_edges
from torch_geometric.transforms import RandomLinkSplit,NormalizeFeatures,Constant,OneHotDegree
from torch_geometric.utils import from_networkx
from torch_geometric.nn import GCNConv,SAGEConv,GATv2Conv, GINConv, Linear, GCN, GAT

import torch
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

import copy
import itertools
import json

import sys
# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.insert(1, '../src')
sys.path.insert(1, '../steemitth')

# LOAD DATASET

In [None]:
from durendalrepurpose import *
from steemit import *

In [None]:
#snapshots = get_steemit_dataset(preprocess='textBERT') text features available upon request
snapshots = get_steemit_dataset(preprocess='constant')

# TRAINING AND EVALUATION

### Functions

In [None]:
def reverse_insort(a, x, lo=0, hi=None):
    """Insert item x in list a, and keep it reverse-sorted assuming a
    is reverse-sorted.

    If x is already in a, insert it to the right of the rightmost x.

    Optional args lo (default 0) and hi (default len(a)) bound the
    slice of a to be searched.
    
    Function useful to compute MRR.
    """
    if lo < 0:
        raise ValueError('lo must be non-negative')
    if hi is None:
        hi = len(a)
    while lo < hi:
        mid = (lo+hi)//2
        if x > a[mid]: hi = mid
        else: lo = mid+1
    a.insert(lo, x)
    return lo

def compute_mrr(real_scores, fake_scores):
    srr = 0
    count = 0
    for i,score in enumerate(real_scores):
        try:
            fake_scores_cp = copy.copy([fake_scores[i]])
        except IndexError: break
        rank = reverse_insort(fake_scores_cp, score)
        rr = 1/(rank+1) #index starts from zero
        srr+=rr
        count+=1
    return srr/count

In [None]:
def training(snapshots, hidden_conv_1, hidden_conv_2, device='cpu'):
    num_snap = len(snapshots)
    hetdata = copy.deepcopy(snapshots[0])
    edge_types = list(hetdata.edge_index_dict.keys())
    
    lr = 0.001
    weight_decay = 5e-3
    
    in_channels = {node: len(v[0]) for node,v in hetdata.x_dict.items()}
    
    #DURENDALRGCN
    
    rgcn = DurendalRepurpose(in_channels, hetdata.metadata(),
                        hidden_conv_1=hidden_conv_1,
                        hidden_conv_2=hidden_conv_2,
                        model='rgcn')
    
    rgcn.reset_parameters()
    
    rgcnopt = torch.optim.Adam(params=rgcn.parameters(), lr=lr, weight_decay = weight_decay)
    
    #DURENDALHAN
    
    han = DurendalRepurpose(in_channels, hetdata.metadata(),
                        hidden_conv_1=hidden_conv_1,
                        hidden_conv_2=hidden_conv_2,
                        model='han')
    
    han.reset_parameters()
    
    hanopt = torch.optim.Adam(params=han.parameters(), lr=lr, weight_decay = weight_decay)
    
    #DURENDALHGT
    
    hgt = DurendalRepurpose(in_channels, hetdata.metadata(),
                        hidden_conv_1=hidden_conv_1,
                        hidden_conv_2=hidden_conv_2,
                        model='hgt')
    
    hgt.reset_parameters()
    
    hgtopt = torch.optim.Adam(params=hgt.parameters(), lr=lr, weight_decay = weight_decay)
    
    
    past_dict_1_rgcn = {}
    for node in hetdata.x_dict.keys():
        past_dict_1_rgcn[node] = {}
    for src,r,dst in hetdata.edge_index_dict.keys():
        past_dict_1_rgcn[src][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(hetdata[src].num_nodes)])
        past_dict_1_rgcn[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(hetdata[dst].num_nodes)])
        
    past_dict_2_rgcn = {}
    for node in hetdata.x_dict.keys():
        past_dict_2_rgcn[node] = {}
    for src,r,dst in hetdata.edge_index_dict.keys():
        past_dict_2_rgcn[src][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(hetdata[src].num_nodes)])
        past_dict_2_rgcn[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(hetdata[dst].num_nodes)])
    
    past_dict_1_han = copy.deepcopy(past_dict_1_rgcn)
    past_dict_2_han = copy.deepcopy(past_dict_2_rgcn)
    
    past_dict_1_hgt = copy.deepcopy(past_dict_1_rgcn)
    past_dict_2_hgt = copy.deepcopy(past_dict_2_rgcn)
    
    rgcn_avgpr = 0
    rgcn_mrr = 0
    han_avgpr = 0
    han_mrr = 0
    hgt_avgpr = 0
    hgt_mrr = 0
    
    for i in range(num_snap-1):
        #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
                #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
        snapshot = copy.deepcopy(snapshots[i])
        
        hom_transform = RandomLinkSplit(num_val=0.0, num_test=0.20)
        
        
        rel0_hom = Data()
        rel0_hom.x = copy.deepcopy(snapshot['node'].x)
        rel0_hom.edge_index = copy.deepcopy(snapshot['node','follow','node'].edge_index)
        rel0_train, _, rel0_val = hom_transform(rel0_hom)
        
        
        het_train_data = copy.deepcopy(snapshot)
        het_val_data = copy.deepcopy(snapshot)
        
        het_train_data['node','follow','node'].edge_index = rel0_train.edge_index
        het_train_data['node','follow','node'].edge_label_index = rel0_train.edge_label_index
        het_train_data['node','follow','node'].edge_label = rel0_train.edge_label
        het_val_data['node','follow','node'].edge_index = rel0_val.edge_index
        het_val_data['node','follow','node'].edge_label_index = rel0_val.edge_label_index
        het_val_data['node','follow','node'].edge_label = rel0_val.edge_label
     
        het_test_data = copy.deepcopy(snapshots[i+1])
        het_future_neg_edge_index = negative_sampling(
            edge_index=het_test_data['node','follow','node'].edge_index, #positive edges
            num_nodes=het_test_data['node'].num_nodes, # number of nodes
            num_neg_samples=het_test_data['node','follow','node'].edge_index.size(1)) # number of neg_sample equal to number of pos_edges
        #edge index ok, edge_label concat, edge_label_index concat
        num_pos_edge = het_test_data['node','follow','node'].edge_index.size(1)
        het_test_data['node','follow','node'].edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        het_test_data['node','follow','node'].edge_label_index = \
            torch.cat([het_test_data['node','follow','node'].edge_index, het_future_neg_edge_index], dim=-1)
        
        #corrupted_edges as field of test_data and val_data
        src_t, _, corrupted_dst =\
            structured_negative_sampling(het_val_data['node','follow','node'].edge_index)
            
        corrupted_edge_index_val = torch.stack([src_t, corrupted_dst])
        
        src_t_test, _, corrupted_dst_test =\
            structured_negative_sampling(het_test_data['node','follow','node'].edge_index)
        
        corrupted_edge_index_test = torch.stack([src_t_test, corrupted_dst_test])
        
        het_val_data['node','follow','node'].corrupted_edge_index = corrupted_edge_index_val
        het_test_data['node','follow','node'].corrupted_edge_index = corrupted_edge_index_test
        
        #TRAIN AND TEST THE MODEL FOR THE CURRENT SNAP
        rgcn, rgcn_avgpr_test, rgcn_mrr_test , past_dict_1_rgcn, past_dict_2_rgcn, rgcnopt =\
            durendal_train_single_snapshot(rgcn, snapshot, i, het_train_data, het_val_data, het_test_data,\
                                  past_dict_1_rgcn, past_dict_2_rgcn, rgcnopt)
        
        han, han_avgpr_test, han_mrr_test , past_dict_1_han, past_dict_2_han, hanopt =\
            durendal_train_single_snapshot(han, snapshot, i, het_train_data, het_val_data, het_test_data,\
                                  past_dict_1_han, past_dict_2_han, hanopt)
        
        hgt, hgt_avgpr_test, hgt_mrr_test , past_dict_1_hgt, past_dict_2_hgt, hgtopt =\
            durendal_train_single_snapshot(hgt, snapshot, i, het_train_data, het_val_data, het_test_data,\
                                  past_dict_1_hgt, past_dict_2_hgt, hgtopt)
        
        #SAVE AND DISPLAY EVALUATION
        print(f'Snapshot: {i} done\n')
        #print(f' DURENDAL-RGCN AVGPR Test: {rgcn_avgpr_test} \n MRR Test: {rgcn_mrr_test}\n')
        #print(f' DURENDAL-HAN AVGPR Test: {han_avgpr_test} \n MRR Test: {han_mrr_test}\n')
        #print(f' DURENDAL-HGT AVGPR Test: {hgt_avgpr_test} \n MRR Test: {hgt_mrr_test}\n')
        rgcn_avgpr += rgcn_avgpr_test
        rgcn_mrr += rgcn_mrr_test
        han_avgpr += han_avgpr_test
        han_mrr += han_mrr_test
        hgt_avgpr += hgt_avgpr_test
        hgt_mrr += hgt_mrr_test
        
        
    rgcn_avgpr_all = rgcn_avgpr / (num_snap-1)
    rgcn_mrr_all = rgcn_mrr / (num_snap-1)
    han_avgpr_all = han_avgpr / (num_snap-1)
    han_mrr_all = han_mrr / (num_snap-1)
    hgt_avgpr_all = hgt_avgpr / (num_snap-1)
    hgt_mrr_all = hgt_mrr / (num_snap-1)
    
    print('DURENDAL-RGCN')
    print(f'\tAVGPR over time: Test: {rgcn_avgpr_all}')
    print(f'\tMRR over time: Test: {rgcn_mrr_all}')
    print()
    print('DURENDAL-HAN')
    print(f'\tAVGPR over time: Test: {han_avgpr_all}')
    print(f'\tMRR over time: Test: {han_mrr_all}')
    print()
    print('DURENDAL-HGT')
    print(f'\tAVGPR over time: Test: {hgt_avgpr_all}')
    print(f'\tMRR over time: Test: {hgt_mrr_all}')
    print()
    
    return rgcn_avgpr_all, han_avgpr_all, hgt_avgpr_all

In [None]:
def durendal_train_single_snapshot(model, data, i_snap, train_data, val_data, test_data,\
                          past_dict_1, past_dict_2,\
                          optimizer, device='cpu', num_epochs=50, verbose=False):
    
    mrr_val_max = 0
    avgpr_val_max = 0
    best_model = model
    train_data = train_data.to(device)
    best_epoch = -1
    best_past_dict_1 = {}
    best_past_dict_2 = {}
    
    tol = 5e-2
    
    for epoch in range(num_epochs):
        model.train()
        ## Note
        ## 1. Zero grad the optimizer
        ## 2. Compute loss and backpropagate
        ## 3. Update the model parameters
        optimizer.zero_grad()
        
        
        pred, past_dict_1, past_dict_2 =\
            model(train_data.x_dict, train_data.edge_index_dict, train_data['node','follow','node'].edge_label_index,\
                  i_snap, past_dict_1, past_dict_2)
        
        loss = model.loss(pred, train_data['node','follow','node'].edge_label.type_as(pred)) #loss to fine tune on current snapshot

        loss.backward(retain_graph=True)  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.

        ##########################################

        log = 'Epoch: {:03d}\n AVGPR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n MRR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n F1-Score Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n Loss: {}'
        avgpr_score_val, mrr_val = durendal_test(model, i_snap, val_data, data, device)
        
        """
        if mrr_val_max-tol < mrr_val:
            mrr_val_max = mrr_val
            best_epoch = epoch
            best_current_embeddings = current_embeddings
            best_model = copy.deepcopy(model)
        else:
            break
        
        #print(f'Epoch: {epoch} done')
            
        """
        if avgpr_val_max-tol <= avgpr_score_val:
            avgpr_val_max = avgpr_score_val
            best_epoch = epoch
            best_past_dict_1 = past_dict_1
            best_past_dict_2 = past_dict_2
            best_model = model
        else:
            break
        
    avgpr_score_test, mrr_test = durendal_test(model, i_snap, test_data, data, device)
            
    if verbose:
        print(f'Best Epoch: {best_epoch}')
    #print(f'Best Epoch: {best_epoch}')
    
    return best_model, avgpr_score_test, mrr_test, best_past_dict_1, best_past_dict_2, optimizer

In [None]:
def durendal_test(model, i_snap, test_data, data, device='cpu'):
    
    model.eval()

    test_data = test_data.to(device)
    
    num_pos = len(test_data['node','follow','node'].edge_label_index[0])//2

    h, *_ = model(test_data.x_dict, test_data.edge_index_dict, test_data['node','follow','node'].edge_label_index, i_snap)
    fake, *_ = model(test_data.x_dict, test_data.edge_index_dict, test_data['node','follow','node'].corrupted_edge_index, i_snap)
    
    pred_cont = torch.sigmoid(h).cpu().detach().numpy()
    fake_preds = torch.sigmoid(fake).cpu().detach().numpy()
    
    label = test_data['node','follow','node'].edge_label.cpu().detach().numpy()
      
    avgpr_score = average_precision_score(label, pred_cont)
    mrr_score = compute_mrr(pred_cont[:num_pos], fake_preds)
    
    return avgpr_score, mrr_score

### TRAINING AND EVALUATION

In [None]:
hidden_conv_1=256
hidden_conv_2=128

In [None]:
import random
device = torch.device('cuda')
#torch.manual_seed(123456)
#torch.cuda.manual_seed_all(123456)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.empty_cache()

In [None]:
nrun = 10
for run in range(nrun):
    a, b, c = training(snapshots, hidden_conv_1, hidden_conv_2)
    with open('results/durendal-rgcn-gru.txt','a') as wfile:
        wfile.write(f'{a}\n')
    with open('results/durendal-han-gru.txt','a') as wfile:
        wfile.write(f'{b}\n')
    with open('results/durendal-hgt-gru.txt','a') as wfile:
        wfile.write(f'{c}\n')