In [1]:
%load_ext autoreload
%autoreload 2
import os, sys, re, datetime, random, gzip, json, copy
from pathlib import Path
import networkx as nx
import numpy as np
import pandas as pd
import itertools
import collections
import networkx as nx
import numpy as np
import pandas as pd
import itertools
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, f1_score
from sklearn.ensemble import RandomForestClassifier
from sklearn import linear_model
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score

PROJ_PATH = Path(os.path.join(re.sub("/CTGCN.*$", '', os.getcwd()), 'CTGCN'))
sys.path.insert(1, str(PROJ_PATH.parents[0] / 'DyHNet'))
import DyHNet
from DyHNet.src.evaluation import predict_node_classification, eval_node_classification, eval_link_prediction

# Utils

In [34]:
def load_node_embedding(exp, method, data, temporal_methods, non_temporal_methods):
    time_ids = sorted(set([d['time_id'] for i,d in data.items()]))
    node_embedding = {}
    for time_id in time_ids:
        if method in non_temporal_methods:
            embs = pd.read_csv(
                str(PROJ_PATH / 'data' / '{}_homo'.format(exp) / '2.embedding' / method / '00.csv'), 
                index_col=0, sep='\t')
        elif method in temporal_methods:
            embs = pd.read_csv(
                str(PROJ_PATH / 'data' / exp / '2.embedding' / method / '{:02d}.csv'.format(time_id)), 
                index_col=0, sep='\t')
        else:
            print('Unknown method!')
        dict_embs = embs.T.to_dict('list')
        node_embedding[time_id] = dict_embs
    return node_embedding

def get_training_data(data, node_embedding):
    nids = []
    features = []
    dataset = []
    labels_str = []
    for i, d in data.items():
        time_id = d['time_id']
        nid = d['node_id']
        nids.append(nid)
        dataset.append(d['dataset'])
        labels_str.append(d['label'])
        features.append(node_embedding[time_id][nid])

    label_mapping = {j:i for i,j in enumerate(sorted(set(labels_str)))}
    labels = [label_mapping[l] for l in labels_str]
    return nids, features, labels, dataset, label_mapping

def print_report(exp='yelp', methods=[], selected_methods=[]):
    non_temporal_methods = [
        'GCN', 'GAT', 'SAGE', 'GIN', 
        'TgGCN', 'TgGAT', 'TgSAGE', 'TgGIN', 
    ]
    temporal_methods = [
        'GCRN', 'TIMERS', 'DynAE', 'DynRNN', 
        'DynAERNN', 'DynGEM', 'DySAT',
        'VGRNN', 'EvolveGCN', 'CTGCN-C',
        'DyHATR', 'DHNE', 'MetaGraph2vec'
    ]
    if len(methods) == 0:
        methods = non_temporal_methods + temporal_methods
    if len(selected_methods) == 0:
        selected_methods = [
            'GCN', 'TgGAT', 'TgSAGE', 'TgGIN',
            'GCRN', 'TIMERS', 'DynAE', 'DynRNN', 'DynAERNN', 'DynGEM', 'DySAT',
            'VGRNN', 'EvolveGCN', 'CTGCN-C', 'DHNE', 'DyHATR', 'MetaGraph2vec']
    
    # get features
    res = []
    model_ls = {}

    for method in methods:
        print(method)
        for train_percent in [20, 40, 60, 80]:
            data = pd.read_pickle(str(PROJ_PATH / 'data' / exp / '0.input' / f'data_train={train_percent}.pkl'))
            # get features, labels, tvt
            if method in ['DynAE', 'DynRNN', 'DynAERNN', 'DHNE']:
                data = {i:j for i,j in data.items() if j['time_id']>=1}
            node_embedding = load_node_embedding(exp, method, data, temporal_methods, non_temporal_methods)
            nids, features, labels, dataset, label_mapping = get_training_data(data, node_embedding)
            # eval
            results, model = eval_node_classification(features, labels, train_val_test_index=dataset)
            tmp = pd.DataFrame(results)
            tmp['method'] = method
            tmp['train_percent'] = train_percent
            res.append(tmp)
            model_ls[method] = model['model']
        print('Label mapping:', label_mapping)
        
    df = pd.concat(res)
    df = df.reset_index().rename(columns={'index':'dataset'})
    
    print('Full report')
    display(df)

    print('Test report')
    display(df[df['dataset']=='test'])

    print('Selected methods')
    display(df[(df['dataset']=='test')&(df['method'].isin(selected_methods))][[
        'method', 'train_percent', 'accuracy', 'auc', 'f1_macro', 'f1_micro']])
    
    return df, model_ls

In [32]:
# exp = 'dblp_four_area'
# for train_percent in [20, 40, 60, 80]:
#     nids, dataset, labels, time_id = get_training_data(exp, train_percent)
#     df = pd.DataFrame({'tvt': dataset, 'nid': nids, 'label': labels})
#     display(df.pivot_table(index='tvt', columns='label', values='nid', aggfunc='count'))

# Evaluate

In [None]:
methods = [
    'VGRNN', 'DynGEM', 'DynAE', 'DynRNN', 'DynAERNN',
]
df, model_ls = print_report('yelp', methods)

In [None]:
methods = [
    'VGRNN', 'DynGEM', 'DynAE', 'DynRNN', 'DynAERNN',
]
df, model_ls = print_report('yelp_s', methods)

In [None]:
methods = [
    'VGRNN', 'DynGEM', 'DynAE', 'DynRNN', 'DynAERNN',
]
df, model_ls = print_report('dblp_four_area', methods)

In [None]:
methods = [
    'VGRNN', 'DynGEM', 'DynAE', 'DynRNN', 'DynAERNN',
]
df, model_ls = print_report('dblp_four_area_s', methods)

# Evaluate

In [35]:
# methods = [
#     'GCN', 'TgGCN', 'TgGAT', 'TgSAGE', 'TgGIN', 
#     'GCRN', 'TIMERS', 'EvolveGCN', 'CTGCN-C',
#     'DHNE', 'DyHATR',
# ]
# df, model_ls = print_report('yelp', methods)

GCN
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
TgGCN
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
TgGAT
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
TgSAGE
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
TgGIN
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
GCRN
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
TIMERS
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
EvolveGCN
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
CTGCN-C
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
DHNE
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
DyHATR
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
Full report


Unnamed: 0,dataset,accuracy,auc,f1_macro,f1_micro,method,train_percent
0,train,0.392060,0.760970,0.187760,0.392060,GCN,20
1,val,0.358209,0.621968,0.175824,0.358209,GCN,20
2,test,0.353432,0.649368,0.174092,0.353432,GCN,20
3,train,0.377919,0.725966,0.182845,0.377919,GCN,40
4,val,0.455224,0.675499,0.208547,0.455224,GCN,40
...,...,...,...,...,...,...,...
127,val,0.408922,0.614892,0.401558,0.408922,DyHATR,60
128,test,0.405380,0.619618,0.397476,0.405380,DyHATR,60
129,train,0.455703,0.658323,0.446198,0.455703,DyHATR,80
130,val,0.408922,0.630570,0.405996,0.408922,DyHATR,80


Test report


Unnamed: 0,dataset,accuracy,auc,f1_macro,f1_micro,method,train_percent
2,test,0.353432,0.649368,0.174092,0.353432,GCN,20
5,test,0.367347,0.666918,0.179104,0.367347,GCN,40
8,test,0.361781,0.672251,0.177112,0.361781,GCN,60
11,test,0.371058,0.66131,0.180424,0.371058,GCN,80
14,test,0.406308,0.673841,0.396444,0.406308,TgGCN,20
17,test,0.437848,0.67326,0.442041,0.437848,TgGCN,40
20,test,0.427644,0.67695,0.427044,0.427644,TgGCN,60
23,test,0.435993,0.673013,0.440311,0.435993,TgGCN,80
26,test,0.353432,0.57082,0.174092,0.353432,TgGAT,20
29,test,0.367347,0.587469,0.179104,0.367347,TgGAT,40


Selected methods


Unnamed: 0,method,train_percent,accuracy,auc,f1_macro,f1_micro
2,GCN,20,0.353432,0.649368,0.174092,0.353432
5,GCN,40,0.367347,0.666918,0.179104,0.367347
8,GCN,60,0.361781,0.672251,0.177112,0.361781
11,GCN,80,0.371058,0.66131,0.180424,0.371058
26,TgGAT,20,0.353432,0.57082,0.174092,0.353432
29,TgGAT,40,0.367347,0.587469,0.179104,0.367347
32,TgGAT,60,0.361781,0.594434,0.177112,0.361781
35,TgGAT,80,0.374768,0.556147,0.181736,0.374768
38,TgSAGE,20,0.353432,0.659373,0.174092,0.353432
41,TgSAGE,40,0.367347,0.651792,0.179104,0.367347


In [37]:
methods = [
#     'GCN', 
#     'TgGCN', 'TgGAT', 'TgSAGE', 'TgGIN', 
#     'GCRN', 'TIMERS', 'EvolveGCN', 
    'CTGCN-C', 'TIMERS', 'GCRN',
#     'DHNE', 'DyHATR',
]
df, model_ls = print_report('yelp_s', methods)

CTGCN-C
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
TIMERS
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
GCRN
Label mapping: {'American (New)': 0, 'Fast Food': 1, 'Sushi Bars': 2}
Full report


Unnamed: 0,dataset,accuracy,auc,f1_macro,f1_micro,method,train_percent
0,train,0.751861,0.905553,0.755472,0.751861,CTGCN-C,20
1,val,0.410448,0.613162,0.413769,0.410448,CTGCN-C,20
2,test,0.46475,0.656391,0.471118,0.46475,CTGCN-C,20
3,train,0.650743,0.837128,0.655048,0.650743,CTGCN-C,40
4,val,0.477612,0.702087,0.490238,0.477612,CTGCN-C,40
5,test,0.461348,0.662624,0.469395,0.461348,CTGCN-C,40
6,train,0.619614,0.816333,0.624063,0.619614,CTGCN-C,60
7,val,0.468401,0.665755,0.47389,0.468401,CTGCN-C,60
8,test,0.474026,0.670202,0.481708,0.474026,CTGCN-C,60
9,train,0.592042,0.791198,0.597767,0.592042,CTGCN-C,80


Test report


Unnamed: 0,dataset,accuracy,auc,f1_macro,f1_micro,method,train_percent
2,test,0.46475,0.656391,0.471118,0.46475,CTGCN-C,20
5,test,0.461348,0.662624,0.469395,0.461348,CTGCN-C,40
8,test,0.474026,0.670202,0.481708,0.474026,CTGCN-C,60
11,test,0.495362,0.695328,0.500893,0.495362,CTGCN-C,80
14,test,0.375696,0.570489,0.282312,0.375696,TIMERS,20
17,test,0.416203,0.587957,0.398526,0.416203,TIMERS,40
20,test,0.415584,0.582422,0.396495,0.415584,TIMERS,60
23,test,0.410019,0.580549,0.396405,0.410019,TIMERS,80
26,test,0.41744,0.605218,0.41915,0.41744,GCRN,20
29,test,0.428571,0.61738,0.429332,0.428571,GCRN,40


Selected methods


Unnamed: 0,method,train_percent,accuracy,auc,f1_macro,f1_micro
2,CTGCN-C,20,0.46475,0.656391,0.471118,0.46475
5,CTGCN-C,40,0.461348,0.662624,0.469395,0.461348
8,CTGCN-C,60,0.474026,0.670202,0.481708,0.474026
11,CTGCN-C,80,0.495362,0.695328,0.500893,0.495362
14,TIMERS,20,0.375696,0.570489,0.282312,0.375696
17,TIMERS,40,0.416203,0.587957,0.398526,0.416203
20,TIMERS,60,0.415584,0.582422,0.396495,0.415584
23,TIMERS,80,0.410019,0.580549,0.396405,0.410019
26,GCRN,20,0.41744,0.605218,0.41915,0.41744
29,GCRN,40,0.428571,0.61738,0.429332,0.428571


In [39]:
# methods = [
#     'GCN', 'TgSAGE', 'TgGIN', 
#     'GCRN', 'TIMERS', 'EvolveGCN', 'CTGCN-C',
#     'DHNE', 'DyHATR', 'MetaGraph2vec'
# ]
# df, model_ls = print_report('dblp_four_area', methods)

In [40]:
# methods = [
#     'GCN', 'TgSAGE', 'TgGIN', 
# #     'GCRN', 
#     'TIMERS', 'EvolveGCN', 
# #     'CTGCN-C',
# #     'DHNE', 'DyHATR',
#     'MetaGraph2vec'
# ]
# df, model_ls = print_report('dblp_four_area_s', methods)