In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
os.chdir("../src/models/results/")

In [None]:
column_names = [
    "dataset",
    "split",
    "representation",
    "model",
    "uncertainty",
    "dropout",
    "train_rho",
    "train_rmse",
    "train_mae",
    "train_r2",
    "test_rho",
    "test_rmse",
    "test_mae",
    "test_r2",
    "train_rho_unc",
    "train_p_rho_unc",
    "train_percent_coverage",
    "train_average_width_range",
    "train_miscalibration_area",
    "train_average_nll",
    "train_average_optimal_nll",
    "train_average_nll_ratio",
    "test_rho_unc",
    "test_p_rho_unc",
    "test_percent_coverage",
    "test_average_width_range",
    "test_miscalibration_area",
    "test_average_nll",
    "test_average_optimal_nll",
    "test_average_nll_ratio",
    "crossval_idx"
]

In [None]:
aav_results_df = pd.read_csv("aav_results.csv", header=None, names=column_names)
gb1_results_df = pd.read_csv("gb1_results.csv", header=None, names=column_names)
meltome_results_df = pd.read_csv("meltome_results.csv", header=None, names=column_names)

results_df = pd.concat([aav_results_df, gb1_results_df, meltome_results_df]).reset_index(drop=True)
results_df

In [None]:
# keep most recent result if more than one is present
results_df.drop_duplicates(subset=["dataset", 
                                   "split", 
                                   "representation", 
                                   "model", 
                                   "uncertainty", 
                                   "dropout",
                                   "crossval_idx"], keep="last", inplace=True)
results_df

In [None]:
# Keey dropout with lowest train miscalibration area
results_df = results_df.sort_values('train_miscalibration_area', ascending=True)
results_df.drop_duplicates(subset=["dataset", 
                                    "split", 
                                    "representation", 
                                    "model", 
                                    "uncertainty",
                                    "crossval_idx"],
                             keep="first", 
                             inplace=True)
results_df

In [None]:
results_df_mean = results_df.groupby(["dataset", 
                                    "split", 
                                    "representation", 
                                    "model", 
                                    "uncertainty"]).mean()

results_df_std = results_df.groupby(["dataset", 
                                    "split", 
                                    "representation", 
                                    "model", 
                                    "uncertainty"]).std()
results_df_std

In [None]:
results_df = results_df_mean.reset_index()

# Prep Data

In [None]:
def get_full_model_name(row):
    if row.model == "cnn":
        if row.uncertainty == "dropout":
            name = "CNN Dropout"
        elif row.uncertainty == "ensemble":
            name = "CNN Ensemble"
        elif row.uncertainty == "evidential":
            name = "CNN Evidential"
        elif row.uncertainty == "mve":
            name = "CNN MVE"
        elif row.uncertainty == "svi":
            name = "CNN SVI"
        else:
            raise ValueError("not implemented")
    elif row.model == "gp":
        name = "GP Continuous"
    elif row.model == "ridge":
        name = "Linear Bayesian Ridge"
    else: 
        raise ValueError("not implemented")
    return name

results_df["Model"] = results_df.apply(get_full_model_name, axis=1)
results_df

In [None]:
# make names look nice for plot legend

dataset_names_dict = {
    'aav':'AAV',
    'meltome':'Meltome',
    'gb1':'GB1',
}

model_names_dict = {
    'CNN_dropout':'CNN Dropout',
    'CNN_ensemble':'CNN Ensemble',
    'CNN_evidential':'CNN Evidential',
    'CNN_mve':'CNN MVE',
    'CNN_svi':'CNN SVI',
    'linearBayesianRidge':'Linear Bayesian Ridge',
    'GPcontinuous':'GP Continuous',
}

split_names_dict = {
    'sampled':'Random',
    'seven_vs_many':'7 vs. Rest',
    'mut_des':'Sampled vs. Designed',
    'mixed_split':'Random',
    'three_vs_rest':'3 vs. Rest',
    'two_vs_rest':'2 vs. Rest',
    'one_vs_rest':'1 vs. Rest',
}

results_df['Dataset'] = results_df['dataset'].map(dataset_names_dict)
results_df['Split'] = results_df['split'].map(split_names_dict)
results_df

In [None]:
pd.set_option('display.max_rows', None)
results_df.sort_values(['train_rho'], ascending=False)

In [None]:
dataset_rank_dict = {
    'aav': 0,
    'meltome': 1,
    'gb1': 2,
}

split_rank_dict = {
    'sampled': 0,
    'mixed_split': 1,
    'seven_vs_many': 2,
    'mut_des': 3,
    'three_vs_rest': 4,
    'two_vs_rest': 5,
    'one_vs_rest': 6,
}

model_rank_dict = {
    'Linear Bayesian Ridge': 0,
    'CNN Ensemble': 1,
    'CNN MVE': 2,
    'CNN Dropout': 3,
    'GP Continuous': 4,
    'CNN Evidential': 5,
    'CNN SVI': 6,
}

rep_rank_dict = {
    'ohe': 0,
    'esm': 1,
}

results_df['dataset_rank'] = results_df['dataset'].map(dataset_rank_dict)
results_df['split_rank'] = results_df['split'].map(split_rank_dict)
results_df['model_rank'] = results_df['Model'].map(model_rank_dict)
results_df['rep_rank'] = results_df['representation'].map(rep_rank_dict)

results_df.sort_values(['model_rank', 'dataset_rank', 'split_rank', 'rep_rank'])

In [None]:
aav_df = results_df.loc[results_df['Dataset']=='AAV']
meltome_df = results_df.loc[results_df['Dataset']=='Meltome']
gb1_df = results_df.loc[results_df['Dataset']=='GB1']

# Plots

In [None]:
def change_dataset_case(dataset_name):
    if dataset_name in ["aav", "gb1"]:
        dataset_name = dataset_name.upper()
    elif dataset_name == "meltome":
        dataset_name = "Meltome"
    elif 'f' in dataset_name:
        dataset_name = dataset_name
    else:
        raise ValueError
    return dataset_name

### OHE vs. ESM Summary Heatmaps

#### Figure S4

In [None]:
col_to_label_dict = {
    'test_rho': r"Test $\rho$",
    'test_rho_unc': r"Test $\rho_{unc}$",
               }

figs, axs = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))


# rho
ohe_vs_esm_df = pd.pivot_table(results_df, values='test_rho', index=['Dataset','Split','Model'],
                columns=['representation'], aggfunc=np.mean)

ohe_vs_esm_df['diff'] = ohe_vs_esm_df.esm - ohe_vs_esm_df.ohe

print(f"ESM performs better than OHE {sum(ohe_vs_esm_df['diff']>0)} out of {ohe_vs_esm_df['diff'].count()} times for test_rho")

df_pivot = ohe_vs_esm_df.reset_index().pivot(index=['Dataset','Split'], 
                               columns='Model', 
                               values='diff')

sns.heatmap(df_pivot, 
            annot=True, 
            linewidth=0.5,
            cmap='vlag',
            ax=axs[0],
           cbar_kws={
               'label': f"ESM - OHE Difference ({col_to_label_dict['test_rho']})",
           }, 
           )


# rho_unc
ohe_vs_esm_df = pd.pivot_table(results_df, values='test_rho_unc', index=['Dataset','Split','Model'],
                columns=['representation'], aggfunc=np.mean)

ohe_vs_esm_df['diff'] = ohe_vs_esm_df.esm - ohe_vs_esm_df.ohe

print(f"ESM performs better than OHE {sum(ohe_vs_esm_df['diff']>0)} out of {ohe_vs_esm_df['diff'].count()} times for test_rho_unc")

df_pivot = ohe_vs_esm_df.reset_index().pivot(index=['Dataset','Split'], 
                               columns='Model', 
                               values='diff')

sns.heatmap(df_pivot, 
            annot=True, 
            linewidth=0.5,
            cmap='vlag',
            ax=axs[1],
           cbar_kws={
               'label': f"ESM - OHE Difference ({col_to_label_dict['test_rho_unc']})",
           }, 
           )
    
    
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.5, hspace=1)  
plt.savefig(f"esm_minus_ohe_heatmap.pdf", bbox_inches="tight")
plt.show()

### Accuracy vs. Calibration, Sharpness vs. Dispersion

In [None]:
def plot_two_vars(x, y):
    
    palette = {
            "CNN Dropout": "tab:blue",
            "CNN Ensemble": "tab:orange",
            "CNN Evidential": "tab:green",
            "CNN MVE": "tab:red",
            "CNN SVI": "tab:purple",
            "GP Continuous": "tab:brown",
            "Linear Bayesian Ridge": "tab:pink",
        }
    
    for landscape, landscape_df_ in zip(['GB1','AAV','Meltome'], [gb1_df, aav_df, meltome_df]):
        for representation in ["ohe", "esm"]:
            landscape_df_rep = landscape_df_.loc[landscape_df_.representation==representation]
            landscape_df = landscape_df_rep.sort_values(['split_rank','Model'])
            
            if y.endswith("percent_coverage"):
                plt.axhline(0.95, ls='--', c='k', lw=1)
                
            sns.scatterplot(data=landscape_df,
                            x=x,
                            y=y,
                            hue='Model',
                            style='Split',
                            palette=palette,
                            s=100,
                            alpha=0.8)

            if x.endswith("rmse"):
                plt.xlabel(r"RMSE ($\leftarrow$)", fontsize=15)
                x_ = "rmse"
            elif x.endswith("average_width_range"):
                plt.xlabel(r"Average Width / Range ($\leftarrow$)", fontsize=15)
                x_ = "width"
                plt.xscale('log')
            else:
                plt.xlabel(x, fontsize=15)
                x_ = x

            if y.endswith("percent_coverage"):
                plt.ylabel(r"Percent Coverage ($\rightarrow$)", fontsize=15)
                y_ = "coverage"
                plt.ylim((-0.05,1.05))
            elif y.endswith("miscalibration_area"):
                plt.ylabel(r"Miscalibration Area ($\leftarrow$)", fontsize=15)
                y_ = "area"
                plt.ylim((-0.05,0.55))
            else:
                plt.ylabel(y, fontsize=15)
                y_ = y
                
            plt.title(f"{landscape} ({representation.upper()})", fontsize=15)

            plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')            

            plt.xticks(fontsize=12)
            plt.yticks(fontsize=12)

            plt.savefig(f"{landscape.lower()}_{representation}_{y_}_{x_}.pdf", bbox_inches="tight")

            plt.show()

    return

#### Figure 2, Figure S1

In [None]:
plot_two_vars('test_rmse', 'test_miscalibration_area')

#### Figure 3, Figure S2

In [None]:
plot_two_vars('test_average_width_range', 'test_percent_coverage')

# LaTeX Tables

#### Tables S1-S22

In [None]:
col_to_table_list = [
    'test_rmse', 'test_mae', # can't compare across datasets with different units
    'test_r2', 
    'test_rho', 'test_rho_unc',
    'test_percent_coverage', 'test_average_width_range',
    'test_miscalibration_area', 
    'test_average_nll', 'test_average_optimal_nll', 'test_average_nll_ratio', 
]

col_to_label_dict = {
    'test_r2': r"Test $R^2$ ($\rightarrow$)", 
    'test_rho': r"Test $\rho$ ($\rightarrow$)", 
    'test_rho_unc': r"Test $\rho_{unc}$ ($\rightarrow$)",
    'test_percent_coverage': r"Test % Coverage ($\rightarrow$)", 
    'test_average_width_range': r"Test $4\sigma/R$ ($\leftarrow$)",
    'test_miscalibration_area': r"Test Miscalibration Area ($\leftarrow$)", 
    'test_average_nll': r"Test $\overline{NLL}$ ($\leftarrow$)", 
    'test_average_optimal_nll': r"Test $\overline{NLL_{opt}}$", 
    'test_average_nll_ratio': r"Test $\overline{NLL}$ / $\overline{NLL_{opt}} Ratio$ ($\leftarrow$)",
    'test_rmse': r"Test RMSE ($\leftarrow$)", 
    'test_mae': r"Test MAE ($\leftarrow$)", 
}

col_to_cbar = {
    'test_r2': {'fmt': ".3f", "norm": None}, 
    'test_rho': {'fmt': ".3f", "norm": None}, 
    'test_rho_unc': {'fmt': ".3f", "norm": None},
    'test_percent_coverage': {'fmt': ".3f", "norm": None}, 
    'test_average_width_range': {'fmt': ".1e", "norm": LogNorm(), "annot_kws": {"size":8}},
    'test_miscalibration_area': {'fmt': ".3f", "norm": None}, 
    'test_average_nll': {'fmt': ".1e", "norm": LogNorm(), "annot_kws": {"size":8}}, 
    'test_average_optimal_nll': {'fmt': ".3f", "norm": None}, 
    'test_average_nll_ratio': {'fmt': ".1e", "norm": LogNorm(), "annot_kws": {"size":8}},
    'test_rmse': {'fmt': ".3f", "norm": None}, 
    'test_mae': {'fmt': ".3f", "norm": None}, 
}

for rep in ['ohe','esm']:
    for col in col_to_table_list: 
        df_pivot = results_df.loc[results_df.representation==rep].pivot(index=['Dataset','Split'], 
                                                                           columns='Model', 
                                                                           values=col)
        #print(rep, col)
        print()
        print(df_pivot.to_latex(float_format="{:0.3f}".format))
        print()