In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
os.chdir("../src/models/results/")

In [None]:
column_names = [
    "dataset",
    "split",
    "representation",
    "model",
    "uncertainty",
    "dropout",
    "train_rho",
    "train_rmse",
    "train_mae",
    "train_r2",
    "test_rho",
    "test_rmse",
    "test_mae",
    "test_r2",
    "train_rho_unc",
    "train_p_rho_unc",
    "train_percent_coverage",
    "train_average_width_range",
    "train_miscalibration_area",
    "train_average_nll",
    "train_average_optimal_nll",
    "train_average_nll_ratio",
    "test_rho_unc",
    "test_p_rho_unc",
    "test_percent_coverage",
    "test_average_width_range",
    "test_miscalibration_area",
    "test_average_nll",
    "test_average_optimal_nll",
    "test_average_nll_ratio",
    "crossval_idx"
]

In [None]:
aav_results_df = pd.read_csv("aav_results.csv", header=None, names=column_names)
gb1_results_df = pd.read_csv("gb1_results.csv", header=None, names=column_names)
meltome_results_df = pd.read_csv("meltome_results.csv", header=None, names=column_names)

results_df = pd.concat([aav_results_df, gb1_results_df, meltome_results_df]).reset_index(drop=True)
results_df

In [None]:
# keep most recent result
results_df.drop_duplicates(subset=["dataset", 
                                   "split", 
                                   "representation", 
                                   "model", 
                                   "uncertainty", 
                                   "dropout",
                                   "crossval_idx"], keep="last", inplace=True)
results_df

In [None]:
results_df.crossval_idx.value_counts()

In [None]:
assert len(set(results_df[results_df.duplicated(subset=["dataset", 
                                "split", 
                                "representation", 
                                "model", 
                                "uncertainty",
                                "crossval_idx"])].uncertainty)) == 1 # only dropout should have duplicates now

In [None]:
results_df[results_df.duplicated(subset=["dataset", 
                                "split", 
                                "representation", 
                                "model", 
                                "uncertainty",
                                "crossval_idx"])].dataset.count()

In [None]:
# Keey dropout with lowest train miscalibration area
results_df = results_df.sort_values('train_miscalibration_area', ascending=True)
results_df.drop_duplicates(subset=["dataset", 
                                    "split", 
                                    "representation", 
                                    "model", 
                                    "uncertainty",
                                    "crossval_idx"],
                             keep="first", 
                             inplace=True)
results_df

In [None]:
results_df_mean = results_df.groupby(["dataset", 
                                    "split", 
                                    "representation", 
                                    "model", 
                                    "uncertainty"]).mean()

results_df_std = results_df.groupby(["dataset", 
                                    "split", 
                                    "representation", 
                                    "model", 
                                    "uncertainty"]).std()
results_df_std

# Prep Data

In [None]:
def get_full_model_name(row):
    if row.model == "cnn":
        if row.uncertainty == "dropout":
            name = "CNN Dropout"
        elif row.uncertainty == "ensemble":
            name = "CNN Ensemble"
        elif row.uncertainty == "evidential":
            name = "CNN Evidential"
        elif row.uncertainty == "mve":
            name = "CNN MVE"
        elif row.uncertainty == "svi":
            name = "CNN SVI"
        else:
            raise ValueError("not implemented")
    elif row.model == "gp":
        name = "GP Continuous"
    elif row.model == "ridge":
        name = "Linear Bayesian Ridge"
    else: 
        raise ValueError("not implemented")
    return name

results_df["Model"] = results_df.apply(get_full_model_name, axis=1)
results_df

In [None]:
# make names look nice for plot legend

dataset_names_dict = {
    'aav':'AAV',
    'meltome':'Meltome',
    'gb1':'GB1',
}

model_names_dict = {
    'CNN_dropout':'CNN Dropout',
    'CNN_ensemble':'CNN Ensemble',
    'CNN_evidential':'CNN Evidential',
    'CNN_mve':'CNN MVE',
    'CNN_svi':'CNN SVI',
    'linearBayesianRidge':'Linear Bayesian Ridge',
    'GPcontinuous':'GP Continuous',
}

split_names_dict = {
    'sampled':'Random',
    'seven_vs_many':'7 vs. Rest',
    'mut_des':'Sampled vs. Designed',
    'mixed_split':'Random',
    'three_vs_rest':'3 vs. Rest',
    'two_vs_rest':'2 vs. Rest',
    'one_vs_rest':'1 vs. Rest',
}

results_df['Dataset'] = results_df['dataset'].map(dataset_names_dict)
results_df['Split'] = results_df['split'].map(split_names_dict)
results_df

In [None]:
pd.set_option('display.max_rows', None)
results_df.sort_values(['train_rho'], ascending=False)

In [None]:
dataset_rank_dict = {
    'aav': 0,
    'meltome': 1,
    'gb1': 2,
}

split_rank_dict = {
    'sampled': 0,
    'mixed_split': 1,
    'seven_vs_many': 2,
    'mut_des': 3,
    'three_vs_rest': 4,
    'two_vs_rest': 5,
    'one_vs_rest': 6,
}

model_rank_dict = {
    'Linear Bayesian Ridge': 0,
    'CNN Ensemble': 1,
    'CNN MVE': 2,
    'CNN Dropout': 3,
    'GP Continuous': 4,
    'CNN Evidential': 5,
    'CNN SVI': 6,
}

rep_rank_dict = {
    'ohe': 0,
    'esm': 1,
}

results_df['dataset_rank'] = results_df['dataset'].map(dataset_rank_dict)
results_df['split_rank'] = results_df['split'].map(split_rank_dict)
results_df['model_rank'] = results_df['Model'].map(model_rank_dict)
results_df['rep_rank'] = results_df['representation'].map(rep_rank_dict)

results_df.sort_values(['model_rank', 'dataset_rank', 'split_rank', 'rep_rank'])

In [None]:
aav_df = results_df.loc[results_df['Dataset']=='AAV']
meltome_df = results_df.loc[results_df['Dataset']=='Meltome']
gb1_df = results_df.loc[results_df['Dataset']=='GB1']

# Plots

In [None]:
def change_dataset_case(dataset_name):
    if dataset_name in ["aav", "gb1"]:
        dataset_name = dataset_name.upper()
    elif dataset_name == "meltome":
        dataset_name = "Meltome"
    elif 'f' in dataset_name:
        dataset_name = dataset_name
    else:
        raise ValueError
    return dataset_name

### Rank Correlation Bar Plots

#### Figure 4, Figure S3

In [None]:
# consistent color palette for bar plots (same as above)
palette = {
        "CNN Dropout": "tab:blue",
        "CNN Ensemble": "tab:orange",
        "CNN Evidential": "tab:green",
        "CNN MVE": "tab:red",
        "CNN SVI": "tab:purple",
        "GP Continuous": "tab:brown",
        "Linear Bayesian Ridge": "tab:pink",
    }

In [None]:
results_df_ = results_df.copy()

# hack row that doesn't show on plot to make legend order match previous plots (NaN value in AAV / sampled / GP)
hack_row = {
    'dataset': 'aav',
    'split': 'sampled',
    'model': 'gp',
    'uncertainty': 'gp',
    'Model': 'GP Continuous',
    'Dataset': 'AAV',
    'Split': 'Random',
    'dataset_rank': 0,
    'split_rank': 0,
    'model_rank': 3,
}
hack_row['representation'] = 'ohe'
results_df_ = results_df_.append(hack_row, ignore_index=True)
hack_row['representation'] = 'esm'
results_df_ = results_df_.append(hack_row, ignore_index=True)

# hack rows 2 and 3 create space in between each landscape
hack_row2 = {
    'dataset': 'f1',
    'split': '',
    'model': 'gp',
    'uncertainty': 'gp',
    'Model': 'GP Continuous',
    'Dataset': 'f1',
    'Split': '',
    'dataset_rank': 0.5,
    'split_rank': 0,
    'model_rank': 3,
}
hack_row2['representation'] = 'ohe'
results_df_ = results_df_.append(hack_row2, ignore_index=True)
hack_row2['representation'] = 'esm'
results_df_ = results_df_.append(hack_row2, ignore_index=True)
hack_row3 = {
    'dataset': 'f2',
    'split': '',
    'model': 'gp',
    'uncertainty': 'gp',
    'Model': 'GP Continuous',
    'Dataset': 'f2',
    'Split': '',
    'dataset_rank': 1.5,
    'split_rank': 0,
    'model_rank': 3,
}
hack_row3['representation'] = 'ohe'
results_df_ = results_df_.append(hack_row3, ignore_index=True)
hack_row3['representation'] = 'esm'
results_df_ = results_df_.append(hack_row3, ignore_index=True)

results_df_ = results_df_.sort_values(['dataset_rank','split_rank','Model'])
results_df_['dataset_split'] = results_df_.dataset.apply(lambda name: change_dataset_case(name)) + \
                                ' / ' + results_df_.Split

for representation in ["ohe", "esm"]:
    rep_results_df_ = results_df_.loc[results_df_.representation==representation]

    #plt.axvspan(2.5, 3.5, facecolor='k', alpha=0.1) # gray bar over meltome
    sns.barplot(
        x="dataset_split", 
        y="test_rho", 
        hue="Model",
        palette=palette,
        data=rep_results_df_,
        errcolor='k',
        errwidth=1,
        #capsize=0.001,
    )
    plt.xticks(rotation=90)
    xticks = plt.gca().xaxis.get_major_ticks()
    xticks[3].set_visible(False)
    xticks[5].set_visible(False)
    plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
    plt.xlabel("Dataset / Split", fontsize=15)
    plt.ylabel(r"Test $\rho$ ($\rightarrow$)", fontsize=15)
    plt.ylim(-0.2,1.0)
    plt.xticks(fontsize=12, rotation=45, ha='right')
    plt.yticks(fontsize=12)
    plt.title(representation.upper())
    plt.savefig(f"{representation}_rho_bar.pdf", bbox_inches="tight")
    plt.show()

In [None]:
results_df_ = results_df.copy()

# hack row that doesn't show on plot to make legend order match previous plots (NaN value in AAV / sampled / GP)
hack_row = {
    'dataset': 'aav',
    'split': 'sampled',
    'model': 'gp',
    'uncertainty': 'gp',
    'Model': 'GP Continuous',
    'Dataset': 'AAV',
    'Split': 'Random',
    'dataset_rank': 0,
    'split_rank': 0,
    'model_rank': 3,
}
hack_row['representation'] = 'ohe'
results_df_ = results_df_.append(hack_row, ignore_index=True)
hack_row['representation'] = 'esm'
results_df_ = results_df_.append(hack_row, ignore_index=True)

# hack rows 2 and 3 create space in between each landscape
hack_row2 = {
    'dataset': 'f1',
    'split': '',
    'model': 'gp',
    'uncertainty': 'gp',
    'Model': 'GP Continuous',
    'Dataset': 'f1',
    'Split': '',
    'dataset_rank': 0.5,
    'split_rank': 0,
    'model_rank': 3,
}
hack_row2['representation'] = 'ohe'
results_df_ = results_df_.append(hack_row2, ignore_index=True)
hack_row2['representation'] = 'esm'
results_df_ = results_df_.append(hack_row2, ignore_index=True)
hack_row3 = {
    'dataset': 'f2',
    'split': '',
    'model': 'gp',
    'uncertainty': 'gp',
    'Model': 'GP Continuous',
    'Dataset': 'f2',
    'Split': '',
    'dataset_rank': 1.5,
    'split_rank': 0,
    'model_rank': 3,
}
hack_row3['representation'] = 'ohe'
results_df_ = results_df_.append(hack_row3, ignore_index=True)
hack_row3['representation'] = 'esm'
results_df_ = results_df_.append(hack_row3, ignore_index=True)

results_df_ = results_df_.sort_values(['dataset_rank','split_rank','Model'])
results_df_['dataset_split'] = results_df_.dataset.apply(lambda name: change_dataset_case(name)) + \
                                ' / ' + results_df_.Split

for representation in ["ohe", "esm"]:
    rep_results_df_ = results_df_.loc[results_df_.representation==representation]
    
    # plt.axvspan(2.5, 3.5, facecolor='k', alpha=0.1) # gray bar over meltome
    sns.barplot(
        x="dataset_split", 
        y="test_rho_unc", 
        hue="Model",
        palette=palette,
        data=rep_results_df_,
        errcolor='k',
        errwidth=1,
    )
    plt.xticks(rotation=90)
    xticks = plt.gca().xaxis.get_major_ticks()
    xticks[3].set_visible(False)
    xticks[5].set_visible(False)
    plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
    plt.xlabel("Dataset / Split", fontsize=15)
    plt.ylabel(r"Test $\rho_{unc}$ ($\rightarrow$)", fontsize=15)
    plt.ylim(-0.7,0.8)
    plt.xticks(fontsize=12, rotation=45, ha='right')
    plt.yticks(fontsize=12)
    plt.title(representation.upper())
    plt.savefig(f"{representation}_rho_unc_bar.pdf", bbox_inches="tight")
    plt.show()