In [None]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from scipy import stats

from matplotlib import pyplot as plt
import seaborn as sns
from statannot import add_stat_annotation

sns.set()

In [None]:
# load results for study bias holdout
n2v_result_df = pd.read_csv('../result/gene_classification_n2v.csv')
optim_idx_ary = n2v_result_df.groupby(['Network', 'Method', 'Task'])['Validation score'].agg('idxmax').values
optim_n2v_result_df = n2v_result_df.iloc[optim_idx_ary]

gnn_result_df = pd.read_csv('../result/gene_classification_gnn.csv')
sb_result_df = pd.concat((optim_n2v_result_df, gnn_result_df))

# lod results for cross validation
n2v_result_df = pd.read_csv('../result/gene_classification_n2v_cv.csv')
optim_idx_ary = n2v_result_df.groupby(['Network', 'Method', 'Task'])['Validation score'].agg('idxmax').values
cv_result_df = n2v_result_df.iloc[optim_idx_ary]

In [None]:
pthresh = 0.05
pvt = [[pthresh, "*"], [1, ""]]

network_list = ['GIANT-TN', 'GIANT-TN-c01', 'STRING']
dataset_list = ['GOBP', 'KEGGBP', 'DisGeNet']
full_method_list = ['GCN', 'GraphSAGE', 'Node2vec+', 'Node2vec']
n2v_method_list = ['Node2vec+', 'Node2vec']
method_name_dict = {'gcn': 'GCN', 'sage': 'GraphSAGE'}

color_dict = {
    "GCN": 'lightskyblue',
    "GraphSAGE": 'slateblue',
    "Node2vec+": 'orangered',
    "Node2vec": 'lightsalmon'
}

result_df_list = [cv_result_df, sb_result_df]
result_name_list = ['Cross Validation', 'Study-bias Holdout']

# rename method name
for result_df in result_df_list:
    result_df['Method'].replace(method_name_dict, inplace=True)

In [None]:
def get_stat_annot(group, method_list):
    n_method = len(method_list)
    pval_list = []
    box_pair_list = []
    for dataset, subgroup in group.groupby('Dataset'):
        for idx1 in range(n_method - 1):
            for idx2 in range(idx1 + 1, n_method):
                method1, method2 = method_list[idx1], method_list[idx2]
                scores1 = subgroup[subgroup['Method'] == method1].sort_values('Task')['Testing score'].values
                scores2 = subgroup[subgroup['Method'] == method2].sort_values('Task')['Testing score'].values
                pval = stats.wilcoxon(scores1, scores2)[1]
                if pval < pthresh:
                    pval_list.append(pval)
                    box_pair_list.append(((dataset, method1), (dataset, method2)))

    return pval_list, box_pair_list

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(11,7), sharex=True, sharey='row')

for i, (result_df, result_name) in enumerate(zip(result_df_list, result_name_list)):
    for network, group in result_df.groupby('Network'):
        network_idx = network_list.index(network)
        ax = axes[i, network_idx]
        method_list = n2v_method_list if result_name == 'Cross Validation' else full_method_list
        
        sns.boxplot(data=group, order=dataset_list, x='Dataset',
                    y='Testing score', palette=color_dict, hue='Method',
                    hue_order=method_list, notch=True, ax=ax)

        pval_list, box_pair_list = get_stat_annot(group, n2v_method_list)
        if len(pval_list) > 0:
            add_stat_annotation(ax, data=group, x='Dataset', y='Testing score',
                                order=dataset_list, hue='Method', hue_order=method_list,
                                verbose=0, perform_stat_test=False, box_pairs=box_pair_list,
                                loc='inside', pvalue_thresholds=pvt, pvalues=pval_list)

        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.get_legend().remove()
        
        if i == 0:
            ax.set_title(network, fontsize=14)
        if network_idx == 0:
            ax.set_ylabel(result_name, fontsize=14)

plt.tight_layout()

plt.subplots_adjust(bottom=0.1)
axes[1,1].legend(bbox_to_anchor=(1.4, -0.1), ncol=4, fontsize=12)
# plt.savefig("fig_bio_nc_eval.png", dpi=300) # uncomment to save

plt.show()