In [38]:
%load_ext autoreload
%autoreload 2
import os, sys, re, datetime, random, gzip, json, copy
from pathlib import Path
import networkx as nx
import numpy as np
import pandas as pd
import itertools
import collections
import networkx as nx
import numpy as np
import pandas as pd
import itertools
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, f1_score
from sklearn.ensemble import RandomForestClassifier
from sklearn import linear_model
from sklearn.metrics import roc_auc_score
PROJ_PATH = Path(os.path.join(re.sub("/CTGCN.*$", '', os.getcwd()), 'CTGCN'))
sys.path.insert(1, str(str(PROJ_PATH.parents[0] / 'DySubG/src/')))
from ranking import Evaluation
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, precision_score, recall_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Utils

In [36]:
def get_training_data(data, node_embedding, all_labels, label_mapping):
    train_pos = []
    train_neg = []
    val_pos = []
    val_neg = []
    test_pos = []
    test_neg = []
    for i,d in data.items():
        node_id = d['node_id']
        time_id = d['time_id']
        pos_labels = d['label']
        neg_labels = [i for i in all_labels if i not in pos_labels]
        if len(pos_labels) > 0:
            for l in pos_labels:
                label_nid = label_mapping[l]
                pos_pair = (
                    np.array(node_embedding[time_id][node_id]), 
                    np.array(node_embedding[time_id][label_nid]))
                if d['dataset'] == 'train':
                    train_pos.append(pos_pair)
                elif d['dataset'] == 'val':
                    val_pos.append(pos_pair)
                elif d['dataset'] == 'test':
                    test_pos.append(pos_pair)
        if len(neg_labels) > 0:   
            for l in neg_labels:
                label_nid = label_mapping[l]
                neg_pair = (
                    np.array(node_embedding[time_id][node_id]), 
                    np.array(node_embedding[time_id][label_nid]))
                if d['dataset'] == 'train':
                    train_neg.append(neg_pair)
                elif d['dataset'] == 'val':
                    val_neg.append(neg_pair)
                elif d['dataset'] == 'test':
                    test_neg.append(neg_pair)
    return train_pos, train_neg, val_pos, val_neg, test_pos, test_neg


def get_link_score(fu, fv, operator='HAD'):
    """Given a pair of embeddings, compute link feature based on operator (such as Hadammad product, etc.)"""
    fu = np.array(fu)
    fv = np.array(fv)
    if operator == 'HAD':
        return np.multiply(fu, fv)
    elif operator == 'AVG':
        return (fu + fv) / 2
    elif operator == 'L1':
        return np.abs(fu - fv)
    elif operator == 'L2':
        return (fu - fv) ** 2
    else:
        raise NotImplementedError

def sigmoid(x):
        return 1 / (1 + np.exp(-x))
    
def predict_link_without_classifier(train_pos, train_neg, val_pos, val_neg, test_pos, test_neg):
    pred_train = [sigmoid(np.dot(e[0], e[1].T)) for e in train_pos + train_neg]
    label_train = [1] * len(train_pos) + [0] * len(train_neg)
    pred_val = [sigmoid(np.dot(e[0], e[1].T)) for e in val_pos + val_neg]
    label_val = [1] * len(val_pos) + [0] * len(val_neg)
    pred_test = [sigmoid(np.dot(e[0], e[1].T)) for e in test_pos + test_neg]
    label_test = [1] * len(test_pos) + [0] * len(test_neg)
    return pred_train, label_train, pred_val, label_val, pred_test, label_test


def predict_link_with_classifier(train_pos, train_neg, val_pos, val_neg, test_pos, test_neg, operator):
    train_feats = np.array([get_link_score(e[0], e[1], operator) for e in train_pos + train_neg])
    val_feats = np.array([get_link_score(e[0], e[1], operator) for e in val_pos + val_neg])
    test_feats = np.array([get_link_score(e[0], e[1], operator) for e in test_pos + test_neg])
    label_train = np.array([1] * len(train_pos) + [0] * len(train_neg))
    label_val = np.array([1] * len(val_pos) + [0] * len(val_neg))
    label_test = np.array([1] * len(test_pos) + [0] * len(test_neg))
    
    clf = linear_model.LogisticRegression(max_iter=5000)
    clf.fit(train_feats, label_train)
    pred_train = clf.predict_proba(train_feats)[:, 1]
    pred_val = clf.predict_proba(val_feats)[:, 1]
    pred_test = clf.predict_proba(test_feats)[:, 1]

    return pred_train, label_train, pred_val, label_val, pred_test, label_test, clf

def calc_auc(label_train, pred_train):
    auc = roc_auc_score(label_train, pred_train)
    if auc < 0.5:
        return 1 - auc
    else: 
        return auc
    
def evaluate_classifier(train_pos, train_neg, val_pos, val_neg, test_pos, test_neg, operators=['HAD'], threshold=0.5):
    results = {}
    models = {}
    # Predict without classifier
    pred_train, label_train, pred_val, label_val, pred_test, label_test = predict_link_without_classifier(
        train_pos, train_neg, val_pos, val_neg, test_pos, test_neg)
    results['sigmoid_auc'] = {
        'train': calc_auc(label_train, pred_train), 
        'val': calc_auc(label_val, pred_val),
        'test': calc_auc(label_test, pred_test), 
    }
    results['sigmoid_f1']  = {
        'train': f1_score(label_train, [1 if i >= threshold else 0 for i in pred_train]), 
        'val': f1_score(label_val, [1 if i >= threshold else 0 for i in pred_val]),
        'test': f1_score(label_test,[1 if i >= threshold else 0 for i in pred_test]),
    }
    
    # Predict with classifier
    for operator in operators:
        pred_train, label_train, pred_val, label_val, pred_test, label_test, clf = predict_link_with_classifier(
            train_pos, train_neg, val_pos, val_neg, test_pos, test_neg, operator)
        results[f'{operator}_auc'] = {
            'train': calc_auc(label_train, pred_train), 
            'val': calc_auc(label_val, pred_val),
            'test': calc_auc(label_test, pred_test), 
        }
        results[f'{operator}_f1']  = {
            'train': f1_score(label_train, [1 if i >= threshold else 0 for i in pred_train]), 
            'val': f1_score(label_val, [1 if i >= threshold else 0 for i in pred_val]),
            'test': f1_score(label_test,[1 if i >= threshold else 0 for i in pred_test]),
        }
        models[operator] = clf
    return results, models

def eval_lp(data, all_labels, label_mapping, method='CTGCN-C', num_time_steps=5, exp='imdb', operators=['HAD'], verbose=True, threshold=0.5):
    if method in ['DynAE', 'DynRNN', 'DynAERNN']:
        data = {i:j for i,j in data.items() if j['time_id']>=1}
    nodes = pd.read_csv(f'./data/{exp}/nodes_set/nodes.csv', names=['nodes'])['nodes'].values
    start_idx = min([d['time_id'] for i,d in data.items()])
    node_embedding = {}
    for time_id in range(start_idx, num_time_steps):
        embs = pd.read_csv(
            './data/{}/2.embedding/{}/{:02d}.csv'.format(exp, method, time_id), index_col=0, sep='\t').values
        dict_embs = {k:v for k,v in zip(nodes, embs)}
        node_embedding[time_id] = dict_embs
        
    train_pos, train_neg, val_pos, val_neg, test_pos, test_neg = get_training_data(
        data, node_embedding, all_labels, label_mapping)
    if verbose:
        print(len(train_pos), len(train_neg), len(val_pos), len(val_neg), len(test_pos), len(test_neg))
    results = evaluate_classifier(train_pos, train_neg, val_pos, val_neg, test_pos, test_neg, operators, threshold)
    return results


def print_report(exp='imdb', methods=[], selected_methods=[], threshold=0.5):
    if len(methods) == 0:
        methods = [
            'GCN', 'GAT', 'SAGE', 'GIN', 
            'TgGCN', 'TgGAT', 'TgSAGE', 'TgGIN', 
            'GCRN', 'TIMERS', 'DynAE', 'DynRNN', 'DynAERNN', 'DynGEM', 'DySAT',
            'VGRNN', 'EvolveGCN', 'CTGCN-C',
        ]
    if len(selected_methods) == 0:
        selected_methods = [
            'GCN', 'TgGAT', 'TgSAGE', 'TgGIN',
            'GCRN', 'TIMERS', 'DynAE', 'DynRNN', 'DynAERNN', 'DynGEM', 'DySAT',
            'VGRNN', 'EvolveGCN', 'CTGCN-C']
        
    try:
        pd_edges = pd.read_csv(f'./data/{exp}/0.input/temporal_edge_list.txt', sep=' ', names=['source_id', 'target_id', 'time_id'])
    except:
        pd_edges = pd.read_csv(f'./data/{exp}/0.input/edges.csv')
    
    num_time_steps = pd_edges['time_id'].max() + 1
    if exp.startswith('imdb'):
        data = pd.read_pickle(f'./data/{exp}/0.input/data.pkl')
        all_labels = list(set(itertools.chain(*[d['label'] for i, d in data.items()])))
        label_mapping = pd.read_pickle(f'./data/{exp}/0.input/entity_mapping.pkl')['genre']
    elif exp.startswith('dblp'):
        data = pd.read_pickle(f'./data/{exp}/0.input/data.pkl')
        all_labels = list(set(itertools.chain(*[d['label'] for i, d in data.items()])))
        cid2cname = pd.read_pickle('../DySubG/dataset/dblp/cid2cname.pkl')
        cname2cid = {j:i for i,j in cid2cname.items() if j in all_labels}
        label_mapping = {j:i for i,j in cid2cname.items() if j in all_labels}
        
    operators = ['HAD', 'AVG', 'L1', 'L2']
    res = []
    model_ls = {}
    for method in methods:
        print(method)
        results, models = eval_lp(
            data, all_labels, label_mapping, method, num_time_steps, exp, operators, threshold=threshold)
        tmp = pd.DataFrame(results)
        tmp['method'] = method
        res.append(tmp)
        model_ls[method] = models
    
    df = pd.concat(res)
    df = df.reset_index().rename(columns={'index':'dataset'})
    df['best_AUC'] = df[['sigmoid_auc', 'HAD_auc', 'AVG_auc', 'L1_auc', 'L2_auc']].max(axis=1)
    df['best_F1'] = df[['sigmoid_f1', 'HAD_f1', 'AVG_f1', 'L1_f1', 'L2_f1']].max(axis=1)

    print('Full report')
    display(df)

    print('Test report')
    display(df[df['dataset']=='test'])

    print('Selected methods')
    display(df[(df['dataset']=='test')&(df['method'].isin(selected_methods))][[
        'method',
        'sigmoid_auc', 'HAD_auc', 'AVG_auc', 'L1_auc', 'L2_auc', 'best_AUC',
        'sigmoid_f1', 'HAD_f1', 'AVG_f1', 'L1_f1', 'L2_f1', 'best_F1']])
    
    return df, model_ls

def get_ground_truth(data, all_labels, label_mapping):
    ground_truth = {}
    pred_idx = {}
    for i,d in data.items():
        node_id = d['node_id']
        time_id = d['time_id']
        pos_labels = d['label']
        neg_labels = [i for i in all_labels if i not in pos_labels]
        if len(pos_labels) > 0:
            ground_truth[node_id] = [label_mapping[l] for l in pos_labels]
        pred_idx[node_id] = [label_mapping[l] for l in all_labels]
    return ground_truth, pred_idx

def make_prediction(pred_idx, node_embedding, models):
    ranking = {}
    distances = []
    for s, ts in pred_idx.items():
        for t in ts:
            source = node_embedding[s]
            target = node_embedding[t]
            probs = sigmoid(np.dot(source, target.T))
            distances.append((s, t, probs))
        distances.sort(key=lambda tup: tup[2])
    ranking['sigmoid'] = distances  
    
    for operator, model in models.items():
        distances = []
        for s, ts in pred_idx.items():
            for t in ts:
                source = node_embedding[s]
                target = node_embedding[t]
                feats = np.array([get_link_score(source, target, operator)])
                probs = model.predict_proba(feats)[:, 1][0]
                distances.append((s, t, probs))
            distances.sort(key=lambda tup: tup[2])
        ranking[operator] = distances
    return ranking

# def eval_ranking(pred_dict, true_dict, k):
#     predicted_indices = [] # collections.OrderedDict()
#     true_indices = [] # collections.OrderedDict()
#     author_indices = sorted(list(true_dict.keys()))
#     for aid in author_indices:
#         predicted_indices.append(pred_dict[aid])
#         true_indices.append(true_dict[aid])
#     eval_agent = Evaluation(predicted_indices, true_indices, k)
#     return eval_agent.result

def calculate_multilabel(gt, prd):
    micro_f1 = f1_score(gt, prd, average='micro')
    macro_f1 = f1_score(gt, prd, average='macro')
    f1 = f1_score(gt, prd, average='weighted')
    micro_recall = recall_score(gt, prd, average='micro')
    macro_recall = recall_score(gt, prd, average='macro')
    recall = recall_score(gt, prd, average='weighted')
    micro_precision = precision_score(gt, prd, average='micro')
    macro_precision = precision_score(gt, prd, average='macro')
    precision = precision_score(gt, prd, average='weighted')
    return f1, micro_f1, macro_f1, recall, micro_recall, macro_recall, precision, micro_precision, macro_precision

def evaluation(pred, ground_truth, multilabel_binarizer, k=5):
    gt = multilabel_binarizer.transform(list(ground_truth.values()))
    result = {}
    for k in range(1, k+1):
        prd = multilabel_binarizer.transform(list({i: j[:k] for i, j in pred.items()}.values()))
        result[k] = calculate_multilabel(gt, prd)
    df = pd.DataFrame(result).T
    df.columns = [
    'f1', 'micro_f1', 'macro_f1', 
    'recall', 'micro_recall', 'macro_recall', 
    'precision', 'micro_precision', 'macro_precision']
    df['k'] = range(1, k+1)
    return df

# Evaluate

## IMDB

In [14]:
methods = [
    'GCN', 'GAT', #'SAGE', 'GIN', 
    'TgGCN', 'TgGAT', 'TgSAGE', 'TgGIN', 
    'GCRN', 'TIMERS', 'DynAE', #'DynRNN', 
    'DynAERNN', 'DynGEM', 'DySAT',
    'VGRNN', 'EvolveGCN', 'CTGCN-C',
]
df, model_ls = print_report(exp='imdb_lp_t2', methods=methods, threshold=0.1)

GCN
3021 3897 984 1326 954 1350
GAT
3021 3897 984 1326 954 1350
TgGCN
3021 3897 984 1326 954 1350
TgGAT
3021 3897 984 1326 954 1350
TgSAGE
3021 3897 984 1326 954 1350
TgGIN
3021 3897 984 1326 954 1350
GCRN
3021 3897 984 1326 954 1350
TIMERS
3021 3897 984 1326 954 1350
DynAE
2266 3068 744 1020 954 1350
DynAERNN
2266 3068 744 1020 954 1350
DynGEM
3021 3897 984 1326 954 1350
DySAT
3021 3897 984 1326 954 1350
VGRNN
3021 3897 984 1326 954 1350
EvolveGCN
3021 3897 984 1326 954 1350
CTGCN-C
3021 3897 984 1326 954 1350
Full report


Unnamed: 0,dataset,sigmoid_auc,sigmoid_f1,HAD_auc,HAD_f1,AVG_auc,AVG_f1,L1_auc,L1_f1,L2_auc,L2_f1,method,best_AUC,best_F1
0,train,0.602329,0.607908,0.617667,0.607908,0.558497,0.607908,0.587745,0.607908,0.585885,0.607908,GCN,0.617667,0.607908
1,val,0.594455,0.59745,0.581173,0.59745,0.527983,0.59745,0.579758,0.59745,0.579522,0.59745,GCN,0.594455,0.59745
2,test,0.567834,0.585635,0.564504,0.585635,0.511898,0.585635,0.563776,0.585635,0.565322,0.585635,GCN,0.567834,0.585635
3,train,0.5,0.607908,0.561742,0.607908,0.560585,0.607908,0.559757,0.607908,0.558289,0.607908,GAT,0.561742,0.607908
4,val,0.5,0.59745,0.544179,0.59745,0.543878,0.59745,0.544207,0.59745,0.543671,0.59745,GAT,0.544207,0.59745
5,test,0.5,0.585635,0.546391,0.585635,0.546957,0.585635,0.550546,0.585635,0.551084,0.585635,GAT,0.551084,0.585635
6,train,0.54471,0.607908,0.549228,0.607908,0.535487,0.607908,0.544163,0.607908,0.504257,0.607908,TgGCN,0.549228,0.607908
7,val,0.52182,0.59745,0.522528,0.59745,0.521296,0.59745,0.530537,0.59745,0.512865,0.59745,TgGCN,0.530537,0.59745
8,test,0.547681,0.585635,0.548403,0.585635,0.529051,0.585635,0.523279,0.585635,0.526214,0.585635,TgGCN,0.548403,0.585635
9,train,0.563055,0.607908,0.57885,0.607908,0.548456,0.607908,0.563383,0.607908,0.56257,0.607908,TgGAT,0.57885,0.607908


Test report


Unnamed: 0,dataset,sigmoid_auc,sigmoid_f1,HAD_auc,HAD_f1,AVG_auc,AVG_f1,L1_auc,L1_f1,L2_auc,L2_f1,method,best_AUC,best_F1
2,test,0.567834,0.585635,0.564504,0.585635,0.511898,0.585635,0.563776,0.585635,0.565322,0.585635,GCN,0.567834,0.585635
5,test,0.5,0.585635,0.546391,0.585635,0.546957,0.585635,0.550546,0.585635,0.551084,0.585635,GAT,0.551084,0.585635
8,test,0.547681,0.585635,0.548403,0.585635,0.529051,0.585635,0.523279,0.585635,0.526214,0.585635,TgGCN,0.548403,0.585635
11,test,0.577643,0.585635,0.563528,0.585635,0.526472,0.585635,0.566599,0.585635,0.567734,0.585635,TgGAT,0.577643,0.585635
14,test,0.510193,0.585635,0.517289,0.585635,0.524668,0.585635,0.549904,0.585635,0.550647,0.585635,TgSAGE,0.550647,0.585635
17,test,0.561513,0.585635,0.522744,0.585635,0.511166,0.585635,0.522218,0.585635,0.521914,0.585635,TgGIN,0.561513,0.585635
20,test,0.602111,0.528409,0.607511,0.585815,0.547576,0.585635,0.606435,0.585635,0.605285,0.585815,GCRN,0.607511,0.585815
23,test,0.633524,0.585635,0.651375,0.585635,0.501401,0.585635,0.536335,0.585635,0.554804,0.585635,TIMERS,0.651375,0.585635
26,test,0.521641,0.585635,0.515584,0.585635,0.516116,0.585635,0.516379,0.585635,0.507053,0.585635,DynAE,0.521641,0.585635
29,test,0.500877,0.585635,0.501443,0.585635,0.500489,0.585635,0.507225,0.585635,0.504821,0.585635,DynAERNN,0.507225,0.585635


Selected methods


Unnamed: 0,method,sigmoid_auc,HAD_auc,AVG_auc,L1_auc,L2_auc,best_AUC,sigmoid_f1,HAD_f1,AVG_f1,L1_f1,L2_f1,best_F1
2,GCN,0.567834,0.564504,0.511898,0.563776,0.565322,0.567834,0.585635,0.585635,0.585635,0.585635,0.585635,0.585635
11,TgGAT,0.577643,0.563528,0.526472,0.566599,0.567734,0.577643,0.585635,0.585635,0.585635,0.585635,0.585635,0.585635
14,TgSAGE,0.510193,0.517289,0.524668,0.549904,0.550647,0.550647,0.585635,0.585635,0.585635,0.585635,0.585635,0.585635
17,TgGIN,0.561513,0.522744,0.511166,0.522218,0.521914,0.561513,0.585635,0.585635,0.585635,0.585635,0.585635,0.585635
20,GCRN,0.602111,0.607511,0.547576,0.606435,0.605285,0.607511,0.528409,0.585815,0.585635,0.585635,0.585815,0.585815
23,TIMERS,0.633524,0.651375,0.501401,0.536335,0.554804,0.651375,0.585635,0.585635,0.585635,0.585635,0.585635,0.585635
26,DynAE,0.521641,0.515584,0.516116,0.516379,0.507053,0.521641,0.585635,0.585635,0.585635,0.585635,0.585635,0.585635
29,DynAERNN,0.500877,0.501443,0.500489,0.507225,0.504821,0.507225,0.585635,0.585635,0.585635,0.585635,0.585635,0.585635
32,DynGEM,0.5647,0.58294,0.513904,0.582381,0.581785,0.58294,0.585635,0.585635,0.585635,0.585635,0.585635,0.585635
35,DySAT,0.50549,0.508677,0.514076,0.511095,0.516292,0.516292,0.585635,0.585635,0.585635,0.585635,0.585635,0.585635


In [15]:
name = 'imdb_lp_t2'
time_id = 4 - 1
data = pd.read_pickle(f'./data/{name}/0.input/data.pkl')
all_labels = list(set(itertools.chain(*[d['label'] for i, d in data.items()])))
label_mapping = pd.read_pickle(f'./data/{name}/0.input/entity_mapping.pkl')['genre']
test_data = {i: d for i,d in data.items() if d['dataset']=='test'}
ground_truth, pred_idx = get_ground_truth(test_data, all_labels, label_mapping)

In [16]:
ls = []
for method in methods:
    print(method)
    node_embedding = pd.read_csv(
                './data/{}/2.embedding/{}/{:02d}.csv'.format(
                    name, method, time_id), index_col=0, sep='\t').values
    models = model_ls[method]
    ranking = make_prediction(pred_idx, node_embedding, models)
    eval_metrics = []
    for operator, rk in ranking.items():
        pd_pred = pd.DataFrame(rk, columns=['source', 'target', 'sims']).sort_values(['sims'], ascending=False)
        pred = pd_pred.groupby('source').agg({'target': list}).to_dict()['target']
        res = eval_ranking(pred, ground_truth, k=5)
        res['operator'] = operator
        eval_metrics.append(res)
    pd_res = pd.concat(eval_metrics)
    pd_res['method'] = method
    ls.append(pd_res)
results = pd.concat(ls)

GCN
GAT
TgGCN
TgGAT
TgSAGE
TgGIN
GCRN
TIMERS
DynAE
DynAERNN
DynGEM
DySAT
VGRNN
EvolveGCN
CTGCN-C


In [17]:
k = 1
operator = 'HAD'
excl = ['TIMERS']
incl = [
    'GCN', 'TgGAT', 'TgSAGE', 'TgGIN', 'GCRN', 
    'DynAE', 'DynAERNN', 'DynGEM', 'DySAT', 'VGRNN',
    'EvolveGCN', 'CTGCN-C']
display(results[(results['k']==k)&(results['operator']==operator)&(results['method'].isin(incl))])

Unnamed: 0,k,recall,mrr,map,ndcg,operator,method
1,1,0.3852,0.477865,0.3852,0.477865,HAD,GCN
1,1,0.405382,0.486979,0.405382,0.486979,HAD,TgGAT
1,1,0.338976,0.420573,0.338976,0.420573,HAD,TgSAGE
1,1,0.344184,0.415365,0.344184,0.415365,HAD,TgGIN
1,1,0.422309,0.511719,0.422309,0.511719,HAD,GCRN
1,1,0.331163,0.415365,0.331163,0.415365,HAD,DynAE
1,1,0.338976,0.408854,0.338976,0.408854,HAD,DynAERNN
1,1,0.403429,0.489583,0.403429,0.489583,HAD,DynGEM
1,1,0.314887,0.402344,0.314887,0.402344,HAD,DySAT
1,1,0.371528,0.442708,0.371528,0.442708,HAD,VGRNN


In [18]:
k = 2
operator = 'HAD'
excl = ['TIMERS']
incl = [
    'GCN', 'TgGAT', 'TgSAGE', 'TgGIN', 'GCRN', 
    'DynAE', 'DynAERNN', 'DynGEM', 'DySAT', 'VGRNN',
    'EvolveGCN', 'CTGCN-C']
display(results[(results['k']==k)&(results['operator']==operator)&(results['method'].isin(incl))])

Unnamed: 0,k,recall,mrr,map,ndcg,operator,method
2,2,0.722873,0.628255,0.57997,0.639171,HAD,GCN
2,2,0.722873,0.635417,0.587457,0.644952,HAD,TgGAT
2,2,0.652561,0.570312,0.515842,0.575151,HAD,TgSAGE
2,2,0.670139,0.580078,0.523655,0.586272,HAD,TgGIN
2,2,0.735243,0.654297,0.60178,0.659563,HAD,GCRN
2,2,0.661024,0.566406,0.521701,0.579405,HAD,DynAE
2,2,0.668837,0.574219,0.522352,0.583977,HAD,DynAERNN
2,2,0.718316,0.632161,0.58648,0.642387,HAD,DynGEM
2,2,0.660373,0.565104,0.507704,0.571679,HAD,DySAT
2,2,0.670139,0.59375,0.537326,0.596363,HAD,VGRNN


In [19]:
k = 3
operator = 'HAD'
excl = ['TIMERS']
incl = [
    'GCN', 'TgGAT', 'TgSAGE', 'TgGIN', 'GCRN', 
    'DynAE', 'DynAERNN', 'DynGEM', 'DySAT', 'VGRNN',
    'EvolveGCN', 'CTGCN-C']
display(results[(results['k']==k)&(results['operator']==operator)&(results['method'].isin(incl))])

Unnamed: 0,k,recall,mrr,map,ndcg,operator,method
3,3,1.0,0.70204,0.695421,0.775795,HAD,GCN
3,3,1.0,0.707465,0.704644,0.782165,HAD,TgGAT
3,3,1.0,0.663628,0.658637,0.748257,HAD,TgSAGE
3,3,1.0,0.665148,0.662977,0.751399,HAD,TgGIN
3,3,1.0,0.722005,0.715061,0.790665,HAD,GCRN
3,3,1.0,0.66059,0.657986,0.747028,HAD,DynAE
3,3,1.0,0.661024,0.660807,0.749313,HAD,DynAERNN
3,3,1.0,0.707248,0.703668,0.781363,HAD,DynGEM
3,3,1.0,0.655816,0.647895,0.740879,HAD,DySAT
3,3,1.0,0.678819,0.676649,0.761491,HAD,VGRNN


In [21]:
from ranking import *

## DBLP

In [3]:
methods = [
    'GCN', 'GAT', #'SAGE', 'GIN', 
    'TgGCN', 'TgGAT', 'TgSAGE', 'TgGIN', 
    'GCRN', 'TIMERS', 'DynAE',# 'DynRNN',
    'DynAERNN', 'DynGEM', 'DySAT',
    'VGRNN', 'EvolveGCN', 'CTGCN-C',
]
df, model_ls = print_report(exp='dblp', methods=methods, threshold=0.5)

GCN
5949 12311 1409 2786 759 1886
GAT
5949 12311 1409 2786 759 1886
TgGCN
5949 12311 1409 2786 759 1886
TgGAT
5949 12311 1409 2786 759 1886
TgSAGE
5949 12311 1409 2786 759 1886
TgGIN
5949 12311 1409 2786 759 1886
GCRN
5949 12311 1409 2786 759 1886
TIMERS
5949 12311 1409 2786 759 1886
DynAE
5193 10667 1409 2786 759 1886
DynAERNN
5193 10667 1409 2786 759 1886
DynGEM
5949 12311 1409 2786 759 1886
DySAT
5949 12311 1409 2786 759 1886
VGRNN
5949 12311 1409 2786 759 1886
EvolveGCN
5949 12311 1409 2786 759 1886
CTGCN-C
5949 12311 1409 2786 759 1886
Full report


Unnamed: 0,dataset,sigmoid_auc,sigmoid_f1,HAD_auc,HAD_f1,AVG_auc,AVG_f1,L1_auc,L1_f1,L2_auc,L2_f1,method,best_AUC,best_F1
0,train,0.580619,0.489292,0.555749,0.0,0.55407,0.0,0.577562,0.0,0.574619,0.0,GCN,0.580619,0.489292
1,val,0.604251,0.501357,0.505254,0.0,0.512637,0.0,0.592016,0.0,0.592289,0.0,GCN,0.604251,0.501357
2,test,0.612739,0.444248,0.500279,0.0,0.520053,0.0,0.602784,0.0,0.60319,0.0,GCN,0.612739,0.444248
3,train,0.5,0.49147,0.556625,0.0,0.555903,0.0,0.565819,0.0,0.565544,0.0,GAT,0.565819,0.49147
4,val,0.5,0.502855,0.557438,0.0,0.557737,0.0,0.57378,0.0,0.574046,0.0,GAT,0.574046,0.502855
5,test,0.5,0.445946,0.52364,0.0,0.52324,0.0,0.568161,0.0,0.570321,0.0,GAT,0.570321,0.445946
6,train,0.550593,0.202253,0.549041,0.0,0.538356,0.0,0.521215,0.0,0.501099,0.0,TgGCN,0.550593,0.202253
7,val,0.556282,0.214088,0.556713,0.0,0.564624,0.0,0.519864,0.0,0.508987,0.0,TgGCN,0.564624,0.214088
8,test,0.561694,0.265976,0.562751,0.0,0.537827,0.0,0.522147,0.0,0.519261,0.0,TgGCN,0.562751,0.265976
9,train,0.535343,0.475571,0.554349,0.0,0.543831,0.0,0.567408,0.0,0.566748,0.0,TgGAT,0.567408,0.475571


Test report


Unnamed: 0,dataset,sigmoid_auc,sigmoid_f1,HAD_auc,HAD_f1,AVG_auc,AVG_f1,L1_auc,L1_f1,L2_auc,L2_f1,method,best_AUC,best_F1
2,test,0.612739,0.444248,0.500279,0.0,0.520053,0.0,0.602784,0.0,0.60319,0.0,GCN,0.612739,0.444248
5,test,0.5,0.445946,0.52364,0.0,0.52324,0.0,0.568161,0.0,0.570321,0.0,GAT,0.570321,0.445946
8,test,0.561694,0.265976,0.562751,0.0,0.537827,0.0,0.522147,0.0,0.519261,0.0,TgGCN,0.562751,0.265976
11,test,0.540466,0.453242,0.508228,0.0,0.507311,0.0,0.591335,0.0,0.591402,0.0,TgGAT,0.591402,0.453242
14,test,0.532003,0.398821,0.518423,0.0,0.502975,0.0,0.565975,0.0,0.566371,0.0,TgSAGE,0.566371,0.398821
17,test,0.549542,0.213725,0.582931,0.0,0.52102,0.0,0.523648,0.0,0.5531,0.0,TgGIN,0.582931,0.213725
20,test,0.536459,0.437811,0.561702,0.208531,0.561518,0.0,0.564566,0.277347,0.56598,0.091121,GCRN,0.56598,0.437811
23,test,0.684041,0.509038,0.759404,0.370594,0.541781,0.0,0.588769,0.330745,0.59383,0.330745,TIMERS,0.759404,0.509038
26,test,0.504181,0.445946,0.533682,0.0,0.597001,0.0,0.518056,0.0,0.563979,0.0,DynAE,0.597001,0.445946
29,test,0.585067,0.445946,0.588267,0.0,0.594146,0.0,0.552162,0.0,0.552223,0.0,DynAERNN,0.594146,0.445946


Selected methods


Unnamed: 0,method,sigmoid_auc,HAD_auc,AVG_auc,L1_auc,L2_auc,best_AUC,sigmoid_f1,HAD_f1,AVG_f1,L1_f1,L2_f1,best_F1
2,GCN,0.612739,0.500279,0.520053,0.602784,0.60319,0.612739,0.444248,0.0,0.0,0.0,0.0,0.444248
11,TgGAT,0.540466,0.508228,0.507311,0.591335,0.591402,0.591402,0.453242,0.0,0.0,0.0,0.0,0.453242
14,TgSAGE,0.532003,0.518423,0.502975,0.565975,0.566371,0.566371,0.398821,0.0,0.0,0.0,0.0,0.398821
17,TgGIN,0.549542,0.582931,0.52102,0.523648,0.5531,0.582931,0.213725,0.0,0.0,0.0,0.0,0.213725
20,GCRN,0.536459,0.561702,0.561518,0.564566,0.56598,0.56598,0.437811,0.208531,0.0,0.277347,0.091121,0.437811
23,TIMERS,0.684041,0.759404,0.541781,0.588769,0.59383,0.759404,0.509038,0.370594,0.0,0.330745,0.330745,0.509038
26,DynAE,0.504181,0.533682,0.597001,0.518056,0.563979,0.597001,0.445946,0.0,0.0,0.0,0.0,0.445946
29,DynAERNN,0.585067,0.588267,0.594146,0.552162,0.552223,0.594146,0.445946,0.0,0.0,0.0,0.0,0.445946
32,DynGEM,0.567981,0.529378,0.521807,0.560547,0.517568,0.567981,0.445946,0.189767,0.015686,0.027228,0.367195,0.445946
35,DySAT,0.505889,0.513793,0.510786,0.533857,0.530416,0.533857,0.4361,0.0,0.0,0.007853,0.01292,0.4361


In [4]:
name = 'dblp'
time_id = 7 - 1
data = pd.read_pickle(f'./data/{name}/0.input/data.pkl')
all_labels = list(set(itertools.chain(*[d['label'] for i, d in data.items()])))
cid2cname = pd.read_pickle('../DySubG/dataset/dblp/cid2cname.pkl')
cname2cid = {j:i for i,j in cid2cname.items() if j in all_labels}
label_mapping = {j:i for i,j in cid2cname.items() if j in all_labels}
test_data = {i: d for i,d in data.items() if d['dataset']=='test'}
ground_truth, pred_idx = get_ground_truth(test_data, all_labels, label_mapping)
# ground_truth = pd.read_pickle(f'./data/{name}/0.input/ground_truth.pkl')
print(len(all_labels))

5


In [37]:
ls = []
for method in methods:
    print(method)
    node_embedding = pd.read_csv(
                './data/{}/2.embedding/{}/{:02d}.csv'.format(
                    name, method, time_id), index_col=0, sep='\t').values
    models = model_ls[method]
    ranking = make_prediction(pred_idx, node_embedding, models)
    eval_metrics = []
    multilabel_binarizer = MultiLabelBinarizer().fit(list(ground_truth.values()))
    for operator, rk in ranking.items():
        pd_pred = pd.DataFrame(rk, columns=['source', 'target', 'sims']).sort_values(['sims'], ascending=False)
        pred = pd_pred.groupby('source').agg({'target': list}).to_dict()['target']
        res = evaluation(pred, ground_truth, multilabel_binarizer, k=5)
        res['operator'] = operator
        eval_metrics.append(res)
    pd_res = pd.concat(eval_metrics)
    pd_res['method'] = method
    ls.append(pd_res)
results = pd.concat(ls)

GCN


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


GAT


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TgGCN


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

TgGAT


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TgSAGE


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TgGIN


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

GCRN


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TIMERS


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

DynAE


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


DynAERNN


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


DynGEM


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

DySAT
VGRNN


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

EvolveGCN


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

CTGCN-C


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [6]:
pred = pd_pred.groupby('source').agg({'target': list}).to_dict()['target']

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [39]:
k = 1
operator = 'HAD'
excl = ['TIMERS']
incl = [
    'GCN', 'TgGAT', 'TgSAGE', 'TgGIN', 'GCRN', 
    'DynAE', 'DynAERNN', 'DynGEM', 'DySAT', 'VGRNN',
    'EvolveGCN', 'CTGCN-C']
display(results[(results['k']==k)&(results['operator']==operator)&(results['method'].isin(incl))])

Unnamed: 0,f1,micro_f1,macro_f1,recall,micro_recall,macro_recall,precision,micro_precision,macro_precision,k,operator,method
1,0.159546,0.223602,0.157728,0.189723,0.189723,0.2059,0.281803,0.272212,0.25844,1,HAD,GCN
1,0.265185,0.273292,0.259933,0.231884,0.231884,0.230147,0.324737,0.332703,0.315596,1,HAD,TgGAT
1,0.228076,0.231366,0.225328,0.196311,0.196311,0.20008,0.323405,0.281664,0.303094,1,HAD,TgSAGE
1,0.097163,0.190994,0.115633,0.162055,0.162055,0.193575,0.193503,0.232514,0.194667,1,HAD,TgGIN
1,0.285916,0.321429,0.265623,0.272727,0.272727,0.243929,0.381796,0.391304,0.370952,1,HAD,GCRN
1,0.057585,0.184783,0.073457,0.156785,0.156785,0.2,0.035269,0.224953,0.044991,1,HAD,DynAE
1,0.064788,0.197205,0.077439,0.167325,0.167325,0.2,0.040171,0.240076,0.048015,1,HAD,DynAERNN
1,0.218707,0.26087,0.213784,0.221344,0.221344,0.223223,0.493079,0.31758,0.464548,1,HAD,DynGEM
1,0.227458,0.256211,0.210463,0.217391,0.217391,0.196341,0.327516,0.311909,0.305475,1,HAD,DySAT
1,0.142824,0.268634,0.13189,0.227931,0.227931,0.204912,0.185555,0.327032,0.200656,1,HAD,VGRNN


In [40]:
k = 2
operator = 'HAD'
excl = ['TIMERS']
incl = [
    'GCN', 'TgGAT', 'TgSAGE', 'TgGIN', 'GCRN', 
    'DynAE', 'DynAERNN', 'DynGEM', 'DySAT', 'VGRNN',
    'EvolveGCN', 'CTGCN-C']
display(results[(results['k']==k)&(results['operator']==operator)&(results['method'].isin(incl))])

Unnamed: 0,f1,micro_f1,macro_f1,recall,micro_recall,macro_recall,precision,micro_precision,macro_precision,k,operator,method
2,0.302616,0.325812,0.287371,0.389987,0.389987,0.394628,0.291959,0.279773,0.271965,2,HAD,GCN
2,0.352372,0.347826,0.341574,0.416337,0.416337,0.416403,0.314059,0.298677,0.297062,2,HAD,TgGAT
2,0.335668,0.328013,0.32261,0.392622,0.392622,0.389789,0.303069,0.281664,0.283461,2,HAD,TgSAGE
2,0.16388,0.286186,0.184723,0.342556,0.342556,0.403705,0.177003,0.245747,0.180488,2,HAD,TgGIN
2,0.376878,0.39956,0.358989,0.478261,0.478261,0.458891,0.358472,0.3431,0.343797,2,HAD,GCRN
2,0.122372,0.270776,0.150896,0.324111,0.324111,0.4,0.07544,0.232514,0.093006,2,HAD,DynAE
2,0.122372,0.270776,0.150896,0.324111,0.324111,0.4,0.07544,0.232514,0.093006,2,HAD,DynAERNN
2,0.321354,0.379747,0.306325,0.454545,0.454545,0.43416,0.392976,0.326087,0.375946,2,HAD,DynGEM
2,0.345798,0.351128,0.328342,0.42029,0.42029,0.400712,0.297919,0.301512,0.283233,2,HAD,DySAT
2,0.247822,0.305999,0.237531,0.366271,0.366271,0.354864,0.25488,0.26276,0.249853,2,HAD,VGRNN


In [41]:
k = 3
operator = 'HAD'
excl = ['TIMERS']
incl = [
    'GCN', 'TgGAT', 'TgSAGE', 'TgGIN', 'GCRN', 
    'DynAE', 'DynAERNN', 'DynGEM', 'DySAT', 'VGRNN',
    'EvolveGCN', 'CTGCN-C']
display(results[(results['k']==k)&(results['operator']==operator)&(results['method'].isin(incl))])

Unnamed: 0,f1,micro_f1,macro_f1,recall,micro_recall,macro_recall,precision,micro_precision,macro_precision,k,operator,method
3,0.383088,0.386189,0.364979,0.596838,0.596838,0.593673,0.297234,0.285444,0.279577,3,HAD,GCN
3,0.40868,0.398977,0.393179,0.616601,0.616601,0.61666,0.317256,0.294896,0.297947,3,HAD,TgGAT
3,0.388755,0.378517,0.372161,0.58498,0.58498,0.579995,0.296339,0.279773,0.278907,3,HAD,TgSAGE
3,0.242238,0.364024,0.253508,0.562582,0.562582,0.6,0.155193,0.269061,0.161528,3,HAD,TgGIN
3,0.435571,0.435635,0.416642,0.673254,0.673254,0.658188,0.330133,0.321991,0.314459,3,HAD,GCRN
3,0.235855,0.358056,0.2499,0.55336,0.55336,0.6,0.150845,0.26465,0.15879,3,HAD,DynAE
3,0.235855,0.358056,0.2499,0.55336,0.55336,0.6,0.150845,0.26465,0.15879,3,HAD,DynAERNN
3,0.367885,0.412617,0.34617,0.637681,0.637681,0.622372,0.378941,0.304978,0.357608,3,HAD,DynGEM
3,0.388029,0.382779,0.375092,0.591568,0.591568,0.595776,0.299492,0.282924,0.283915,3,HAD,DySAT
3,0.364578,0.391304,0.338862,0.604743,0.604743,0.582896,0.307043,0.289225,0.291023,3,HAD,VGRNN


## DBLP 6 

In [3]:
methods = [
    'GCN', 'GAT', #'SAGE', 'GIN', 
    'TgGCN', 'TgGAT', 'TgSAGE', 'TgGIN', 
    'GCRN', 'TIMERS', 'DynAE',# 'DynRNN',
    'DynAERNN', 'DynGEM', 'DySAT',
    'VGRNN', 'EvolveGCN', 'CTGCN-C',
]
df, model_ls = print_report(exp='dblp', methods=methods, threshold=0.5)

GCN
6766 16778 1634 3880 939 2691
GAT
6766 16778 1634 3880 939 2691
TgGCN
6766 16778 1634 3880 939 2691
TgGAT
6766 16778 1634 3880 939 2691
TgSAGE
6766 16778 1634 3880 939 2691
TgGIN
6766 16778 1634 3880 939 2691
GCRN
6766 16778 1634 3880 939 2691
TIMERS
6766 16778 1634 3880 939 2691
DynAE
5926 14534 1634 3880 939 2691
DynAERNN
5926 14534 1634 3880 939 2691
DynGEM
6766 16778 1634 3880 939 2691
DySAT
6766 16778 1634 3880 939 2691
VGRNN
6766 16778 1634 3880 939 2691
EvolveGCN
6766 16778 1634 3880 939 2691
CTGCN-C
6766 16778 1634 3880 939 2691
Full report


Unnamed: 0,dataset,sigmoid_auc,sigmoid_f1,HAD_auc,HAD_f1,AVG_auc,AVG_f1,L1_auc,L1_f1,L2_auc,L2_f1,method,best_AUC,best_F1
0,train,0.574367,0.446195,0.564089,0.0,0.557196,0.0,0.573204,0.0,0.570806,0.0,GCN,0.574367,0.446195
1,val,0.599924,0.456618,0.516876,0.0,0.52281,0.0,0.58749,0.0,0.587456,0.0,GCN,0.599924,0.456618
2,test,0.601287,0.40975,0.522419,0.0,0.503811,0.0,0.607165,0.0,0.6075,0.0,GCN,0.6075,0.40975
3,train,0.5,0.446453,0.55113,0.0,0.551264,0.0,0.568495,0.0,0.568382,0.0,GAT,0.568495,0.446453
4,val,0.5,0.457191,0.559633,0.0,0.559687,0.0,0.571471,0.0,0.571209,0.0,GAT,0.571471,0.457191
5,test,0.5,0.411031,0.536258,0.0,0.5361,0.0,0.574108,0.0,0.576238,0.0,GAT,0.576238,0.411031
6,train,0.539605,0.203712,0.538854,0.0,0.559803,0.0,0.52817,0.0,0.514325,0.0,TgGCN,0.559803,0.203712
7,val,0.552159,0.205827,0.551789,0.0,0.569383,0.0,0.505544,0.0,0.515341,0.0,TgGCN,0.569383,0.205827
8,test,0.57705,0.258741,0.577708,0.0,0.523415,0.0,0.513608,0.0,0.513327,0.0,TgGCN,0.577708,0.258741
9,train,0.528802,0.434968,0.52656,0.0,0.550404,0.0,0.562297,0.0,0.561617,0.0,TgGAT,0.562297,0.434968


Test report


Unnamed: 0,dataset,sigmoid_auc,sigmoid_f1,HAD_auc,HAD_f1,AVG_auc,AVG_f1,L1_auc,L1_f1,L2_auc,L2_f1,method,best_AUC,best_F1
2,test,0.601287,0.40975,0.522419,0.0,0.503811,0.0,0.607165,0.0,0.6075,0.0,GCN,0.6075,0.40975
5,test,0.5,0.411031,0.536258,0.0,0.5361,0.0,0.574108,0.0,0.576238,0.0,GAT,0.576238,0.411031
8,test,0.57705,0.258741,0.577708,0.0,0.523415,0.0,0.513608,0.0,0.513327,0.0,TgGCN,0.577708,0.258741
11,test,0.538524,0.416469,0.54079,0.0,0.51558,0.0,0.589482,0.0,0.59009,0.0,TgGAT,0.59009,0.416469
14,test,0.532665,0.373384,0.534901,0.0,0.505136,0.0,0.573757,0.0,0.574738,0.0,TgSAGE,0.574738,0.373384
17,test,0.565183,0.209192,0.560961,0.0,0.504169,0.0,0.520415,0.0,0.539824,0.0,TgGIN,0.565183,0.209192
20,test,0.533966,0.402253,0.561693,0.038229,0.555413,0.0,0.558352,0.224443,0.559792,0.0,GCRN,0.561693,0.402253
23,test,0.680806,0.477099,0.744242,0.273128,0.500532,0.0,0.562489,0.272548,0.566395,0.276086,TIMERS,0.744242,0.477099
26,test,0.504422,0.411031,0.524896,0.0,0.564302,0.0,0.538787,0.0,0.532422,0.0,DynAE,0.564302,0.411031
29,test,0.556475,0.411031,0.554615,0.0,0.565709,0.0,0.521546,0.0,0.521284,0.0,DynAERNN,0.565709,0.411031


Selected methods


Unnamed: 0,method,sigmoid_auc,HAD_auc,AVG_auc,L1_auc,L2_auc,best_AUC,sigmoid_f1,HAD_f1,AVG_f1,L1_f1,L2_f1,best_F1
2,GCN,0.601287,0.522419,0.503811,0.607165,0.6075,0.6075,0.40975,0.0,0.0,0.0,0.0,0.40975
11,TgGAT,0.538524,0.54079,0.51558,0.589482,0.59009,0.59009,0.416469,0.0,0.0,0.0,0.0,0.416469
14,TgSAGE,0.532665,0.534901,0.505136,0.573757,0.574738,0.574738,0.373384,0.0,0.0,0.0,0.0,0.373384
17,TgGIN,0.565183,0.560961,0.504169,0.520415,0.539824,0.565183,0.209192,0.0,0.0,0.0,0.0,0.209192
20,GCRN,0.533966,0.561693,0.555413,0.558352,0.559792,0.561693,0.402253,0.038229,0.0,0.224443,0.0,0.402253
23,TIMERS,0.680806,0.744242,0.500532,0.562489,0.566395,0.744242,0.477099,0.273128,0.0,0.272548,0.276086,0.477099
26,DynAE,0.504422,0.524896,0.564302,0.538787,0.532422,0.564302,0.411031,0.0,0.0,0.0,0.0,0.411031
29,DynAERNN,0.556475,0.554615,0.565709,0.521546,0.521284,0.565709,0.411031,0.0,0.0,0.0,0.0,0.411031
32,DynGEM,0.574835,0.552129,0.521987,0.504054,0.523481,0.574835,0.411031,0.13609,0.006369,0.0,0.295067,0.411031
35,DySAT,0.516052,0.523657,0.516769,0.526067,0.529218,0.529218,0.403969,0.0,0.0,0.008466,0.014629,0.403969


In [4]:
name = 'dblp'
data = pd.read_pickle(f'./data/{name}/0.input/data.pkl')
all_labels = list(set(itertools.chain(*[d['label'] for i, d in data.items()])))
cid2cname = pd.read_pickle('../DySubG/dataset/dblp/cid2cname.pkl')
cname2cid = {j:i for i,j in cid2cname.items() if j in all_labels}
label_mapping = {j:i for i,j in cid2cname.items() if j in all_labels}
test_data = {i: d for i,d in data.items() if d['dataset']=='test'}
ground_truth, pred_idx = get_ground_truth(test_data, all_labels, label_mapping)

In [5]:
ls = []
for method in methods:
    print(method)
    node_embedding = pd.read_csv(
                './data/{}/2.embedding/{}/{:02d}.csv'.format(
                    name, method, time_id), index_col=0, sep='\t').values
    models = model_ls[method]
    ranking = make_prediction(pred_idx, node_embedding, models)
    eval_metrics = []
    for operator, rk in ranking.items():
        pd_pred = pd.DataFrame(rk, columns=['source', 'target', 'sims']).sort_values(['sims'], ascending=False)
        pred = pd_pred.groupby('source').agg({'target': list}).to_dict()['target']
        res = eval_ranking(pred, ground_truth, k=5)
        res['operator'] = operator
        eval_metrics.append(res)
    pd_res = pd.concat(eval_metrics)
    pd_res['method'] = method
    ls.append(pd_res)
results = pd.concat(ls)

GCN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.276501,0.421488,0.276501,0.421488,sigmoid
1,1,0.166033,0.256198,0.166033,0.256198,HAD
1,1,0.201295,0.297521,0.201295,0.297521,AVG
1,1,0.275675,0.414876,0.275675,0.414876,L1
1,1,0.273609,0.418182,0.273609,0.418182,L2


GAT


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.173333,0.261157,0.173333,0.261157,sigmoid
1,1,0.118926,0.208264,0.118926,0.208264,HAD
1,1,0.118926,0.208264,0.118926,0.208264,AVG
1,1,0.173196,0.269421,0.173196,0.269421,L1
1,1,0.179532,0.287603,0.179532,0.287603,L2


TgGCN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.211901,0.31405,0.211901,0.31405,sigmoid
1,1,0.206391,0.304132,0.206391,0.304132,HAD
1,1,0.187796,0.287603,0.187796,0.287603,AVG
1,1,0.242617,0.352066,0.242617,0.352066,L1
1,1,0.201295,0.297521,0.201295,0.297521,L2


TgGAT


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.186832,0.300826,0.186832,0.300826,sigmoid
1,1,0.127466,0.206612,0.127466,0.206612,HAD
1,1,0.120992,0.209917,0.120992,0.209917,AVG
1,1,0.199917,0.305785,0.199917,0.305785,L1
1,1,0.198402,0.307438,0.198402,0.307438,L2


TgSAGE


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.177879,0.27438,0.177879,0.27438,sigmoid
1,1,0.153223,0.231405,0.153223,0.231405,HAD
1,1,0.187796,0.287603,0.187796,0.287603,AVG
1,1,0.147163,0.239669,0.147163,0.239669,L1
1,1,0.164931,0.259504,0.164931,0.259504,L2


TgGIN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.139449,0.239669,0.139449,0.239669,sigmoid
1,1,0.149366,0.244628,0.149366,0.244628,HAD
1,1,0.201295,0.297521,0.201295,0.297521,AVG
1,1,0.120992,0.209917,0.120992,0.209917,L1
1,1,0.120992,0.209917,0.120992,0.209917,L2


GCRN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.192617,0.28595,0.192617,0.28595,sigmoid
1,1,0.236143,0.350413,0.236143,0.350413,HAD
1,1,0.242617,0.352066,0.242617,0.352066,AVG
1,1,0.196281,0.300826,0.196281,0.300826,L1
1,1,0.22011,0.330579,0.22011,0.330579,L2


TIMERS


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.336419,0.500826,0.336419,0.500826,sigmoid
1,1,0.374711,0.545455,0.374711,0.545455,HAD
1,1,0.201295,0.297521,0.201295,0.297521,AVG
1,1,0.242617,0.352066,0.242617,0.352066,L1
1,1,0.242617,0.352066,0.242617,0.352066,L2


DynAE


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.118926,0.208264,0.118926,0.208264,sigmoid
1,1,0.242617,0.352066,0.242617,0.352066,HAD
1,1,0.242617,0.352066,0.242617,0.352066,AVG
1,1,0.242617,0.352066,0.242617,0.352066,L1
1,1,0.187796,0.287603,0.187796,0.287603,L2


DynAERNN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.120992,0.209917,0.120992,0.209917,sigmoid
1,1,0.201295,0.297521,0.201295,0.297521,HAD
1,1,0.242617,0.352066,0.242617,0.352066,AVG
1,1,0.242617,0.352066,0.242617,0.352066,L1
1,1,0.120992,0.209917,0.120992,0.209917,L2


DynGEM


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.216171,0.319008,0.216171,0.319008,sigmoid
1,1,0.172782,0.276033,0.172782,0.276033,HAD
1,1,0.118926,0.208264,0.118926,0.208264,AVG
1,1,0.127879,0.219835,0.127879,0.219835,L1
1,1,0.187796,0.287603,0.187796,0.287603,L2


DySAT


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.178843,0.271074,0.178843,0.271074,sigmoid
1,1,0.192424,0.28595,0.192424,0.28595,HAD
1,1,0.182562,0.27438,0.182562,0.27438,AVG
1,1,0.182562,0.272727,0.182562,0.272727,L1
1,1,0.190634,0.290909,0.190634,0.290909,L2


VGRNN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.2327,0.342149,0.2327,0.342149,sigmoid
1,1,0.190275,0.289256,0.190275,0.289256,HAD
1,1,0.187796,0.287603,0.187796,0.287603,AVG
1,1,0.128375,0.196694,0.128375,0.196694,L1
1,1,0.133333,0.201653,0.133333,0.201653,L2


EvolveGCN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.175262,0.261157,0.175262,0.261157,sigmoid
1,1,0.126777,0.216529,0.126777,0.216529,HAD
1,1,0.187796,0.287603,0.187796,0.287603,AVG
1,1,0.120992,0.209917,0.120992,0.209917,L1
1,1,0.120992,0.209917,0.120992,0.209917,L2


CTGCN-C


Unnamed: 0,k,recall,mrr,map,ndcg,operator
1,1,0.168788,0.266116,0.168788,0.266116,sigmoid
1,1,0.16562,0.252893,0.16562,0.252893,HAD
1,1,0.201295,0.297521,0.201295,0.297521,AVG
1,1,0.170661,0.257851,0.170661,0.257851,L1
1,1,0.15303,0.234711,0.15303,0.234711,L2


In [None]:
0.276501	0.421488	0.276501	0.421488	
0.198402	0.307438	0.198402	0.307438	
0.187796	0.287603	0.187796	0.287603	
0.201295	0.297521	0.201295	0.297521
0.242617	0.352066	0.242617	0.352066	
0.242617	0.352066	0.242617	0.352066	
0.242617	0.352066	0.242617	0.352066	
0.216171	0.319008	0.216171	0.319008	
0.192424	0.285950	0.192424	0.285950
0.232700	0.342149	0.232700	0.342149
0.187796	0.287603	0.187796	0.287603
0.201295	0.297521	0.201295	0.297521

In [6]:
exp = 'dblp'
k = 2
time_id = 7 - 1
for method in methods:
    node_embedding = pd.read_csv(
                './data/{}/2.embedding/{}/{:02d}.csv'.format(
                    exp, method, time_id), index_col=0, sep='\t').values
    models = model_ls[method]
    ranking = make_prediction(pred_idx, node_embedding, models)
    eval_metrics = []
    for operator, rk in ranking.items():
        pd_pred = pd.DataFrame(rk, columns=['source', 'target', 'sims']).sort_values(['sims'], ascending=False)
        pred = pd_pred.groupby('source').agg({'target': list}).to_dict()['target']
        res = eval_ranking(pred, ground_truth, k=5)
        res['operator'] = operator
        eval_metrics.append(res)
    pd_res = pd.concat(eval_metrics)
    print(method)
    display(pd_res[pd_res['k']==k])

GCN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.459201,0.51157,0.389642,0.464808,sigmoid
2,2,0.325868,0.356198,0.256309,0.321632,HAD
2,2,0.32022,0.386777,0.266157,0.326112,AVG
2,2,0.465675,0.51157,0.390331,0.465732,L1
2,2,0.456997,0.504959,0.387231,0.460794,L2


GAT


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.372424,0.395868,0.277314,0.349229,sigmoid
2,2,0.320771,0.340496,0.226763,0.297349,HAD
2,2,0.32022,0.342149,0.224972,0.296474,AVG
2,2,0.374904,0.393388,0.28365,0.356047,L1
2,2,0.359339,0.392562,0.279174,0.348789,L2


TgGCN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.350523,0.404959,0.288884,0.351889,sigmoid
2,2,0.350799,0.399174,0.286267,0.349173,HAD
2,2,0.308788,0.345455,0.26416,0.322048,AVG
2,2,0.430413,0.471074,0.344256,0.420689,L1
2,2,0.322287,0.387603,0.266777,0.328768,L2


TgGAT


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.374215,0.400826,0.295565,0.364806,sigmoid
2,2,0.2727,0.296694,0.208099,0.263329,HAD
2,2,0.308788,0.306612,0.230758,0.296444,AVG
2,2,0.380551,0.41157,0.305069,0.37364,L1
2,2,0.372562,0.407438,0.300592,0.368731,L2


TgSAGE


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.354876,0.380165,0.277328,0.342097,sigmoid
2,2,0.341708,0.356198,0.255964,0.32342,HAD
2,2,0.316171,0.376033,0.255152,0.319353,AVG
2,2,0.307273,0.341322,0.234752,0.299044,L1
2,2,0.329394,0.360331,0.255771,0.318741,L2


TgGIN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.355482,0.369421,0.258168,0.333664,sigmoid
2,2,0.327658,0.36281,0.243567,0.31332,HAD
2,2,0.32022,0.386777,0.266157,0.326112,AVG
2,2,0.322287,0.343802,0.226625,0.300684,L1
2,2,0.249366,0.296694,0.188623,0.24998,L2


GCRN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.367052,0.394215,0.290399,0.355891,sigmoid
2,2,0.454242,0.460331,0.365882,0.434344,HAD
2,2,0.430413,0.471074,0.344256,0.420689,AVG
2,2,0.417273,0.418182,0.324545,0.394113,L1
2,2,0.402121,0.424793,0.328333,0.393584,L2


TIMERS


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.518705,0.577686,0.457342,0.531899,sigmoid
2,2,0.557686,0.618182,0.498113,0.57333,HAD
2,2,0.329669,0.366116,0.277466,0.326073,AVG
2,2,0.442259,0.447934,0.364229,0.423886,L1
2,2,0.430413,0.471074,0.344256,0.420689,L2


DynAE


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.2473,0.298347,0.185868,0.245769,sigmoid
2,2,0.370992,0.418182,0.319614,0.374225,HAD
2,2,0.430413,0.471074,0.344256,0.420689,AVG
2,2,0.430413,0.471074,0.344256,0.420689,L1
2,2,0.430413,0.438843,0.316846,0.400438,L2


DynAERNN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.322287,0.343802,0.226625,0.300684,sigmoid
2,2,0.322287,0.387603,0.266777,0.328768,HAD
2,2,0.361543,0.440496,0.30686,0.374263,AVG
2,2,0.361543,0.440496,0.30686,0.374263,L1
2,2,0.322287,0.343802,0.226625,0.300684,L2


DynGEM


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.442948,0.440496,0.347011,0.413908,sigmoid
2,2,0.380275,0.395868,0.287094,0.356462,HAD
2,2,0.306722,0.3,0.231997,0.292234,AVG
2,2,0.306722,0.305785,0.236474,0.295796,L1
2,2,0.308512,0.342149,0.267121,0.32067,L2


DySAT


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.328623,0.371074,0.260234,0.325402,sigmoid
2,2,0.358595,0.395041,0.283705,0.349732,HAD
2,2,0.333994,0.382645,0.263333,0.330055,AVG
2,2,0.339229,0.380165,0.266088,0.330823,L1
2,2,0.333113,0.384298,0.270758,0.336254,L2


VGRNN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.41719,0.459504,0.332686,0.408686,sigmoid
2,2,0.343636,0.386777,0.275937,0.344124,HAD
2,2,0.430413,0.438843,0.316846,0.400438,AVG
2,2,0.370992,0.340496,0.262493,0.326793,L1
2,2,0.370992,0.342975,0.264972,0.328623,L2


EvolveGCN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.360854,0.391736,0.273664,0.343109,sigmoid
2,2,0.328898,0.350413,0.232824,0.30706,HAD
2,2,0.430413,0.438843,0.316846,0.400438,AVG
2,2,0.322287,0.343802,0.226625,0.300684,L1
2,2,0.322287,0.343802,0.226625,0.300684,L2


CTGCN-C


Unnamed: 0,k,recall,mrr,map,ndcg,operator
2,2,0.362094,0.385124,0.273251,0.342914,sigmoid
2,2,0.331873,0.356198,0.258182,0.318189,HAD
2,2,0.329669,0.366116,0.277466,0.326073,AVG
2,2,0.305289,0.342149,0.248099,0.302222,L1
2,2,0.283113,0.32314,0.225716,0.280597,L2


In [None]:
0.465675	0.511570	0.390331	0.465732
0.380551	0.411570	0.305069	0.373640	
0.354876	0.380165	0.277328	0.342097
0.355482	0.369421	0.258168	0.333664
0.454242	0.460331	0.365882	0.434344
0.430413	0.471074	0.344256	0.420689	
0.361543	0.440496	0.306860	0.374263
0.442948	0.440496	0.347011	0.413908
0.358595	0.395041	0.283705	0.349732
0.430413	0.438843	0.316846	0.400438	
0.430413	0.438843	0.316846	0.400438
0.362094	0.385124	0.273251	0.342914

In [7]:
exp = 'dblp'
k = 3
time_id = 7 - 1
for method in methods:
    node_embedding = pd.read_csv(
                './data/{}/2.embedding/{}/{:02d}.csv'.format(
                    exp, method, time_id), index_col=0, sep='\t').values
    models = model_ls[method]
    ranking = make_prediction(pred_idx, node_embedding, models)
    eval_metrics = []
    for operator, rk in ranking.items():
        pd_pred = pd.DataFrame(rk, columns=['source', 'target', 'sims']).sort_values(['sims'], ascending=False)
        pred = pd_pred.groupby('source').agg({'target': list}).to_dict()['target']
        res = eval_ranking(pred, ground_truth, k=5)
        res['operator'] = operator
        eval_metrics.append(res)
    pd_res = pd.concat(eval_metrics)
    print(method)
    display(pd_res[pd_res['k']==k])

GCN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.620689,0.560055,0.459486,0.532226,sigmoid
3,3,0.524904,0.423967,0.335556,0.413396,HAD
3,3,0.441212,0.424793,0.318186,0.383031,AVG
3,3,0.601405,0.550689,0.452048,0.52194,L1
3,3,0.61146,0.550689,0.457443,0.526566,L2


GAT


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.516694,0.437741,0.340638,0.417233,sigmoid
3,3,0.448375,0.380165,0.281455,0.356196,HAD
3,3,0.441212,0.380165,0.277002,0.352164,AVG
3,3,0.518898,0.44022,0.342796,0.419093,L1
3,3,0.505675,0.440496,0.33933,0.413355,L2


TgGCN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.464215,0.445179,0.336965,0.402579,sigmoid
3,3,0.464215,0.438843,0.33444,0.399982,HAD
3,3,0.510083,0.430303,0.337218,0.412691,AVG
3,3,0.558788,0.511295,0.3977,0.473378,L1
3,3,0.441212,0.425069,0.318393,0.383302,L2


TgGAT


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.540386,0.452066,0.368255,0.440384,sigmoid
3,3,0.446529,0.355096,0.276097,0.345031,HAD
3,3,0.427713,0.335813,0.291556,0.345359,AVG
3,3,0.551625,0.464463,0.376556,0.450056,L1
3,3,0.556639,0.464187,0.378287,0.452856,L2


TgSAGE


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.533306,0.439118,0.351102,0.426115,sigmoid
3,3,0.487218,0.404683,0.316534,0.390144,HAD
3,3,0.435096,0.407989,0.309412,0.372231,AVG
3,3,0.494601,0.406336,0.309637,0.38827,L1
3,3,0.504242,0.42259,0.326561,0.403262,L2


TgGIN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.466143,0.406887,0.305973,0.379685,sigmoid
3,3,0.466832,0.410193,0.299775,0.376932,HAD
3,3,0.441212,0.424793,0.318186,0.383031,AVG
3,3,0.44438,0.382369,0.279803,0.355108,L1
3,3,0.437163,0.358402,0.263912,0.338698,L2


GCRN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.536749,0.446556,0.363448,0.435958,sigmoid
3,3,0.627934,0.508815,0.443981,0.514269,HAD
3,3,0.558788,0.511295,0.3977,0.473378,AVG
3,3,0.609201,0.471074,0.4109,0.483933,L1
3,3,0.613196,0.484848,0.42309,0.494311,L2


TIMERS


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.636804,0.604683,0.515983,0.575516,sigmoid
3,3,0.716143,0.652893,0.57714,0.638533,HAD
3,3,0.572287,0.436088,0.381405,0.448648,AVG
3,3,0.572287,0.481543,0.4241,0.481003,L1
3,3,0.558788,0.511295,0.3977,0.473378,L2


DynAE


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.435096,0.355647,0.263361,0.336628,sigmoid
3,3,0.558788,0.493664,0.389486,0.464608,HAD
3,3,0.549339,0.502479,0.399867,0.471268,AVG
3,3,0.549339,0.502479,0.399867,0.471268,L1
3,3,0.549339,0.470248,0.372456,0.450315,L2


DynAERNN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.450661,0.384573,0.279702,0.35531,sigmoid
3,3,0.450661,0.428375,0.319853,0.385412,HAD
3,3,0.549339,0.492287,0.387401,0.461945,AVG
3,3,0.549339,0.492287,0.387401,0.461945,L1
3,3,0.450661,0.384573,0.279702,0.35531,L2


DynGEM


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.601901,0.482369,0.419187,0.488167,sigmoid
3,3,0.563113,0.45978,0.362309,0.445376,HAD
3,3,0.508017,0.384848,0.305331,0.385105,AVG
3,3,0.42854,0.340496,0.294151,0.347512,L1
3,3,0.437025,0.395041,0.316203,0.373495,L2


DySAT


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.458926,0.411846,0.31455,0.38187,sigmoid
3,3,0.495427,0.436915,0.342282,0.411211,HAD
3,3,0.472287,0.423967,0.320863,0.391389,AVG
3,3,0.479174,0.424793,0.323939,0.394631,L1
3,3,0.47573,0.428926,0.330666,0.399143,L2


VGRNN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.538953,0.497521,0.383926,0.458069,sigmoid
3,3,0.510771,0.439118,0.346791,0.419622,HAD
3,3,0.549339,0.470248,0.372456,0.450315,AVG
3,3,0.582066,0.420937,0.342466,0.4311,L1
3,3,0.563747,0.419559,0.337048,0.423428,L2


EvolveGCN


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.502231,0.431956,0.334417,0.408797,sigmoid
3,3,0.449477,0.38843,0.285266,0.360589,HAD
3,3,0.558788,0.479063,0.370289,0.452425,AVG
3,3,0.441212,0.381267,0.278242,0.3532,L1
3,3,0.441212,0.381267,0.278242,0.3532,L2


CTGCN-C


Unnamed: 0,k,recall,mrr,map,ndcg,operator
3,3,0.527025,0.439669,0.341212,0.420165,sigmoid
3,3,0.488264,0.410193,0.323232,0.393997,HAD
3,3,0.572287,0.436088,0.381405,0.448648,AVG
3,3,0.487025,0.406612,0.321276,0.39082,L1
3,3,0.432893,0.367218,0.29017,0.352096,L2
