# Reproduce the ablation study table, kmeans agreement score table and flow prediction performance boxplots

In [68]:
import pandas as pd
import copy
import os
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
import seaborn as sns
import tikzplotlib


In [69]:
plt.style.use('seaborn-whitegrid')
plt.rcParams['font.size'] = 14
plt.rcParams['figure.figsize'] = [6.5, 4.4]
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['font.family'] = 'serif' 
default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [82]:
synthetic_type = 'multimodal'

assert synthetic_type == 'unimodal' or synthetic_type == 'multimodal'

figures_save_folder = f"figures/ablation_{synthetic_type}/"
tables_save_folder = f"tables/ablation_{synthetic_type}/"
os.makedirs(figures_save_folder, exist_ok=True)
os.makedirs(tables_save_folder, exist_ok=True)

# compare_inx = 5 if synthetic_type == 'unimodal' else 2

model_dict = {"dnn2_engi_feat": 'f.e.+\ndnn2', 
              "dnn2_node2vec": 'n2v+\ndnn2',
              "dnn2_both": 'both+\ndnn2',
              'gated': 'gated', 
              'grad': 'grad',
              'fairness_goodness': "Kumar\net.al."}

repr_dict = {"dnn2_engi_feat": 'engi. feat.', 
             "dnn2_node2vec": 'node2vec', 
             "dnn2_both": 'both',
             'gated': 'gated', 
             'grad': 'grad'}

model_dict_wo_enter = {wenter: wenter.replace("+\n", "+").replace("\n", " ") for wenter in model_dict.values()}

In [83]:
def load_results(graph_name, synthetic_type):
    RESULTS_TMPL = "results/{}_{}/"
    BASELINE_TMPL = "{}_ablation_baseline_results.csv"
    GRAD_TMPL = "{}_ablation_grad_baseline_results.csv"
    GATED_GRAD_TMPL = "{}_ablation_joint_results.csv"
    FAIRNESS_GOODNESS_TMPL = "{}_ablation_fairness_goodness.csv"
    INIT_TMPL = "{}_ablation_init_results.csv"
    FLOW_INFO_TMPL = "flow_info.csv"
    folder = RESULTS_TMPL.format(graph_name, synthetic_type)
    
    results = dict()
    
    results["gated"] = pd.read_csv(os.path.join(
        folder, GATED_GRAD_TMPL.format(graph_name)
    ))
    
    results["grad"] = pd.read_csv(os.path.join(
        folder, GRAD_TMPL.format(graph_name)
    ))
    
    baseline_df = pd.read_csv(os.path.join(
        folder, BASELINE_TMPL.format(graph_name)
    ))
    unique_baselines = baseline_df.baseline_name.unique()
    for bl in unique_baselines:
        results[bl] = baseline_df.loc[baseline_df.baseline_name == bl, :]
    

    
    results["fairness_goodness"] = pd.read_csv(os.path.join(
        folder, FAIRNESS_GOODNESS_TMPL.format(graph_name)
    ))


    results["init"] = pd.read_csv(os.path.join(
        folder, INIT_TMPL.format(graph_name)
    ))
    

    results["flow_info"] = pd.read_csv(os.path.join(
        folder, FLOW_INFO_TMPL
    ))
    
    return results
    

In [84]:
def add_models_repr_columns(results):
    for graph_name, results_ in results.items():
        for key, df in results_.items():
            df['models'] = model_dict.get(key, key)
            df['representations'] = repr_dict.get(key, key)
    return results

In [85]:
def get_best_abl_inx(results, column="val_median_mag_error", select="min"):
    abl_inx_dict = dict()
    for graph_name, results_ in results.items():
        abl_inx_dict[graph_name] = dict()
        for key, df in results_.items():
            if "ablation_idx" not in df.columns:
                continue
            grouped = df.groupby("ablation_idx").agg({column: "mean"})
            abl_idx = grouped.idxmin().item()
            abl_inx_dict[graph_name][key] = abl_idx
    return abl_inx_dict

In [86]:
graph_names = [ "cora", "bitcoin", "complete"]
selected_models = [ 'gated', 'grad', "dnn2_engi_feat", "dnn2_node2vec", "fairness_goodness"]
results = dict()
for graph_name in graph_names:
    results[graph_name] = load_results(graph_name, synthetic_type)
results = add_models_repr_columns(results)
abl_inx_dict = get_best_abl_inx(results)

best_results = dict()
for graph_name in graph_names:
    best_results[graph_name] = []
    for key, result in results[graph_name].items():
        if key not in selected_models:
            continue
        best_results[graph_name].append(result.loc[(result['ablation_idx'] == abl_inx_dict[graph_name][key]), :])
    best_results[graph_name] = pd.concat(best_results[graph_name], ignore_index=True)

In [87]:
# To create ablation study table
def make_init_improvement_table(grad_results, col='median_mag_error', new_name='error*'):
    
    columns = [f'train_{col}', f'val_{col}']
    names = [f'train {new_name}', f'val {new_name}']
    name_mapper = dict(zip(columns, names))
    name_mapper['ablation_idx'] = 'init and reg.'
    
    

    df_results = grad_results.loc[:, [f'train_{col}', f'val_{col}', 'ablation_idx']]
    df_results['ablation_idx'] = df_results['ablation_idx'].map({0: 'normal noise', 
                                                                 1: 'LSQR', 
                                                                 2: 'LSQR+',
                                                                 3: 'LSQR+, L1($u$)',
                                                                 4: 'LSQR+, L1($z$)',
                                                                 5: 'LSQR+, L1($u$), L1($z$)'
                                                                })
    
    
    all_res_df = df_results.rename(name_mapper, axis='columns')
    
    out = all_res_df.groupby(['init and reg.'], sort=False).agg({
        f'val {new_name}': ['mean', 'std'],
        f'train {new_name}': ['mean', 'std']})
    
    return out

In [88]:
# To create train and validation box plot
def make_comparison(best_results, score='median_mag_error', split="val"):
    
    res = best_results.loc[:, ['models']]
    res[score] = best_results[f'{split}_{score}']
    
#     res['models'] = ['models'].map
    
    return res
    
    

In [89]:
def make_kmeans_agreement_score_comparison(best_results, score='multimodal_score'):
    res_val = best_results.loc[:, ['representations']]
    res_val[score] = best_results[score]
    res_val = res_val.rename({'multimodal_score': 'k-means agreement'}, axis='columns')
    out = res_val.groupby('representations', sort=False).agg({'k-means agreement': ['mean', 'std']})
    return out
    

In [90]:
# To a dense comparison of training and validation results for all models
def setup_df_for_ablation_plot_all_train_val(results_dict, score='loss', remove_index=False):
    for key, result in results_dict.items():
        if key == 'baseline' or key == 'baselines':
            result['models'] = result['baseline_name']
        else:
            result['models'] = key
    all_res = pd.concat(list(results_dict.values()))

    res_val = all_res.loc[:, ['models', 'ablation_idx']]
    res_val[score] = all_res[f'val_{score}']
    res_val['split'] = 'val'

    res_train = all_res.loc[:, ['models', 'ablation_idx']]
    res_train[score] = all_res[f'train_{score}']
    res_train['split'] = 'train'

    split_res = pd.concat((res_train, res_val))
    if remove_index:
        split_res = remove_abl_inx(split_res, remove_index)
    return split_res

In [91]:
def init_improvement(results, graph_name):
    gated_init_improvment = make_init_improvement_table(results["gated"])
    gated_init_improvment.to_latex(os.path.join(tables_save_folder, f"{graph_name}_init_improvement.tex"), 
                                        escape=False, float_format="%.2f")
    print(gated_init_improvment)

def agreement_score(results, graph_name):
    if synthetic_type != 'multimodal':
        print("agreement score only for mulimodal")
        return
    out = make_kmeans_agreement_score_comparison(results)
    out.to_latex(os.path.join(tables_save_folder, f"{graph_name}_kmeans_agreement.tex"), escape=False, float_format="%.2f")
    print(out)
    
def plot_boxes(results, graph_name, split):
#     score = 'median_mag_error'
#     score_name=r'log$_{10}$ median rel. error'
    score = 'MeAE'
    score_name='median abs. error'
    data = make_comparison(results, score=score, split=split)
    g = sns.catplot(x="models", y=score,
                    data=data, 
                    kind='box', 
                    linewidth=1.5)
    g.set(ylabel=score_name)
    sns.despine(left=True)
    xticklabels = list(x._text for x in g_val._axes[0][0].get_xticklabels())
    xticklabels_latex =  list(x.replace("\n", "\\\\") for x in xticklabels)
    g.set_xticklabels(xticklabels_latex)
    tikzplotlib.save(os.path.join(figures_save_folder, f"{graph_name}_{split}_{score}.tex"),
                    extra_axis_parameters=["xticklabel style={align=center}"])
    g.set_xticklabels(xticklabels)
    g.fig.savefig(os.path.join(figures_save_folder, f"{graph_name}_{split}_{score}.pdf"), bbox_inches='tight')
    return g

In [80]:
for graph_name in graph_names:
    init_improvement(results[graph_name], graph_name)
    agreement_score(best_results[graph_name], graph_name)
#     g_val = plot_boxes(best_results[graph_name], graph_name, "val")
#     g_val = plot_boxes(best_results[graph_name], graph_name, "train")

                        val error*           train error*          
                              mean       std         mean       std
init and reg.                                                      
normal noise             -0.042623  0.003290    -1.234667  0.014708
LSQR                     -0.042221  0.002488    -0.449679  0.010975
LSQR+                    -0.032853  0.004197    -1.241268  0.012833
LSQR+, L1($u$)           -0.067426  0.005261    -0.951209  0.008656
LSQR+, L1($z$)           -0.050759  0.003698    -0.486623  0.005285
LSQR+, L1($u$), L1($z$)  -0.073036  0.003498    -0.388167  0.005591
agreement score only for mulimodal
                        val error*           train error*          
                              mean       std         mean       std
init and reg.                                                      
normal noise             -0.041706  0.014977    -1.844500  0.108128
LSQR                     -0.033476  0.010215    -0.960839  0.052054
LSQR+        

In [92]:

model_order = ["gated", "grad", "f.e.+dnn2", "n2v+dnn2", "Kumar et.al."]
for split in ["val", "train"]:
    all_data = []
    for graph_name in graph_names:
        data = make_comparison(best_results[graph_name], score="median_mag_error", split=split)
        all_data.append(
                data.groupby("models").agg({"median_mag_error": ["mean", "std"]}).rename(columns={"median_mag_error": graph_name})
        )
    result_tbl = pd.concat(all_data, axis="columns").rename(index=model_dict_wo_enter).loc[model_order, :]
    result_tbl.to_latex(os.path.join(tables_save_folder, f"{split}_performance.tex"), 
                                            escape=False, float_format="%.2f")