# 5. Test set evaluation 
We now use top models from the grid search to deploy on test

In [None]:
%load_ext autoreload
%autoreload 2

Note: the below is similar to the code seen in "protein-eval.py", use the selected models in notebook #4 to hardcode model strings

In [None]:
encoder_list = ["COLLAPSE", "ESM", "AA"]
test_metrics = ["ap"]
metal = 'ZN'
# hard coded for now from looking at top of ranked dataframe
encoder_top_models = { \
    'COLLAPSE': ('k25_r4_cutoff8.00_alpha1.0000_tau4.00_lamnan.model', 0.95), \
    # 'COLLAPSE': ('k15_r1_cutoff8.00_alpha0.0100_tau0.00_lamnan.model', 0.8), \
    'ESM': ('k30_r2_cutoff6.00_alpha0.0100_tau0.00_lamnan.model', 0.1), \
    'AA': ('k21_r1_cutoff8.00_alpha0.010_tau1.00_lamnan.model', 0.5)}

baseline_top_models = \
    {'COLLAPSE': ('COLLAPSE-ZN-8.0-0.0001-100', 0.7), \
    'ESM': ('ESM-ZN-6.0-0.001-500', 0.4), \
    'AA': ('AA-ZN-6.0-0.001-200', 0.5)}

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
def setup_figure(width=6, height=3):
    sns.set(style='white')
    sns.set_context('paper')
    plt.figure(figsize=(width,height))
pal = sns.color_palette('tab20')

In [None]:
from evaluation import test_eval, extract_params, get_test_metrics
import utils
import pandas as pd

In [None]:
baselines = ['Attention', 'GNNExplainer', 'SHAP']
base_df = []
for encoder in encoder_list:
    for baseline in baselines:
        for model_thresh_pair in baseline_top_models[encoder]:
            best_model, best_thresh = baseline_top_models[encoder]
            results_dict = utils.deserialize(f'../data/baselines/{encoder}_{baseline}_test_results.pkl')
            df = get_test_metrics(results_dict, encoder, best_model, best_thresh, test_metrics)
            # add a "method" column to the df (K2, Attn, Prob)
            df["method"] = "GAT+"+baseline
            base_df.append(df)
base_df = pd.concat(base_df)

In [None]:
base_df.groupby(['encoder', 'method']).size()

In [None]:
# from evaluation import compute_seg_all_configs

# cache_dir = "/dfs/scratch1/gmachi/gcp_backup/k2/"
# Gs_dir = "/dfs/scratch1/gmachi/gcp_backup/data/tinycam/test/clean_Gs_"
# label_dict_path = "/dfs/scratch1/gmachi/gcp_backup/k2/refined_label_dicts/refined_test_labeldict-" 
# gts_path = "/dfs/scratch1/gmachi/gcp_backup/data/tinycam/test/gt_graphs_"

# test_df = compute_seg_all_configs(encoder_top_models, cache_dir, Gs_dir, gts_path, label_dict_path)

Note: the below cell takes about 8min to run

In [None]:
test_df = []
for encoder, (model_str, threshold) in encoder_top_models.items():
    results_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results/{encoder}-eval_results"
    model_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results/{encoder}-fitted_k2_models"
    processor_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results/{encoder}-fitted_k2_processors"
    linearized_cache_dir = f"../data/{encoder}_{metal}_gridsearch_results/{encoder}-linearized_data"

    _,_,cutoff,_,_,_ = extract_params(model_str)

    if encoder == 'AA':
        g_encoder = 'COLLAPSE'
    else:
        g_encoder = encoder

    G_dir = f"../data/{g_encoder}_{metal}_{cutoff}_test_graphs_2"

    df = test_eval(model_str, threshold, test_metrics, model_cache_dir, processor_cache_dir, G_dir, gt_dir=None, label_dict=None, modality="graph", arm="test")
    df['method'] = 'Prospector'
    test_df.append(df)
test_df = pd.concat(test_df)

In [None]:
combined_df = pd.concat([test_df, base_df])

mean_df = combined_df.groupby(['encoder', 'method','regime', 'metric'])['value'].mean().reset_index()
sem_df = combined_df.groupby(['encoder', 'method','regime', 'metric'])['value'].sem().reset_index()

mean_pvt = mean_df.pivot(index=['encoder', 'method', 'regime'], columns='metric', values='value')
mean_pvt = mean_pvt[test_metrics]

sem_pvt = sem_df.pivot(index=['encoder', 'method', 'regime'], columns='metric', values='value')
sem_pvt = sem_pvt[test_metrics]

#Save dfs
# mean_pvt.to_csv(f'../data/all_test_results_mean-{test_metrics[0]}.csv')
# sem_pvt.to_csv(f'../data/all_test_results_sem-{test_metrics[0]}.csv')
# combined_df.to_csv(f'../data/all_test_results_points-{test_metrics[0]}.csv') # graph-level results
# test_df.to_csv(f'../data/k2_test_results_points-{test_metrics[0]}.csv')  # k2 only


# Properties vs performance

In [None]:
from evaluation import compute_test_mrds, compute_test_rps, compute_test_mcs, compute_test_ccs

G_dir = f"../data/COLLAPSE_ZN_8.0_test_graphs_2"
rps_dict = compute_test_rps(G_dir, gt_key='gt')
mrds_dict = compute_test_mrds(G_dir, gt_key='gt')
ccs_dict = compute_test_ccs(G_dir, gt_key='gt')
mcs_dict = compute_test_mcs(G_dir, gt_key='gt')

In [None]:
test_df['rp'] = test_df['datum_id'].map(rps_dict)
test_df['mrd'] = test_df['datum_id'].map(mrds_dict)
test_df['ccs'] = test_df['datum_id'].map(ccs_dict)
test_df['mcs'] = test_df['datum_id'].map(mcs_dict)
test_df['srp'] = test_df['rp'] / test_df['ccs']

In [None]:
test_df = test_df.reset_index(drop=True).dropna()

In [None]:
import numpy as np
def movingaverage(interval, window_size):
    window= np.ones(int(window_size))/float(window_size)
    return np.convolve(interval, window, 'full')

In [None]:
running_means = test_df.sort_values('rp').groupby('encoder').apply(lambda x: movingaverage(x['value'], 20))
mean_data = []
for enc, df in test_df.sort_values('rp').groupby('encoder'):
    mean_data.extend(list(zip(df['rp'], running_means[enc], [enc]*len(df))))
mean_data = pd.DataFrame(mean_data, columns=['rp', 'value', 'encoder'])

In [None]:
met = test_metrics[0]
setup_figure(5,3.5)
g = sns.scatterplot(data=test_df, x="rp", y="value", hue="encoder", size="ccs", sizes=(10, 100), alpha=0.3)
sns.lineplot(data=mean_data.reset_index(), x='rp', y='value', hue='encoder', hue_order=['COLLAPSE', 'ESM', 'AA'])
# g = sns.scatterplot(data=test_df, x="rp", y="value", hue="encoder", alpha=0.4, s=100)
g.set_xscale("log")
# g.set_yscale("log")
sns.move_legend(g, "upper left", bbox_to_anchor=(1, 1))
g.set_xlabel('Region Prevalence', fontsize=13)
g.set_ylabel('Localization AUPRC', fontsize=13)
g.tick_params(labelsize=13)
g.set_title("MetalPDB", fontsize=15)
plt.tight_layout()
plt.savefig("../data/figures/enc_vs_prevalence_scatter_" + met + ".png", dpi=300, format='png')

In [None]:
met = test_metrics[0]
setup_figure(5,3.5)
# g = sns.scatterplot(data=test_df, x="rp", y="value", hue="encoder", size="ccs", sizes=(10, 100), alpha=0.3)
g = sns.lineplot(data=mean_data.reset_index(), x='rp', y='value', hue='encoder', hue_order=['COLLAPSE', 'ESM', 'AA'])
# g = sns.scatterplot(data=test_df, x="rp", y="value", hue="encoder", alpha=0.4, s=100)
g.set_xscale("log")
# g.set_yscale("log")
sns.move_legend(g, "upper left", bbox_to_anchor=(1, 1))
g.set_xlabel('Region Prevalence', fontsize=13)
g.set_ylabel('Localization AUPRC', fontsize=13)
g.tick_params(labelsize=13)
g.set_title("MetalPDB", fontsize=15)
plt.tight_layout()
plt.savefig("../data/figures/enc_vs_prevalence_line_" + met + ".png", dpi=300, format='png')

In [None]:
running_means = test_df.sort_values('mrd').groupby('encoder').apply(lambda x: movingaverage(x['value'], 20))
mean_data = []
for enc, df in test_df.sort_values('mrd').groupby('encoder'):
    mean_data.extend(list(zip(df['mrd'], running_means[enc], [enc]*len(df))))
mean_data = pd.DataFrame(mean_data, columns=['mrd', 'value', 'encoder'])

In [None]:
setup_figure(5,3.5)
g = sns.scatterplot(data=test_df, x="mrd", y="value", hue="encoder", size="ccs", sizes=(10, 100), alpha=0.3)
sns.lineplot(data=mean_data.reset_index(), x='mrd', y='value', hue='encoder', hue_order=['COLLAPSE', 'ESM', 'AA'])
# g = sns.scatterplot(data=test_df, x="rp", y="value", hue="encoder", alpha=0.4, s=100)
g.set_xscale("log")
# g.set_yscale("log")
g.set_xlim(0.1, 10)
sns.move_legend(g, "upper left", bbox_to_anchor=(1, 1))
g.set_xlabel('Mean Region Dispersion', fontsize=13)
g.set_ylabel('Segmentation AUPRC', fontsize=13)
g.tick_params(labelsize=13)
g.set_title("MetalPDB", fontsize=15)
plt.tight_layout()
plt.savefig("../data/figures/enc_vs_mrd_scatter_" + met + ".png", dpi=300, format='png')

In [None]:
setup_figure(5,3.5)
# g = sns.scatterplot(data=test_df, x="mrd", y="value", hue="encoder", size="ccs", sizes=(10, 100), alpha=0.3)
g=sns.lineplot(data=mean_data.reset_index(), x='mrd', y='value', hue='encoder', hue_order=['COLLAPSE', 'ESM', 'AA'])
# g = sns.scatterplot(data=test_df, x="rp", y="value", hue="encoder", alpha=0.4, s=100)
g.set_xscale("log")
# g.set_yscale("log")
g.set_xlim(0.1, 10)
sns.move_legend(g, "upper left", bbox_to_anchor=(1, 1))
g.set_xlabel('Mean Region Dispersion', fontsize=13)
g.set_ylabel('Segmentation AUPRC', fontsize=13)
g.tick_params(labelsize=13)
g.set_title("MetalPDB", fontsize=15)
plt.tight_layout()
plt.savefig("../data/figures/enc_vs_mrd_line_" + met + ".png", dpi=300, format='png')

In [None]:
# g = sns.scatterplot(data=test_df, x="mcs", y="value", hue="encoder", size="ccs", sizes=(40, 300), alpha=0.4)
# # g = sns.scatterplot(data=test_df, x="rp", y="value", hue="encoder", alpha=0.4, s=100)
# g.set_xscale("log")
# g.set_yscale("log")
# _ = g.set(xlabel="Mean Component Size", ylabel="Segmentation AUPRC")
# g.set_title("Impact of test-set MCS for segmentation")

In [None]:
# g = sns.scatterplot(data=test_df, x="rp", y="value", hue="encoder", size="mrd", sizes=(40, 300), alpha=0.4)
# # g = sns.scatterplot(data=test_df, x="rp", y="value", hue="encoder", alpha=0.4, s=100)
# g.set_xscale("log")
# g.set_yscale("log")
# _ = g.set(xlabel="Region Prevalence", ylabel="Segmentation AUPRC")
# g.set_title("Impact of test-set RP for segmentation")

# plotting test set results

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import utils

# combined_df = pd.read_csv("/home/k2/K2/src/outputs/k2-test/all_test_results_points.csv")
# combined_df = pd.read_csv("/dfs/scratch1/gmachi/gcp_backup/k2/k2-test/all_test_results_points.csv")

# test_df['method'] = ['Prospector']*len(test_df) #+ ['GAT+Explainer']*len(base_df)
combined_df = pd.concat([test_df, base_df])

In [None]:
combined_df.method.unique()

In [None]:
for met in test_metrics:
    subdf = combined_df[combined_df.metric == met].reset_index()
    plt.clf()
    # p=sns.color_palette("Set1")
    # colors at: https://xkcd.com/color/rgb/
    p=sns.xkcd_palette(["cerulean","lavender","celadon","sage","mahogany","goldenrod","violet","fuchsia"])
    if met in ['auprc', 'ap', 'auroc']:
        setup_figure(4.5,3)
        hue_order = ['GAT+Attention', 'Prospector', 'GAT+GNNExplainer', 'GAT+SHAP']
        ax = sns.barplot(data=subdf[subdf.regime == 'all'], palette=p, x='encoder', y='value', hue='method', hue_order=hue_order, orient='vertical', errorbar='se', capsize=0.05, errwidth=1.0, linewidth=1, edgecolor="k")
        sns.stripplot(data=subdf[subdf.regime == 'all'], palette=p, x='encoder', y='value',  hue='method', hue_order=hue_order, orient='vertical', dodge=True, alpha=0.1, linewidth=0.5, ax=ax, legend=False)
        plt.title("MetalPDB", fontsize=15)
        ax.set_ylabel(f'Localization {met}', fontsize=13)
        if met == 'ap':
            ax.set_ylabel(f'Localization Average Precision', fontsize=13)
        ax.set_xlabel('')
        ax.tick_params(labelsize=11)
        for p in ax.patches:
            y = p.get_height()
            print(y)
        plt.legend(loc=(1.01,0.7))
    elif met == 'precision':
        setup_figure(3,3)
        hue_order = ['GAT+Attention', 'Prospector', 'GAT+GNNExplainer', 'GAT+SHAP']
        ax = sns.barplot(data=subdf[subdf.regime == 'class-1'], palette=p, x='encoder', y='value', hue='method', hue_order=hue_order, orient='vertical', errorbar='se', capsize=0.05, errwidth=1.0, linewidth=1, edgecolor="k")
        sns.stripplot(data=subdf[subdf.regime == 'class-1'], palette=p, x='encoder', y='value',  hue='method', hue_order=hue_order, orient='vertical', dodge=True, alpha=0.1, linewidth=0.5, ax=ax, legend=False)
        plt.title("MetalPDB", fontsize=15)
        ax.set_ylabel('Precision', fontsize=13)
        ax.set_xlabel('')
        ax.tick_params(labelsize=11)
        for p in ax.patches:
            y = p.get_height()
            print(y)
        plt.legend(loc='upper left')
    # else:
    #     fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(6, 3), sharey=True, gridspec_kw={'wspace': 0})
    #     sns.barplot(data=subdf[subdf['regime'] == 'class-1'], x='value', y='encoder', hue='method', orient='horizontal', dodge=True, ax=ax2, errorbar='se', capsize=0.05, errwidth=1.0, linewidth=1, edgecolor="w")
    #     sns.stripplot(data=subdf[subdf['regime'] == 'class-1'], x='value', y='encoder', hue='method', orient='horizontal', dodge=True, alpha=0.1, linewidth=0.5, ax=ax2, legend=False)
    #     # ax1.yaxis.set_label_position('left')

    #     ax2.set_title('  '+'class-1', loc='left')
    #     ax2.set_ylabel('')
    #     ax2.set_yticklabels([])
    #     ax2.legend_.remove()
    
    #     sns.barplot(data=subdf[subdf['regime'] == 'all'], x='value', y='encoder', hue='method', orient='horizontal', dodge=True, ax=ax1, errorbar='se', capsize=0.05, errwidth=1.0, linewidth=1, edgecolor="k")
    #     sns.stripplot(data=subdf[subdf['regime'] == 'all'], x='value', y='encoder', hue='method', orient='horizontal', dodge=True, alpha=0.1, linewidth=0.5, ax=ax1, legend=False)
    #     ax1.legend_.remove()
    
    #     # optionally use the same scale left and right
    #     xmax = max(ax1.get_xlim()[1], ax2.get_xlim()[1])
    #     ax1.set_xlim(xmax=xmax)
    #     ax2.set_xlim(xmax=xmax)

    #     ax1.invert_xaxis()  # reverse the direction
    #     ax1.tick_params(axis='y', labelleft=True, left=True, labelright=False, right=False)
    #     ax1.set_ylabel('')
    #     ax1.set_title('all data'+'  ', loc='right')

    #     plt.legend(loc=(-1.01,1.02))
    #     fig.suptitle(met, fontsize=15)
    
    plt.tight_layout()

    plt.savefig("../data/figures/k2-vs-baseline_" + met + ".png", dpi=2000, format='png')
    plt.show()

In [None]:
subdf[subdf.regime == 'all'].groupby(['encoder','method'])['value'].mean()

In [None]:
combined_df[combined_df.regime == 'all'].groupby(['encoder','method'])['value'].sem()


In [None]:
0.640309 - 0.376692