Notebook to reproduce the results for TaobaoTH. To inspect annotated code, please see the notebook related to SteemitTH. Note that the code to run the experiments is the same for all the datasets; there are small changes just related to relation and model names. We decide to proceed in this way to treat the different prediction tasks separately. 

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, GRUCell
from torch_geometric.data import Data
from sklearn.metrics import roc_auc_score,average_precision_score

import random

import bisect

import gc
import copy

from itertools import permutations

import pandas as pd

from torch_geometric.utils import negative_sampling, structured_negative_sampling
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
from torch_geometric.transforms import RandomLinkSplit,NormalizeFeatures,Constant,OneHotDegree
from torch_geometric.utils import from_networkx
from torch_geometric.nn import GCNConv,SAGEConv,GATv2Conv, GINConv, Linear, GCN, GAT

import torch
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

import copy

In [None]:
from taobao import *

# Load dataset

In [None]:
snapshots = get_taobao_dataset()

In [None]:
snapshots

# TRAINING AND EVALUATION

### Functions

In [None]:
def reverse_insort(a, x, lo=0, hi=None):
    """Insert item x in list a, and keep it reverse-sorted assuming a
    is reverse-sorted.

    If x is already in a, insert it to the right of the rightmost x.

    Optional args lo (default 0) and hi (default len(a)) bound the
    slice of a to be searched.
    
    Function useful to compute MRR.
    """
    if lo < 0:
        raise ValueError('lo must be non-negative')
    if hi is None:
        hi = len(a)
    while lo < hi:
        mid = (lo+hi)//2
        if x > a[mid]: hi = mid
        else: lo = mid+1
    a.insert(lo, x)
    return lo

def compute_mrr(real_scores, fake_scores):
    srr = 0
    count = 0
    for i,score in enumerate(real_scores):
        try:
            fake_scores_cp = copy.copy([fake_scores[i]])
        except IndexError: break
        rank = reverse_insort(fake_scores_cp, score)
        rr = 1/(rank+1) #index starts from zero
        srr+=rr
        count+=1
    return srr/count

In [None]:
def training(snapshots, hidden_conv_1, hidden_conv_2, device='cpu'):
    num_snap = len(snapshots)
    hetdata = copy.deepcopy(snapshots[0])
    homdata = copy.deepcopy(snapshots[0]).to_homogeneous()
    edge_types = list(hetdata.edge_index_dict.keys())
    
    lr = 0.001
    weight_decay = 5e-3
    
    in_channels = {node: len(v[0]) for node,v in hetdata.x_dict.items()}
    num_nodes = {node: len(v) for node, v in hetdata.x_dict.items()}
    
    in_channels_homo = homdata.x.size(1)
    num_nodes_homo = homdata.x.size(0)
    
    #DURENDAL
    durendal = TAOBAODurendal(in_channels, num_nodes, hetdata.metadata(),
                        hidden_conv_1=hidden_conv_1,
                        hidden_conv_2=hidden_conv_2)
    
    durendal.reset_parameters()
    
    durendalopt = torch.optim.Adam(params=durendal.parameters(), lr=lr, weight_decay = weight_decay)
    
    #GAT
    gat = TAOBAOGAT(in_channels_homo, hidden_conv_1, hidden_conv_2)
    gat.reset_parameters()
    gatopt = torch.optim.Adam(params=gat.parameters(), lr=lr, weight_decay = weight_decay)
    
    #HAN
    han = TAOBAOHAN(in_channels, hidden_conv_1, hidden_conv_2, hetdata.metadata())
    han.reset_parameters()
    hanopt = torch.optim.Adam(params=han.parameters(), lr=lr, weight_decay = weight_decay)
    
    #GConvGRU
    gcgru = TAOBAOGConvGRU(in_channels_homo, hidden_conv_2)
    gcgru.reset_parameters()
    gcgruopt = torch.optim.Adam(params=gcgru.parameters(), lr=lr, weight_decay = weight_decay)
    
    #EvolveGCN
    ev = TAOBAOEvolveGCN(in_channels_homo, num_nodes_homo)
    ev.reset_parameters()
    evopt = torch.optim.Adam(params=ev.parameters(), lr=lr, weight_decay = weight_decay)
    
    #HetEvolveGCN
    hev = TAOBAOHEGCN(in_channels_homo, num_nodes_homo, list(hetdata.edge_index_dict.keys()))
    hev.reset_parameters()
    hevopt = torch.optim.Adam(params=hev.parameters(), lr=lr, weight_decay = weight_decay)
    
    #ATU
    atu = TAOBAOATU(in_channels, num_nodes, hetdata.metadata(),
                        hidden_conv_1=hidden_conv_1,
                        hidden_conv_2=hidden_conv_2)
    atu.reset_parameters()
    atuopt = torch.optim.Adam(params=atu.parameters(), lr=lr, weight_decay = weight_decay)
    
    past_dict_1 = {}
    for node in hetdata.x_dict.keys():
        past_dict_1[node] = {}
    for src,r,dst in hetdata.edge_index_dict.keys():
        past_dict_1[src][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(hetdata[src].num_nodes)])
        past_dict_1[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(hetdata[dst].num_nodes)])
        
    past_dict_2 = {}
    for node in hetdata.x_dict.keys():
        past_dict_2[node] = {}
    for src,r,dst in hetdata.edge_index_dict.keys():
        past_dict_2[src][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(hetdata[src].num_nodes)])
        past_dict_2[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(hetdata[dst].num_nodes)])
    
    past_dict_1_atu = copy.deepcopy(past_dict_1)
    past_dict_2_atu = copy.deepcopy(past_dict_2)
    
    durendal_avgpr = 0
    durendal_mrr = 0
    gat_avgpr = 0
    gat_mrr = 0
    han_avgpr = 0
    han_mrr = 0
    gcgru_avgpr = 0
    gcgru_mrr = 0
    ev_avgpr = 0
    ev_mrr = 0
    hev_avgpr = 0
    hev_mrr = 0
    atu_avgpr = 0
    atu_mrr = 0
    
    H = None
    H_1 = None
    C_1 = None
    H_2 = None
    C_2 = None
    
    for i in range(num_snap-1):
        #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
        snapshot = copy.deepcopy(snapshots[i])
        hom_snapshot = snapshot.to_homogeneous()
        hom_transform = RandomLinkSplit(num_val=0.0, num_test=0.20)
        hom_train_data, _, hom_val_data = hom_transform(hom_snapshot)
        
        het_transform = RandomLinkSplit(num_val=0.0,num_test=0.20, edge_types=edge_types)
        het_train_data, _, het_val_data = het_transform(snapshot)
     
        het_test_data = copy.deepcopy(snapshots[i+1])
        het_future_neg_edge_index = negative_sampling(
            edge_index=het_test_data['user','buy','item'].edge_index, #positive edges
            num_nodes=len(het_test_data['user'].x), # number of nodes
            num_neg_samples=het_test_data['user','buy','item'].edge_index.size(1)) # number of neg_sample equal to number of pos_edges
        #edge index ok, edge_label concat, edge_label_index concat
        num_pos_edge = het_test_data['user','buy','item'].edge_index.size(1)
        het_test_data['user','buy','item'].edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        het_test_data['user','buy','item'].edge_label_index = \
            torch.cat([het_test_data['user','buy','item'].edge_index, het_future_neg_edge_index], dim=-1)
        
        hom_test_data = copy.deepcopy(snapshots[i+1]).to_homogeneous()
        """
        hom_future_neg_edge_index = negative_sampling(
            edge_index = hom_test_data.edge_index,
            num_nodes = hom_test_data.num_nodes,
            num_neg_samples = hom_test_data.edge_index.size(1)
        )
        """
        #num_pos_edge = hom_test_data.edge_index.size(1)
        hom_test_data.edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        #hom_test_data.edge_label_index = \
            #torch.cat([hom_test_data.edge_index, hom_future_neg_edge_index], dim=-1)
        hom_test_data.edge_label_index = het_test_data['user','buy','item'].edge_label_index
        
        #corrupted_edges as field of test_data and val_data
        src_t, _, corrupted_dst =\
            structured_negative_sampling(het_val_data['user','buy','item'].edge_index)
            
        corrupted_edge_index_val = torch.stack([src_t, corrupted_dst])
        
        src_t_test, _, corrupted_dst_test =\
            structured_negative_sampling(het_test_data['user','buy','item'].edge_index)
        
        corrupted_edge_index_test = torch.stack([src_t_test, corrupted_dst_test])
        
        hom_val_data.corrupted_edge_index = corrupted_edge_index_val
        het_val_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_val
        hom_test_data.corrupted_edge_index = corrupted_edge_index_test
        het_test_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_test
        
        #TRAIN AND TEST THE MODEL FOR THE CURRENT SNAP
        durendal, dur_avgpr_test, dur_mrr_test , past_dict_1, past_dict_2, durendalopt =\
            durendal_train_single_snapshot(durendal, snapshot, i, het_train_data, het_val_data, het_test_data,\
                                  past_dict_1, past_dict_2, durendalopt)
        
        gat, gat_avgpr_test, gat_mrr_test, gatopt =\
            hom_train_single_snapshot(gat, snapshot, hom_train_data, hom_val_data, hom_test_data, gatopt)
        
        han, han_avgpr_test, han_mrr_test, hanopt =\
            het_train_single_snapshot(han, snapshot, het_train_data, het_val_data, het_test_data, hanopt)
        
        gcgru, gcgru_avgpr_test, gcgru_mrr_test, H, gcgruopt =\
            gcgru_train_single_snapshot(gcgru, snapshot, hom_train_data, hom_val_data, hom_test_data, gcgruopt, H)
        
        ev, ev_avgpr_test, ev_mrr_test, evopt =\
            hom_train_single_snapshot(ev, snapshot, hom_train_data, hom_val_data, hom_test_data, evopt)
        
        hev, hev_avgpr_test, hev_mrr_test, hevopt =\
            het_train_single_snapshot(hev, snapshot, het_train_data, het_val_data, het_test_data, hevopt)
        
        atu, atu_avgpr_test, atu_mrr_test , past_dict_1_atu, past_dict_2_atu, atuopt =\
            durendal_train_single_snapshot(atu, snapshot, i, het_train_data, het_val_data, het_test_data,\
                                  past_dict_1_atu, past_dict_2_atu, atuopt)
        
        #SAVE AND DISPLAY EVALUATION
        print(f'Snapshot: {i}\n')
        print(f' DURENDAL AVGPR Test: {dur_avgpr_test} \n MRR Test: {dur_mrr_test}\n')
        print(f' GAT AVGPR Test: {gat_avgpr_test} \n MRR Test: {gat_mrr_test}\n')
        print(f' HAN AVGPR Test: {han_avgpr_test} \n MRR Test: {han_mrr_test}\n')
        print(f' GConvGRU AVGPR Test: {gcgru_avgpr_test} \n MRR Test: {gcgru_mrr_test}\n')
        print(f' EvolveGCN AVGPR Test: {ev_avgpr_test} \n MRR Test: {ev_mrr_test}\n')
        print(f' HetEvolveGCN AVGPR Test: {hev_avgpr_test} \n MRR Test: {hev_mrr_test}\n')
        print(f' ATU AVGPR Test: {atu_avgpr_test} \n MRR Test: {atu_mrr_test}\n')
        durendal_avgpr += dur_avgpr_test
        durendal_mrr += dur_mrr_test
        gat_avgpr += gat_avgpr_test
        gat_mrr += gat_mrr_test
        han_avgpr += han_avgpr_test
        han_mrr += han_mrr_test
        gcgru_avgpr += gcgru_avgpr_test
        gcgru_mrr += gcgru_mrr_test
        ev_avgpr += ev_avgpr_test
        ev_mrr += ev_mrr_test
        hev_avgpr += hev_avgpr_test
        hev_mrr += hev_mrr_test
        atu_avgpr += atu_avgpr_test
        atu_mrr += atu_mrr_test
        
        
    durendal_avgpr_all = durendal_avgpr / (num_snap-1)
    durendal_mrr_all = durendal_mrr / (num_snap-1)
    gat_avgpr_all = gat_avgpr / (num_snap-1)
    gat_mrr_all = gat_mrr / (num_snap-1)
    han_avgpr_all = han_avgpr / (num_snap-1)
    han_mrr_all = han_mrr / (num_snap-1)
    gcgru_avgpr_all = gcgru_avgpr / (num_snap-1)
    gcgru_mrr_all = gcgru_mrr / (num_snap-1)
    ev_avgpr_all = ev_avgpr / (num_snap-1)
    ev_mrr_all = ev_mrr / (num_snap-1)
    hev_avgpr_all = hev_avgpr / (num_snap-1)
    hev_mrr_all = hev_mrr / (num_snap-1)
    atu_avgpr_all = atu_avgpr / (num_snap-1)
    atu_mrr_all = atu_mrr / (num_snap-1)
    
    print('DURENDAL')
    print(f'\tAVGPR over time: Test: {durendal_avgpr_all}')
    print(f'\tMRR over time: Test: {durendal_mrr_all}')
    print()
    print('GAT')
    print(f'\tAVGPR over time: Test: {gat_avgpr_all}')
    print(f'\tMRR over time: Test: {gat_mrr_all}')
    print()
    print('HAN')
    print(f'\tAVGPR over time: Test: {han_avgpr_all}')
    print(f'\tMRR over time: Test: {han_mrr_all}')
    print()
    print('GConvGRU')
    print(f'\tAVGPR over time: Test: {gcgru_avgpr_all}')
    print(f'\tMRR over time: Test: {gcgru_mrr_all}')
    print()
    print('EvolveGCN')
    print(f'\tAVGPR over time: Test: {ev_avgpr_all}')
    print(f'\tMRR over time: Test: {ev_mrr_all}')
    print()
    print('HetEvolveGCN')
    print(f'\tAVGPR over time: Test: {hev_avgpr_all}')
    print(f'\tMRR over time: Test: {hev_mrr_all}')
    print('ATU')
    print(f'\tAVGPR over time: Test: {atu_avgpr_all}')
    print(f'\tMRR over time: Test: {atu_mrr_all}')
    
    return

In [None]:
def training_han(snapshots, hidden_conv_1, hidden_conv_2, device='cpu'):
    num_snap = len(snapshots)
    hetdata = copy.deepcopy(snapshots[0])
    edge_types = list(hetdata.edge_index_dict.keys())
    
    lr = 0.001
    weight_decay = 5e-3
    
    in_channels = {node: len(v[0]) for node,v in hetdata.x_dict.items()}
    num_nodes = {node: len(v) for node, v in hetdata.x_dict.items()}
    
    #HAN
    han = TAOBAOHAN(in_channels, hidden_conv_1, hidden_conv_2, hetdata.metadata())
    han.reset_parameters()
    hanopt = torch.optim.Adam(params=han.parameters(), lr=lr, weight_decay = weight_decay)
    
    han_avgpr = 0
    han_mrr = 0
    
    ch=0
    for i in range(num_snap-1):
        #if ch>=9:break
        #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
        snapshot = copy.deepcopy(snapshots[i])
        
        het_transform = RandomLinkSplit(num_val=0.0,num_test=0.20, edge_types=edge_types)
        het_train_data, _, het_val_data = het_transform(snapshot)
     
        het_test_data = copy.deepcopy(snapshots[i+1])
        het_future_neg_edge_index = negative_sampling(
            edge_index=het_test_data['user','buy','item'].edge_index, #positive edges
            num_nodes=len(het_test_data['user'].x), # number of nodes
            num_neg_samples=het_test_data['user','buy','item'].edge_index.size(1)) # number of neg_sample equal to number of pos_edges
        #edge index ok, edge_label concat, edge_label_index concat
        num_pos_edge = het_test_data['user','buy','item'].edge_index.size(1)
        het_test_data['user','buy','item'].edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        het_test_data['user','buy','item'].edge_label_index = \
            torch.cat([het_test_data['user','buy','item'].edge_index, het_future_neg_edge_index], dim=-1)
        
        #corrupted_edges as field of test_data and val_data
        src_t, _, corrupted_dst =\
            structured_negative_sampling(het_val_data['user','buy','item'].edge_index)
            
        corrupted_edge_index_val = torch.stack([src_t, corrupted_dst])
        
        src_t_test, _, corrupted_dst_test =\
            structured_negative_sampling(het_test_data['user','buy','item'].edge_index)
        
        corrupted_edge_index_test = torch.stack([src_t_test, corrupted_dst_test])
    
        het_val_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_val
        het_test_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_test
        
        #TRAIN AND TEST THE MODEL FOR THE CURRENT SNAP
        
        han, han_avgpr_test, han_mrr_test, hanopt =\
            het_train_single_snapshot(han, snapshot, het_train_data, het_val_data, het_test_data, hanopt)
        
        #SAVE AND DISPLAY EVALUATION
        print(f'Snapshot: {i}\n')
        print(f' HAN AVGPR Test: {han_avgpr_test} \n MRR Test: {han_mrr_test}\n')
        han_avgpr += han_avgpr_test
        han_mrr += han_mrr_test
        ch+=1
        
        
    han_avgpr_all = han_avgpr / ch
    han_mrr_all = han_mrr / ch
    
    print('HAN')
    print(f'\tAVGPR over time: Test: {han_avgpr_all}')
    print(f'\tMRR over time: Test: {han_mrr_all}')
    print()
    
    return

In [None]:
def training_gat(snapshots, hidden_conv_1, hidden_conv_2, device='cpu'):
    num_snap = len(snapshots)
    homdata = copy.deepcopy(snapshots[0]).to_homogeneous()
    num_snap = len(snapshots)
    hetdata = copy.deepcopy(snapshots[0])
    edge_types = list(hetdata.edge_index_dict.keys())
    
    lr = 0.001
    weight_decay = 5e-3
    
    in_channels_homo = homdata.x.size(1)
    num_nodes_homo = homdata.x.size(0)
    
    #GAT
    gat = TAOBAOGAT(in_channels_homo, hidden_conv_1, hidden_conv_2)
    gat.reset_parameters()
    gatopt = torch.optim.Adam(params=gat.parameters(), lr=lr, weight_decay = weight_decay)
    
    gat_avgpr = 0
    gat_mrr = 0
    
    ch=0
    for i in range(num_snap-1):
        if ch>=9:break
        #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
        snapshot = copy.deepcopy(snapshots[i])
        hom_snapshot = snapshot.to_homogeneous()
        hom_transform = RandomLinkSplit(num_val=0.0, num_test=0.20)
        hom_train_data, _, hom_val_data = hom_transform(hom_snapshot)
        
        het_transform = RandomLinkSplit(num_val=0.0,num_test=0.20, edge_types=edge_types)
        het_train_data, _, het_val_data = het_transform(snapshot)
     
        het_test_data = copy.deepcopy(snapshots[i+1])
        het_future_neg_edge_index = negative_sampling(
            edge_index=het_test_data['user','buy','item'].edge_index, #positive edges
            num_nodes=len(het_test_data['user'].x), # number of nodes
            num_neg_samples=het_test_data['user','buy','item'].edge_index.size(1)) # number of neg_sample equal to number of pos_edges
        #edge index ok, edge_label concat, edge_label_index concat
        num_pos_edge = het_test_data['user','buy','item'].edge_index.size(1)
        het_test_data['user','buy','item'].edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        het_test_data['user','buy','item'].edge_label_index = \
            torch.cat([het_test_data['user','buy','item'].edge_index, het_future_neg_edge_index], dim=-1)
        
        hom_test_data = copy.deepcopy(snapshots[i+1]).to_homogeneous()
        #num_pos_edge = hom_test_data.edge_index.size(1)
        hom_test_data.edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        #hom_test_data.edge_label_index = \
            #torch.cat([hom_test_data.edge_index, hom_future_neg_edge_index], dim=-1)
        hom_test_data.edge_label_index = het_test_data['user','buy','item'].edge_label_index
        
        #corrupted_edges as field of test_data and val_data
        src_t, _, corrupted_dst =\
            structured_negative_sampling(het_val_data['user','buy','item'].edge_index)
            
        corrupted_edge_index_val = torch.stack([src_t, corrupted_dst])
        
        src_t_test, _, corrupted_dst_test =\
            structured_negative_sampling(het_test_data['user','buy','item'].edge_index)
        
        corrupted_edge_index_test = torch.stack([src_t_test, corrupted_dst_test])
        
        hom_val_data.corrupted_edge_index = corrupted_edge_index_val
        hom_test_data.corrupted_edge_index = corrupted_edge_index_test
        
        #TRAIN AND TEST THE MODEL FOR THE CURRENT SNAP
        
        gat, gat_avgpr_test, gat_mrr_test, gatopt =\
            hom_train_single_snapshot(gat, snapshot, hom_train_data, hom_val_data, hom_test_data, gatopt)
        
        #SAVE AND DISPLAY EVALUATION
        print(f'Snapshot: {ch}\n')
        print(f' GAT AVGPR Test: {gat_avgpr_test} \n MRR Test: {gat_mrr_test}\n')
        gat_avgpr += gat_avgpr_test
        gat_mrr += gat_mrr_test
        ch+=1
        
    gat_avgpr_all = gat_avgpr / ch
    gat_mrr_all = gat_mrr / ch
    
    print('GAT')
    print(f'\tAVGPR over time: Test: {gat_avgpr_all}')
    print(f'\tMRR over time: Test: {gat_mrr_all}')
    print()
    return

In [None]:
def training_durendal(snapshots, hidden_conv_1, hidden_conv_2, device='cpu'):
    num_snap = 10
    first_snap = snapshots[0]
    edge_types = list(first_snap.edge_index_dict.keys())
    
    lr = 0.001
    weight_decay = 5e-3
    
    in_channels = {node: len(v[0]) for node,v in first_snap.x_dict.items()}
    num_nodes = {node: len(v) for node, v in first_snap.x_dict.items()}
    
    #DURENDAL
    durendal = TAOBAODurendal(in_channels, num_nodes, first_snap.metadata(),
                        hidden_conv_1=hidden_conv_1,
                        hidden_conv_2=hidden_conv_2)
    
    durendal.reset_parameters()
    
    durendalopt = torch.optim.Adam(params=durendal.parameters(), lr=lr, weight_decay = weight_decay)
    
    past_dict_1 = {}
    for node in first_snap.x_dict.keys():
        past_dict_1[node] = {}
    for src,r,dst in first_snap.edge_index_dict.keys():
        past_dict_1[src][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(first_snap[src].num_nodes)])
        past_dict_1[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(first_snap[dst].num_nodes)])
        
    past_dict_2 = {}
    for node in first_snap.x_dict.keys():
        past_dict_2[node] = {}
    for src,r,dst in first_snap.edge_index_dict.keys():
        past_dict_2[src][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(first_snap[src].num_nodes)])
        past_dict_2[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(first_snap[dst].num_nodes)])
    
    durendal_avgpr = 0
    durendal_mrr = 0
    
    del(first_snap)
    gc.collect()
    
    ch=0
    for i in range(num_snap-1):
        #if ch >= 24: break
        #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
        snapshot = copy.deepcopy(snapshots[i])
        
        het_transform = RandomLinkSplit(num_val=0.0,num_test=0.20, edge_types=edge_types)
        het_train_data, _, het_val_data = het_transform(snapshot)
     
        het_test_data = copy.deepcopy(snapshots[i+1])
        het_future_neg_edge_index = negative_sampling(
            edge_index=het_test_data['user','buy','item'].edge_index, #positive edges
            num_nodes=len(het_test_data['user'].x), # number of nodes
            num_neg_samples=het_test_data['user','buy','item'].edge_index.size(1)) # number of neg_sample equal to number of pos_edges
        #edge index ok, edge_label concat, edge_label_index concat
        num_pos_edge = het_test_data['user','buy','item'].edge_index.size(1)
        het_test_data['user','buy','item'].edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        het_test_data['user','buy','item'].edge_label_index = \
            torch.cat([het_test_data['user','buy','item'].edge_index, het_future_neg_edge_index], dim=-1)
        
        #corrupted_edges as field of test_data and val_data
        src_t, _, corrupted_dst =\
            structured_negative_sampling(het_val_data['user','buy','item'].edge_index)
            
        corrupted_edge_index_val = torch.stack([src_t, corrupted_dst])
        
        src_t_test, _, corrupted_dst_test =\
            structured_negative_sampling(het_test_data['user','buy','item'].edge_index)
        
        corrupted_edge_index_test = torch.stack([src_t_test, corrupted_dst_test])
        
       
        het_val_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_val
        het_test_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_test
        
        #TRAIN AND TEST THE MODEL FOR THE CURRENT SNAP
        durendal, dur_avgpr_test, dur_mrr_test , past_dict_1, past_dict_2, durendalopt =\
            durendal_train_single_snapshot(durendal, snapshot, i, het_train_data, het_val_data, het_test_data,\
                                  past_dict_1, past_dict_2, durendalopt)
        
        #SAVE AND DISPLAY EVALUATION
        print(f'Snapshot: {ch}\n')
        print(f' DURENDAL AVGPR Test: {dur_avgpr_test} \n MRR Test: {dur_mrr_test}\n')
        durendal_avgpr += dur_avgpr_test
        durendal_mrr += dur_mrr_test
        ch+=1
        
        
    #durendal_avgpr_all = durendal_avgpr / (num_snap-1)
    #durendal_mrr_all = durendal_mrr / (num_snap-1)
    durendal_avgpr_all = durendal_avgpr / (ch)
    durendal_mrr_all = durendal_mrr / (ch)
    
    print('DURENDAL')
    print(f'\tAVGPR over time: Test: {durendal_avgpr_all}')
    print(f'\tMRR over time: Test: {durendal_mrr_all}')
    return

In [None]:
def training_gconvgru(snapshots, hidden_conv_1, hidden_conv_2, device='cpu'):
    num_snap = len(snapshots)
    hetdata = copy.deepcopy(snapshots[0])
    homdata = copy.deepcopy(snapshots[0]).to_homogeneous()
    edge_types = list(hetdata.edge_index_dict.keys())
    
    lr = 0.001
    weight_decay = 5e-3
    
    in_channels = {node: len(v[0]) for node,v in hetdata.x_dict.items()}
    num_nodes = {node: len(v) for node, v in hetdata.x_dict.items()}
    
    in_channels_homo = homdata.x.size(1)
    num_nodes_homo = homdata.x.size(0)
    
    #GConvGRU
    gcgru = TAOBAOGConvGRU(in_channels_homo, hidden_conv_2)
    gcgru.reset_parameters()
    gcgruopt = torch.optim.Adam(params=gcgru.parameters(), lr=lr, weight_decay = weight_decay)
    
    past_dict_1 = {}
    for node in hetdata.x_dict.keys():
        past_dict_1[node] = {}
    for src,r,dst in hetdata.edge_index_dict.keys():
        past_dict_1[src][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(hetdata[src].num_nodes)])
        past_dict_1[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(hetdata[dst].num_nodes)])
        
    past_dict_2 = {}
    for node in hetdata.x_dict.keys():
        past_dict_2[node] = {}
    for src,r,dst in hetdata.edge_index_dict.keys():
        past_dict_2[src][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(hetdata[src].num_nodes)])
        past_dict_2[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(hetdata[dst].num_nodes)])
    
    gcgru_avgpr = 0
    gcgru_mrr = 0
    
    H = None
    H_1 = None
    C_1 = None
    H_2 = None
    C_2 = None
    
    ch=0
    for i in range(num_snap-1):
        if ch>=9: break
        #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
        snapshot = copy.deepcopy(snapshots[i])
        hom_snapshot = snapshot.to_homogeneous()
        hom_transform = RandomLinkSplit(num_val=0.0, num_test=0.20)
        hom_train_data, _, hom_val_data = hom_transform(hom_snapshot)
        
        het_transform = RandomLinkSplit(num_val=0.0,num_test=0.20, edge_types=edge_types)
        het_train_data, _, het_val_data = het_transform(snapshot)
     
        het_test_data = copy.deepcopy(snapshots[i+1])
        het_future_neg_edge_index = negative_sampling(
            edge_index=het_test_data['user','buy','item'].edge_index, #positive edges
            num_nodes=len(het_test_data['user'].x), # number of nodes
            num_neg_samples=het_test_data['user','buy','item'].edge_index.size(1)) # number of neg_sample equal to number of pos_edges
        #edge index ok, edge_label concat, edge_label_index concat
        num_pos_edge = het_test_data['user','buy','item'].edge_index.size(1)
        het_test_data['user','buy','item'].edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        het_test_data['user','buy','item'].edge_label_index = \
            torch.cat([het_test_data['user','buy','item'].edge_index, het_future_neg_edge_index], dim=-1)
        
        hom_test_data = copy.deepcopy(snapshots[i+1]).to_homogeneous()
        """
        hom_future_neg_edge_index = negative_sampling(
            edge_index = hom_test_data.edge_index,
            num_nodes = hom_test_data.num_nodes,
            num_neg_samples = hom_test_data.edge_index.size(1)
        )
        """
        #num_pos_edge = hom_test_data.edge_index.size(1)
        hom_test_data.edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        #hom_test_data.edge_label_index = \
            #torch.cat([hom_test_data.edge_index, hom_future_neg_edge_index], dim=-1)
        hom_test_data.edge_label_index = het_test_data['user','buy','item'].edge_label_index
        
        #corrupted_edges as field of test_data and val_data
        src_t, _, corrupted_dst =\
            structured_negative_sampling(het_val_data['user','buy','item'].edge_index)
            
        corrupted_edge_index_val = torch.stack([src_t, corrupted_dst])
        
        src_t_test, _, corrupted_dst_test =\
            structured_negative_sampling(het_test_data['user','buy','item'].edge_index)
        
        corrupted_edge_index_test = torch.stack([src_t_test, corrupted_dst_test])
        
        hom_val_data.corrupted_edge_index = corrupted_edge_index_val
        het_val_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_val
        hom_test_data.corrupted_edge_index = corrupted_edge_index_test
        het_test_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_test
        
        #TRAIN AND TEST THE MODEL FOR THE CURRENT SNAP
        
        gcgru, gcgru_avgpr_test, gcgru_mrr_test, H, gcgruopt =\
            gcgru_train_single_snapshot(gcgru, snapshot, hom_train_data, hom_val_data, hom_test_data, gcgruopt, H)

        
        #SAVE AND DISPLAY EVALUATION
        print(f'Snapshot: {i}\n')
        print(f' GConvGRU AVGPR Test: {gcgru_avgpr_test} \n MRR Test: {gcgru_mrr_test}\n')
        gcgru_avgpr += gcgru_avgpr_test
        gcgru_mrr += gcgru_mrr_test
        ch+=1
        
        
    gcgru_avgpr_all = gcgru_avgpr / ch
    gcgru_mrr_all = gcgru_mrr / ch
    
    
    print('GConvGRU')
    print(f'\tAVGPR over time: Test: {gcgru_avgpr_all}')
    print(f'\tMRR over time: Test: {gcgru_mrr_all}')
    print()
    
    return

In [None]:
def training_ev(snapshots, hidden_conv_1, hidden_conv_2, device='cpu'):
    num_snap = len(snapshots)
    hetdata = copy.deepcopy(snapshots[0])
    homdata = copy.deepcopy(snapshots[0]).to_homogeneous()
    edge_types = list(hetdata.edge_index_dict.keys())
    
    lr = 0.001
    weight_decay = 5e-3
    
    in_channels = {node: len(v[0]) for node,v in hetdata.x_dict.items()}
    num_nodes = {node: len(v) for node, v in hetdata.x_dict.items()}
    
    in_channels_homo = homdata.x.size(1)
    num_nodes_homo = homdata.x.size(0)
    
    #EvolveGCN
    ev = TAOBAOEvolveGCN(in_channels_homo, num_nodes_homo)
    ev.reset_parameters()
    evopt = torch.optim.Adam(params=ev.parameters(), lr=lr, weight_decay = weight_decay)
        
    ev_avgpr = 0
    ev_mrr = 0
    
    H = None
    H_1 = None
    C_1 = None
    H_2 = None
    C_2 = None
    
    ch=0
    for i in range(num_snap-1):
        if ch >= 9: break
        #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
        snapshot = copy.deepcopy(snapshots[i])
        hom_snapshot = snapshot.to_homogeneous()
        hom_transform = RandomLinkSplit(num_val=0.0, num_test=0.20)
        hom_train_data, _, hom_val_data = hom_transform(hom_snapshot)
        
        het_transform = RandomLinkSplit(num_val=0.0,num_test=0.20, edge_types=edge_types)
        het_train_data, _, het_val_data = het_transform(snapshot)
     
        het_test_data = copy.deepcopy(snapshots[i+1])
        het_future_neg_edge_index = negative_sampling(
            edge_index=het_test_data['user','buy','item'].edge_index, #positive edges
            num_nodes=len(het_test_data['user'].x), # number of nodes
            num_neg_samples=het_test_data['user','buy','item'].edge_index.size(1)) # number of neg_sample equal to number of pos_edges
        #edge index ok, edge_label concat, edge_label_index concat
        num_pos_edge = het_test_data['user','buy','item'].edge_index.size(1)
        het_test_data['user','buy','item'].edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        het_test_data['user','buy','item'].edge_label_index = \
            torch.cat([het_test_data['user','buy','item'].edge_index, het_future_neg_edge_index], dim=-1)
        
        hom_test_data = copy.deepcopy(snapshots[i+1]).to_homogeneous()
        """
        hom_future_neg_edge_index = negative_sampling(
            edge_index = hom_test_data.edge_index,
            num_nodes = hom_test_data.num_nodes,
            num_neg_samples = hom_test_data.edge_index.size(1)
        )
        """
        #num_pos_edge = hom_test_data.edge_index.size(1)
        hom_test_data.edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        #hom_test_data.edge_label_index = \
            #torch.cat([hom_test_data.edge_index, hom_future_neg_edge_index], dim=-1)
        hom_test_data.edge_label_index = het_test_data['user','buy','item'].edge_label_index
        
        #corrupted_edges as field of test_data and val_data
        src_t, _, corrupted_dst =\
            structured_negative_sampling(het_val_data['user','buy','item'].edge_index)
            
        corrupted_edge_index_val = torch.stack([src_t, corrupted_dst])
        
        src_t_test, _, corrupted_dst_test =\
            structured_negative_sampling(het_test_data['user','buy','item'].edge_index)
        
        corrupted_edge_index_test = torch.stack([src_t_test, corrupted_dst_test])
        
        hom_val_data.corrupted_edge_index = corrupted_edge_index_val
        het_val_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_val
        hom_test_data.corrupted_edge_index = corrupted_edge_index_test
        het_test_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_test
        
        #TRAIN AND TEST THE MODEL FOR THE CURRENT SNAP
        ev, ev_avgpr_test, ev_mrr_test, evopt =\
            hom_train_single_snapshot(ev, snapshot, hom_train_data, hom_val_data, hom_test_data, evopt)
        
        #SAVE AND DISPLAY EVALUATION
        print(f'Snapshot: {i}\n')
        print(f' EvolveGCN AVGPR Test: {ev_avgpr_test} \n MRR Test: {ev_mrr_test}\n')
        ev_avgpr += ev_avgpr_test
        ev_mrr += ev_mrr_test
        ch+=1
        
    ev_avgpr_all = ev_avgpr / ch
    ev_mrr_all = ev_mrr / ch
    
    print('EvolveGCN')
    print(f'\tAVGPR over time: Test: {ev_avgpr_all}')
    print(f'\tMRR over time: Test: {ev_mrr_all}')
    print()
    
    return

In [None]:
def training_hev(snapshots, hidden_conv_1, hidden_conv_2, device='cpu'):
    num_snap = len(snapshots)
    hetdata = copy.deepcopy(snapshots[0])
    homdata = copy.deepcopy(snapshots[0]).to_homogeneous()
    edge_types = list(hetdata.edge_index_dict.keys())
    
    lr = 0.001
    weight_decay = 5e-3
    
    in_channels = {node: len(v[0]) for node,v in hetdata.x_dict.items()}
    num_nodes = {node: len(v) for node, v in hetdata.x_dict.items()}
    
    in_channels_homo = homdata.x.size(1)
    num_nodes_homo = homdata.x.size(0)
    
    #HetEvolveGCN
    hev = TAOBAOHEGCN(in_channels_homo, num_nodes_homo, list(hetdata.edge_index_dict.keys()))
    hev.reset_parameters()
    hevopt = torch.optim.Adam(params=hev.parameters(), lr=lr, weight_decay = weight_decay)
    
    hev_avgpr = 0
    hev_mrr = 0
    
    H = None
    H_1 = None
    C_1 = None
    H_2 = None
    C_2 = None
    
    ch=0
    for i in range(num_snap-1):
        if ch>=9: break
        #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
        snapshot = copy.deepcopy(snapshots[i])
        hom_snapshot = snapshot.to_homogeneous()
        hom_transform = RandomLinkSplit(num_val=0.0, num_test=0.20)
        hom_train_data, _, hom_val_data = hom_transform(hom_snapshot)
        
        het_transform = RandomLinkSplit(num_val=0.0,num_test=0.20, edge_types=edge_types)
        het_train_data, _, het_val_data = het_transform(snapshot)
     
        het_test_data = copy.deepcopy(snapshots[i+1])
        het_future_neg_edge_index = negative_sampling(
            edge_index=het_test_data['user','buy','item'].edge_index, #positive edges
            num_nodes=len(het_test_data['user'].x), # number of nodes
            num_neg_samples=het_test_data['user','buy','item'].edge_index.size(1)) # number of neg_sample equal to number of pos_edges
        #edge index ok, edge_label concat, edge_label_index concat
        num_pos_edge = het_test_data['user','buy','item'].edge_index.size(1)
        het_test_data['user','buy','item'].edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        het_test_data['user','buy','item'].edge_label_index = \
            torch.cat([het_test_data['user','buy','item'].edge_index, het_future_neg_edge_index], dim=-1)
        
        hom_test_data = copy.deepcopy(snapshots[i+1]).to_homogeneous()
        """
        hom_future_neg_edge_index = negative_sampling(
            edge_index = hom_test_data.edge_index,
            num_nodes = hom_test_data.num_nodes,
            num_neg_samples = hom_test_data.edge_index.size(1)
        )
        """
        #num_pos_edge = hom_test_data.edge_index.size(1)
        hom_test_data.edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        #hom_test_data.edge_label_index = \
            #torch.cat([hom_test_data.edge_index, hom_future_neg_edge_index], dim=-1)
        hom_test_data.edge_label_index = het_test_data['user','buy','item'].edge_label_index
        
        #corrupted_edges as field of test_data and val_data
        src_t, _, corrupted_dst =\
            structured_negative_sampling(het_val_data['user','buy','item'].edge_index)
            
        corrupted_edge_index_val = torch.stack([src_t, corrupted_dst])
        
        src_t_test, _, corrupted_dst_test =\
            structured_negative_sampling(het_test_data['user','buy','item'].edge_index)
        
        corrupted_edge_index_test = torch.stack([src_t_test, corrupted_dst_test])
        
        hom_val_data.corrupted_edge_index = corrupted_edge_index_val
        het_val_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_val
        hom_test_data.corrupted_edge_index = corrupted_edge_index_test
        het_test_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_test
        
        #TRAIN AND TEST THE MODEL FOR THE CURRENT SNAP
        
        hev, hev_avgpr_test, hev_mrr_test, hevopt =\
            het_train_single_snapshot(hev, snapshot, het_train_data, het_val_data, het_test_data, hevopt)
        
        #SAVE AND DISPLAY EVALUATION
        print(f'Snapshot: {i}\n')
        print(f' HetEvolveGCN AVGPR Test: {hev_avgpr_test} \n MRR Test: {hev_mrr_test}\n')
        print(f' ATU AVGPR Test: {atu_avgpr_test} \n MRR Test: {atu_mrr_test}\n')
        hev_avgpr += hev_avgpr_test
        hev_mrr += hev_mrr_test
        ch+=1
        
    hev_avgpr_all = hev_avgpr / (ch)
    hev_mrr_all = hev_mrr / (ch)
    print('HetEvolveGCN')
    print(f'\tAVGPR over time: Test: {hev_avgpr_all}')
    print(f'\tMRR over time: Test: {hev_mrr_all}')
    
    return

In [None]:
def training_atu(snapshots, hidden_conv_1, hidden_conv_2, device='cpu'):
    num_snap = len(snapshots)
    hetdata = copy.deepcopy(snapshots[0])
    homdata = copy.deepcopy(snapshots[0]).to_homogeneous()
    edge_types = list(hetdata.edge_index_dict.keys())
    
    lr = 0.001
    weight_decay = 5e-3
    
    in_channels = {node: len(v[0]) for node,v in hetdata.x_dict.items()}
    num_nodes = {node: len(v) for node, v in hetdata.x_dict.items()}
    
    in_channels_homo = homdata.x.size(1)
    num_nodes_homo = homdata.x.size(0)
    
    #ATU
    atu = TAOBAOATU(in_channels, num_nodes, hetdata.metadata(),
                        hidden_conv_1=hidden_conv_1,
                        hidden_conv_2=hidden_conv_2)
    atu.reset_parameters()
    atuopt = torch.optim.Adam(params=atu.parameters(), lr=lr, weight_decay = weight_decay)
    
    past_dict_1_atu = {}
    for node in hetdata.x_dict.keys():
        past_dict_1_atu[node] = {}
    for src,r,dst in hetdata.edge_index_dict.keys():
        past_dict_1_atu[src][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(hetdata[src].num_nodes)])
        past_dict_1_atu[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_1)] for i in range(hetdata[dst].num_nodes)])
        
    past_dict_2_atu = {}
    for node in hetdata.x_dict.keys():
        past_dict_2_atu[node] = {}
    for src,r,dst in hetdata.edge_index_dict.keys():
        past_dict_2_atu[src][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(hetdata[src].num_nodes)])
        past_dict_2_atu[dst][r] = torch.Tensor([[0 for j in range(hidden_conv_2)] for i in range(hetdata[dst].num_nodes)])
    
    atu_avgpr = 0
    atu_mrr = 0
    
    ch=0
    for i in range(num_snap-1):
        if ch>=9: break
        #CREATE TRAIN + VAL + TEST SET FOR THE CURRENT SNAP
        snapshot = copy.deepcopy(snapshots[i])
        hom_snapshot = snapshot.to_homogeneous()
        hom_transform = RandomLinkSplit(num_val=0.0, num_test=0.20)
        hom_train_data, _, hom_val_data = hom_transform(hom_snapshot)
        
        het_transform = RandomLinkSplit(num_val=0.0,num_test=0.20, edge_types=edge_types)
        het_train_data, _, het_val_data = het_transform(snapshot)
     
        het_test_data = copy.deepcopy(snapshots[i+1])
        het_future_neg_edge_index = negative_sampling(
            edge_index=het_test_data['user','buy','item'].edge_index, #positive edges
            num_nodes=len(het_test_data['user'].x), # number of nodes
            num_neg_samples=het_test_data['user','buy','item'].edge_index.size(1)) # number of neg_sample equal to number of pos_edges
        #edge index ok, edge_label concat, edge_label_index concat
        num_pos_edge = het_test_data['user','buy','item'].edge_index.size(1)
        het_test_data['user','buy','item'].edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        het_test_data['user','buy','item'].edge_label_index = \
            torch.cat([het_test_data['user','buy','item'].edge_index, het_future_neg_edge_index], dim=-1)
        
        hom_test_data = copy.deepcopy(snapshots[i+1]).to_homogeneous()
        """
        hom_future_neg_edge_index = negative_sampling(
            edge_index = hom_test_data.edge_index,
            num_nodes = hom_test_data.num_nodes,
            num_neg_samples = hom_test_data.edge_index.size(1)
        )
        """
        #num_pos_edge = hom_test_data.edge_index.size(1)
        hom_test_data.edge_label = torch.Tensor(\
            np.array([1 for i in range(num_pos_edge)] + [0 for i in range(num_pos_edge)]))
        #hom_test_data.edge_label_index = \
            #torch.cat([hom_test_data.edge_index, hom_future_neg_edge_index], dim=-1)
        hom_test_data.edge_label_index = het_test_data['user','buy','item'].edge_label_index
        
        #corrupted_edges as field of test_data and val_data
        src_t, _, corrupted_dst =\
            structured_negative_sampling(het_val_data['user','buy','item'].edge_index)
            
        corrupted_edge_index_val = torch.stack([src_t, corrupted_dst])
        
        src_t_test, _, corrupted_dst_test =\
            structured_negative_sampling(het_test_data['user','buy','item'].edge_index)
        
        corrupted_edge_index_test = torch.stack([src_t_test, corrupted_dst_test])
        
        hom_val_data.corrupted_edge_index = corrupted_edge_index_val
        het_val_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_val
        hom_test_data.corrupted_edge_index = corrupted_edge_index_test
        het_test_data['user','buy','item'].corrupted_edge_index = corrupted_edge_index_test
        
        #TRAIN AND TEST THE MODEL FOR THE CURRENT SNAP
        atu, atu_avgpr_test, atu_mrr_test , past_dict_1_atu, past_dict_2_atu, atuopt =\
            durendal_train_single_snapshot(atu, snapshot, i, het_train_data, het_val_data, het_test_data,\
                                  past_dict_1_atu, past_dict_2_atu, atuopt)
        
        #SAVE AND DISPLAY EVALUATION
        print(f'Snapshot: {i}\n')
        print(f' ATU AVGPR Test: {atu_avgpr_test} \n MRR Test: {atu_mrr_test}\n')
        atu_avgpr += atu_avgpr_test
        atu_mrr += atu_mrr_test
        ch+=1
        
    atu_avgpr_all = atu_avgpr / ch
    atu_mrr_all = atu_mrr / ch

    
    print('ATU')
    print(f'\tAVGPR over time: Test: {atu_avgpr_all}')
    print(f'\tMRR over time: Test: {atu_mrr_all}')
    
    return

In [None]:
def durendal_train_single_snapshot(model, data, i_snap, train_data, val_data, test_data,\
                          past_dict_1, past_dict_2,\
                          optimizer, device='cpu', num_epochs=50, verbose=False):
    
    mrr_val_max = 0
    avgpr_val_max = 0
    best_model = model
    train_data = train_data.to(device)
    best_epoch = -1
    best_past_dict_1 = {}
    best_past_dict_2 = {}
    
    tol = 5e-2
    
    for epoch in range(num_epochs):
        model.train()
        ## Note
        ## 1. Zero grad the optimizer
        ## 2. Compute loss and backpropagate
        ## 3. Update the model parameters
        optimizer.zero_grad()
            
        pred, past_dict_1, past_dict_2 =\
            model(train_data.x_dict, train_data.edge_index_dict, train_data['user','buy','item'].edge_label_index,\
                  i_snap, past_dict_1, past_dict_2)
        
        loss = model.loss(pred, train_data['user','buy','item'].edge_label.type_as(pred)) #loss to fine tune on current snapshot

        loss.backward(retain_graph=True)  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.

        ##########################################

        log = 'Epoch: {:03d}\n AVGPR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n MRR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n F1-Score Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n Loss: {}'
        avgpr_score_val, mrr_val = durendal_test(model, i_snap, val_data, data, device)
        
        """
        if mrr_val_max-tol < mrr_val:
            mrr_val_max = mrr_val
            best_epoch = epoch
            best_current_embeddings = current_embeddings
            best_model = copy.deepcopy(model)
        else:
            break
        
        #print(f'Epoch: {epoch} done')
            
        """
        if avgpr_val_max-tol <= avgpr_score_val:
            avgpr_val_max = avgpr_score_val
            best_epoch = epoch
            best_past_dict_1 = past_dict_1
            best_past_dict_2 = past_dict_2
            best_model = model
        else:
            break
        
    avgpr_score_test, mrr_test = durendal_test(model, i_snap, test_data, data, device)
            
    if verbose:
        print(f'Best Epoch: {best_epoch}')
    #print(f'Best Epoch: {best_epoch}')
    
    return best_model, avgpr_score_test, mrr_test, best_past_dict_1, best_past_dict_2, optimizer

In [None]:
def hom_train_single_snapshot(model, data, train_data, val_data, test_data,\
                          optimizer, device='cpu', num_epochs=50, verbose=False):
    
    mrr_val_max = 0
    avgpr_val_max = 0
    best_model = model
    train_data = train_data.to(device)
    best_epoch = -1
    
    tol = 5e-2
    
    for epoch in range(num_epochs):
        model.train()
        ## Note
        ## 1. Zero grad the optimizer
        ## 2. Compute loss and backpropagate
        ## 3. Update the model parameters
        optimizer.zero_grad()
            
        pred = model(train_data.x, train_data.edge_index, train_data.edge_label_index)
        
        loss = model.loss(pred, train_data.edge_label.type_as(pred)) #loss to fine tune on current snapshot

        loss.backward(retain_graph=True)  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.

        ##########################################

        log = 'Epoch: {:03d}\n AVGPR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n MRR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n F1-Score Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n Loss: {}'
        avgpr_score_val, mrr_val = hom_test(model, val_data, data, device)
        
        """
        if mrr_val_max-tol < mrr_val:
            mrr_val_max = mrr_val
            best_epoch = epoch
            best_current_embeddings = current_embeddings
            best_model = copy.deepcopy(model)
        else:
            break
        
        #print(f'Epoch: {epoch} done')
            
        """
        if avgpr_val_max-tol <= avgpr_score_val:
            avgpr_val_max = avgpr_score_val
            best_epoch = epoch
            best_model = model
        else:
            break
        
    avgpr_score_test, mrr_test = hom_test(model, test_data, data, device)
            
    if verbose:
        print(f'Best Epoch: {best_epoch}')
    #print(f'Best Epoch: {best_epoch}')
    
    return best_model, avgpr_score_test, mrr_test, optimizer

In [None]:
def het_train_single_snapshot(model, data, train_data, val_data, test_data,\
                          optimizer, device='cpu', num_epochs=50, verbose=False):
    
    mrr_val_max = 0
    avgpr_val_max = 0
    best_model = model
    train_data = train_data.to(device)
    best_epoch = -1
    
    tol = 5e-2
    
    for epoch in range(num_epochs):
        model.train()
        ## Note
        ## 1. Zero grad the optimizer
        ## 2. Compute loss and backpropagate
        ## 3. Update the model parameters
        optimizer.zero_grad()
            
        pred =\
            model(train_data.x_dict, train_data.edge_index_dict, train_data['user','buy','item'].edge_label_index)
        
        loss = model.loss(pred, train_data['user','buy','item'].edge_label.type_as(pred)) #loss to fine tune on current snapshot

        loss.backward(retain_graph=True)  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.

        ##########################################

        log = 'Epoch: {:03d}\n AVGPR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n MRR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n F1-Score Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n Loss: {}'
        avgpr_score_val, mrr_val = het_test(model, val_data, data, device)
        
        """
        if mrr_val_max-tol < mrr_val:
            mrr_val_max = mrr_val
            best_epoch = epoch
            best_current_embeddings = current_embeddings
            best_model = copy.deepcopy(model)
        else:
            break
        
        #print(f'Epoch: {epoch} done')
            
        """
        if avgpr_val_max-tol <= avgpr_score_val:
            avgpr_val_max = avgpr_score_val
            best_epoch = epoch
            best_model = model
        else:
            break
        
    avgpr_score_test, mrr_test = het_test(model, test_data, data, device)
            
    if verbose:
        print(f'Best Epoch: {best_epoch}')
    #print(f'Best Epoch: {best_epoch}')
    
    return best_model, avgpr_score_test, mrr_test, optimizer

In [None]:
def gcgru_train_single_snapshot(model, data, train_data, val_data, test_data,\
                          optimizer, H=None, device='cpu', num_epochs=50, verbose=False):
    
    mrr_val_max = 0
    avgpr_val_max = 0
    best_model = model
    train_data = train_data.to(device)
    best_epoch = -1
    
    tol = 5e-2
    
    best_H = None
    
    for epoch in range(num_epochs):
        model.train()
        ## Note
        ## 1. Zero grad the optimizer
        ## 2. Compute loss and backpropagate
        ## 3. Update the model parameters
        optimizer.zero_grad()
        
        #H = None
            
        pred, H = model(train_data.x, train_data.edge_index, train_data.edge_label_index, H)
        
        loss = model.loss(pred, train_data.edge_label.type_as(pred)) #loss to fine tune on current snapshot

        loss.backward(retain_graph=True)  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.

        ##########################################

        log = 'Epoch: {:03d}\n AVGPR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n MRR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n F1-Score Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n Loss: {}'
        avgpr_score_val, mrr_val = gcgru_test(model, val_data, data, device)
        
        """
        if mrr_val_max-tol < mrr_val:
            mrr_val_max = mrr_val
            best_epoch = epoch
            best_current_embeddings = current_embeddings
            best_model = copy.deepcopy(model)
        else:
            break
        
        #print(f'Epoch: {epoch} done')
            
        """
        if avgpr_val_max-tol <= avgpr_score_val:
            avgpr_val_max = avgpr_score_val
            best_H = H.clone()
            best_epoch = epoch
            best_model = model
        else:
            break
        
    avgpr_score_test, mrr_test = gcgru_test(model, test_data, data, device)
            
    if verbose:
        print(f'Best Epoch: {best_epoch}')
    #print(f'Best Epoch: {best_epoch}')
    
    return best_model, avgpr_score_test, mrr_test, best_H, optimizer

In [None]:
def htlstm_train_single_snapshot(model, data, train_data, val_data, test_data,\
                                 H_1, C_1, H_2, C_2, optimizer, device='cpu', num_epochs=50, verbose=False):
    mrr_val_max = 0
    avgpr_val_max = 0
    best_model = model
    train_data = train_data.to(device)
    best_epoch = -1
    best_H_1 = None
    best_H_2 = None
    best_C_1 = None
    best_C_2 = None
    
    tol = 5e-2
    
    for epoch in range(num_epochs):
        model.train()
        ## Note
        ## 1. Zero grad the optimizer
        ## 2. Compute loss and backpropagate
        ## 3. Update the model parameters
        optimizer.zero_grad()
            
        pred, cH_1, cC_1, cH_2, cC_2 =\
            model(train_data.x_dict, train_data.edge_index_dict, train_data['user','buy','item'].edge_label_index,\
                  H_1, C_1, H_2, C_2)
        
        loss = model.loss(pred, train_data['user','buy','item'].edge_label.type_as(pred)) #loss to fine tune on current snapshot

        loss.backward(retain_graph=True)  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.

        ##########################################

        log = 'Epoch: {:03d}\n AVGPR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n MRR Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n F1-Score Train: {:.4f}, Val: {:.4f}, Test: {:.4f}\n Loss: {}'
        avgpr_score_val, mrr_val = htlstm_test(model, val_data, data, H_1, C_1, H_2, C_2, device)
        
        """
        if mrr_val_max-tol < mrr_val:
            mrr_val_max = mrr_val
            best_epoch = epoch
            best_current_embeddings = current_embeddings
            best_model = copy.deepcopy(model)
        else:
            break
        
        #print(f'Epoch: {epoch} done')
            
        """
        if avgpr_val_max-tol <= avgpr_score_val:
            avgpr_val_max = avgpr_score_val
            best_epoch = epoch
            best_H_1 = cH_1
            best_H_2 = cH_2
            best_C_1 = cC_1
            best_C_2 = cC_2
            best_model = model
        else:
            break
        
    avgpr_score_test, mrr_test =  htlstm_test(model, test_data, data, H_1, C_1, H_2, C_2, device)
            
    if verbose:
        print(f'Best Epoch: {best_epoch}')
    #print(f'Best Epoch: {best_epoch}')
    
    return best_model, avgpr_score_test, mrr_test, best_H_1, best_C_1, best_H_2, best_C_2, optimizer

In [None]:
def durendal_test(model, i_snap, test_data, data, device='cpu'):
    
    model.eval()

    test_data = test_data.to(device)
    
    num_pos = (len(test_data['user','buy','item'].edge_label_index[0])//2)

    h, *_ = model(test_data.x_dict, test_data.edge_index_dict, test_data['user','buy','item'].edge_label_index, i_snap)
    fake, *_ = model(test_data.x_dict, test_data.edge_index_dict, test_data['user','buy','item'].corrupted_edge_index, i_snap)
    
    pred_cont = torch.sigmoid(h).cpu().detach().numpy()
    fake_preds = torch.sigmoid(fake).cpu().detach().numpy()
    
    label = test_data['user','buy','item'].edge_label.cpu().detach().numpy()
      
    avgpr_score = average_precision_score(label, pred_cont)
    mrr_score = compute_mrr(pred_cont[:num_pos], fake_preds)
    
    return avgpr_score, mrr_score

In [None]:
def hom_test(model, test_data, data, device='cpu'):
        
    model.eval()

    test_data = test_data.to(device)

    num_pos = (len(test_data.edge_label_index[0])//2)

    h = model(test_data.x, test_data.edge_index, test_data.edge_label_index)
    fake = model(test_data.x, test_data.edge_index, test_data.corrupted_edge_index)
    
    pred_cont = torch.sigmoid(h).cpu().detach().numpy()
    fake_preds = torch.sigmoid(fake).cpu().detach().numpy()
    
    label = test_data.edge_label.cpu().detach().numpy()
      
    avgpr_score = average_precision_score(label, pred_cont)
    mrr_score = compute_mrr(pred_cont[:num_pos], fake_preds)
    
    return avgpr_score, mrr_score

In [None]:
def het_test(model, test_data, data, device='cpu'):
        
    model.eval()

    test_data = test_data.to(device)
    
    num_pos = (len(test_data['user','buy','item'].edge_label_index[0])//2)

    h = model(test_data.x_dict, test_data.edge_index_dict, test_data['user','buy','item'].edge_label_index)
    fake = model(test_data.x_dict, test_data.edge_index_dict, test_data['user','buy','item'].corrupted_edge_index)
    
    pred_cont = torch.sigmoid(h).cpu().detach().numpy()
    fake_preds = torch.sigmoid(fake).cpu().detach().numpy()
    
    label = test_data['user','buy','item'].edge_label.cpu().detach().numpy()
      
    avgpr_score = average_precision_score(label, pred_cont)
    mrr_score = compute_mrr(pred_cont[:num_pos], fake_preds)
    
    return avgpr_score, mrr_score

In [None]:
def gcgru_test(model, test_data, data, device='cpu'):
        
    model.eval()

    test_data = test_data.to(device)
    
    num_pos = (len(test_data.edge_label_index[0])//2)

    h, _ = model(test_data.x, test_data.edge_index, test_data.edge_label_index)
    fake, _ = model(test_data.x, test_data.edge_index, test_data.corrupted_edge_index)
    
    pred_cont = torch.sigmoid(h).cpu().detach().numpy()
    fake_preds = torch.sigmoid(fake).cpu().detach().numpy()
    
    label = test_data.edge_label.cpu().detach().numpy()
      
    avgpr_score = average_precision_score(label, pred_cont)
    mrr_score = compute_mrr(pred_cont[:num_pos], fake_preds)
    
    return avgpr_score, mrr_score

In [None]:
def htlstm_test(model, test_data, data, H_1, C_1, H_2, C_2, device='cpu'):
        
    model.eval()

    test_data = test_data.to(device)
    
    num_pos = (len(test_data['user','buy','item'].edge_label_index[0])//2)

    h, *_ = model(test_data.x_dict, test_data.edge_index_dict, test_data['user','buy','item'].edge_label_index,\
                 H_1, C_1, H_2, C_2)
    
    fake, *_ = model(test_data.x_dict, test_data.edge_index_dict, test_data['user','buy','item'].corrupted_edge_index,\
                    H_1, C_1, H_2, C_2)
    
    pred_cont = torch.sigmoid(h).cpu().detach().numpy()
    fake_preds = torch.sigmoid(fake).cpu().detach().numpy()
    
    label = test_data['user','buy','item'].edge_label.cpu().detach().numpy()
      
    avgpr_score = average_precision_score(label, pred_cont)
    mrr_score = compute_mrr(pred_cont[:num_pos], fake_preds)
    
    return avgpr_score, mrr_score

In [None]:
hidden_conv_1=256
hidden_conv_2=128

In [None]:
import random
device = torch.device('cuda')
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.empty_cache()

In [None]:
from IPython import get_ipython
from IPython.core import magic_arguments
from IPython.core.magic import register_cell_magic
from IPython.utils.capture import capture_output

@magic_arguments.magic_arguments()
@magic_arguments.argument('output', type=str, default='', nargs='?',
    help="""The name of the variable in which to store output.
    This is a utils.io.CapturedIO object with stdout/err attributes
    for the text of the captured output.
    CapturedOutput also has a show() method for displaying the output,
    and __call__ as well, so you can use that to quickly display the
    output.
    If unspecified, captured output is discarded.
    """
)
@magic_arguments.argument('--no-stderr', action="store_true",
    help="""Don't capture stderr."""
)
@magic_arguments.argument('--no-stdout', action="store_true",
    help="""Don't capture stdout."""
)
@magic_arguments.argument('--no-display', action="store_true",
    help="""Don't capture IPython's rich display."""
)
@register_cell_magic
def tee(line, cell):
    args = magic_arguments.parse_argstring(tee, line)
    out = not args.no_stdout
    err = not args.no_stderr
    disp = not args.no_display
    with capture_output(out, err, disp) as io:
        get_ipython().run_cell(cell)
    if args.output:
        get_ipython().user_ns[args.output] = io
    
    io()

Due to memory issues, we strongly suggest to run the training for each candidate models separately, using the training_"model" functions, where model can be \{gat, han, gconvgru, ev, hev, durendal, atu\}

In [None]:
%%tee exp_durendal
training_gat(snapshots, hidden_conv_1, hidden_conv_2)

In [None]:
print(exp_durendal)