In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pandas as pd
import numpy as np
import statistics
from tqdm import tqdm

import matplotlib
import matplotlib.pyplot as plt
plt.style.use('seaborn-paper')

matplotlib.rc('font', family='sans-serif')
matplotlib.rc('font', serif='Arial')
matplotlib.rc('text', usetex='false')

In [None]:
DELQSAR_ROOT = os.getcwd() + '/../../'

if not os.path.isdir('multiple_thresholds_bin_plots'):
    os.mkdir('multiple_thresholds_bin_plots')
def pathify(fname):
    return os.path.join('multiple_thresholds_bin_plots', fname)

In [None]:
def get_avg_AUCs_stdevs(dataset, model_type, metric):
    all_AUCs = [np.array(df_data[
        df_data['dataset'].isin([str(dataset)]) & 
        df_data['model type'].isin([str(model_type)]) & 
        df_data['top percent'].isin([top_percent]) 
    ][str(metric)]) for top_percent in top_percents]
    avg_AUCs = [statistics.mean(AUCs) for AUCs in all_AUCs] 
    stdevs = [statistics.stdev(AUCs) for AUCs in all_AUCs]
    return avg_AUCs, stdevs

In [None]:
def make_plot_AUCs(img_name, dataset, metric, model_name, NLL_AUCs, NLL_stdevs, pt_AUCs, pt_stdevs, xsize, ysize,
                   ylabel=True, xlabel=True, title=True):
    fig = plt.figure(figsize=(xsize, ysize), dpi=300)
    plt.errorbar(top_percents, random_guess_AUCs, yerr=random_guess_stdevs, ecolor='k', linewidth=1,
                 color='#7f7f7f', elinewidth=0.5, capsize=1, capthick=0.5, label='random guess')
    plt.errorbar(top_percents, pt_AUCs, yerr=pt_stdevs, ecolor='k', color='#ff7f0e', linewidth=1,
                 elinewidth=0.5, capsize=1, capthick=0.5, label=f'{model_name} pt')
    plt.errorbar(top_percents, NLL_AUCs, yerr=NLL_stdevs, ecolor='k', color='#1f77b4', linewidth=1,
                 elinewidth=0.5, capsize=1, capthick=0.5, label=f'{model_name}')
    
    fig.canvas.draw()
    ax = plt.gca()
    ax.grid(zorder=1)
    
    if ylabel:
        ax.set_ylabel(str(metric), fontsize=8)
    ax.set_xscale('log')
    if xlabel:
        ax.tick_params(labelsize=8)
        ax.set_xlabel('top percent', fontsize=8)
    else:
        ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    if title:
        ax.set_title(str(dataset), fontsize=8)
    handles, labels = ax.get_legend_handles_labels()
    order = [2,1,0]
    if 'PR' in metric:
        plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order], fontsize=6, loc='upper left')
    else:
        plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order], fontsize=6, loc='center left')
    plt.tight_layout()
    plt.savefig(pathify(img_name))
    plt.show()

In [None]:
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'AUCs_multiple_thresholds.csv'))

# PR AUC

In [None]:
# DD1S CAIX
top_percents = df_data[df_data['dataset'].isin(['DD1S CAIX'])]['top percent'].unique()
random_guess_AUCs, random_guess_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'Random guess', 'PR AUC')
OH_FFNN, OH_FFNN_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'OH-FFNN', 'PR AUC')
FP_FFNN, FP_FFNN_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'FP-FFNN', 'PR AUC')
D_MPNN, D_MPNN_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'D-MPNN', 'PR AUC')
OH_FFNN_pt, OH_FFNN_pt_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'OH-FFNN pt', 'PR AUC')
FP_FFNN_pt, FP_FFNN_pt_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'FP-FFNN pt', 'PR AUC')
D_MPNN_pt, D_MPNN_pt_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'D-MPNN pt', 'PR AUC')

In [None]:
make_plot_AUCs('DD1S_CAIX_OH-FFNN_20_thresholds_PR_AUCs.png', 'DD1S CAIX', 'PR AUC', 'OH-FFNN', OH_FFNN, OH_FFNN_stdevs, OH_FFNN_pt, OH_FFNN_pt_stdevs, 2.33, 1.9, xlabel=False)

In [None]:
make_plot_AUCs('DD1S_CAIX_FP-FFNN_20_thresholds_PR_AUCs.png', 'DD1S CAIX', 'PR AUC', 'FP-FFNN', FP_FFNN, FP_FFNN_stdevs, FP_FFNN_pt, FP_FFNN_pt_stdevs, 2.33, 1.75, xlabel=False, title=False)

In [None]:
make_plot_AUCs('DD1S_CAIX_D-MPNN_20_thresholds_PR_AUCs.png', 'DD1S CAIX', 'PR AUC', 'D-MPNN', D_MPNN, D_MPNN_stdevs, D_MPNN_pt, D_MPNN_pt_stdevs, 2.33, 2.2, title=False)

In [None]:
# triazine sEH
top_percents = df_data[df_data['dataset'].isin(['triazine sEH'])]['top percent'].unique()
random_guess_AUCs, random_guess_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'Random guess', 'PR AUC')
OH_FFNN, OH_FFNN_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'OH-FFNN', 'PR AUC')
FP_FFNN, FP_FFNN_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'FP-FFNN', 'PR AUC')
D_MPNN, D_MPNN_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'D-MPNN', 'PR AUC')
OH_FFNN_pt, OH_FFNN_pt_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'OH-FFNN pt', 'PR AUC')
FP_FFNN_pt, FP_FFNN_pt_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'FP-FFNN pt', 'PR AUC')
D_MPNN_pt, D_MPNN_pt_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'D-MPNN pt', 'PR AUC')

In [None]:
make_plot_AUCs('triazine_sEH_OH-FFNN_20_thresholds_PR_AUCs.png', 'triazine sEH', 'PR AUC', 'OH-FFNN', OH_FFNN, OH_FFNN_stdevs, OH_FFNN_pt, OH_FFNN_pt_stdevs, 2.2, 1.9, xlabel=False, ylabel=False)

In [None]:
make_plot_AUCs('triazine_sEH_FP-FFNN_20_thresholds_PR_AUCs.png', 'triazine sEH', 'PR AUC', 'FP-FFNN', FP_FFNN, FP_FFNN_stdevs, FP_FFNN_pt, FP_FFNN_pt_stdevs, 2.2, 1.75, xlabel=False, ylabel=False, title=False)

In [None]:
make_plot_AUCs('triazine_sEH_D-MPNN_20_thresholds_PR_AUCs.png', 'triazine sEH', 'PR AUC', 'D-MPNN', D_MPNN, D_MPNN_stdevs, D_MPNN_pt, D_MPNN_pt_stdevs, 2.23, 2.2, ylabel=False, title=False)

In [None]:
# triazine SIRT2
top_percents = df_data[df_data['dataset'].isin(['triazine SIRT2'])]['top percent'].unique()
random_guess_AUCs, random_guess_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'Random guess', 'PR AUC')
OH_FFNN, OH_FFNN_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'OH-FFNN', 'PR AUC')
FP_FFNN, FP_FFNN_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'FP-FFNN', 'PR AUC')
D_MPNN, D_MPNN_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'D-MPNN', 'PR AUC')
OH_FFNN_pt, OH_FFNN_pt_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'OH-FFNN pt', 'PR AUC')
FP_FFNN_pt, FP_FFNN_pt_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'FP-FFNN pt', 'PR AUC')
D_MPNN_pt, D_MPNN_pt_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'D-MPNN pt', 'PR AUC')

In [None]:
make_plot_AUCs('triazine_SIRT2_OH-FFNN_20_thresholds_PR_AUCs.png', 'triazine SIRT2', 'PR AUC', 'OH-FFNN', OH_FFNN, OH_FFNN_stdevs, OH_FFNN_pt, OH_FFNN_pt_stdevs, 2.2, 1.9, xlabel=False, ylabel=False)

In [None]:
make_plot_AUCs('triazine_SIRT2_FP-FFNN_20_thresholds_PR_AUCs.png', 'triazine SIRT2', 'PR AUC', 'FP-FFNN', FP_FFNN, FP_FFNN_stdevs, FP_FFNN_pt, FP_FFNN_pt_stdevs, 2.2, 1.75, xlabel=False, ylabel=False, title=False)

In [None]:
make_plot_AUCs('triazine_SIRT2_D-MPNN_20_thresholds_PR_AUCs.png', 'triazine SIRT2', 'PR AUC', 'D-MPNN', D_MPNN, D_MPNN_stdevs, D_MPNN_pt, D_MPNN_pt_stdevs, 2.23, 2.2, ylabel=False, title=False)

# ROC AUC

In [None]:
# DD1S CAIX
top_percents = df_data[df_data['dataset'].isin(['DD1S CAIX'])]['top percent'].unique()
random_guess_AUCs, random_guess_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'Random guess', 'ROC AUC')
OH_FFNN, OH_FFNN_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'OH-FFNN', 'ROC AUC')
FP_FFNN, FP_FFNN_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'FP-FFNN', 'ROC AUC')
D_MPNN, D_MPNN_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'D-MPNN', 'ROC AUC')
OH_FFNN_pt, OH_FFNN_pt_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'OH-FFNN pt', 'ROC AUC')
FP_FFNN_pt, FP_FFNN_pt_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'FP-FFNN pt', 'ROC AUC')
D_MPNN_pt, D_MPNN_pt_stdevs = get_avg_AUCs_stdevs('DD1S CAIX', 'D-MPNN pt', 'ROC AUC')

In [None]:
make_plot_AUCs('DD1S_CAIX_OH-FFNN_20_thresholds_ROC_AUCs.png', 'DD1S CAIX', 'ROC AUC', 'OH-FFNN', OH_FFNN, OH_FFNN_stdevs, OH_FFNN_pt, OH_FFNN_pt_stdevs, 2.33, 1.9, xlabel=False)

In [None]:
make_plot_AUCs('DD1S_CAIX_FP-FFNN_20_thresholds_ROC_AUCs.png', 'DD1S CAIX', 'ROC AUC', 'FP-FFNN', FP_FFNN, FP_FFNN_stdevs, FP_FFNN_pt, FP_FFNN_pt_stdevs, 2.33, 1.75, xlabel=False, title=False)

In [None]:
make_plot_AUCs('DD1S_CAIX_D-MPNN_20_thresholds_ROC_AUCs.png', 'DD1S CAIX', 'ROC AUC', 'D-MPNN', D_MPNN, D_MPNN_stdevs, D_MPNN_pt, D_MPNN_pt_stdevs, 2.33, 2.2, title=False)

In [None]:
# triazine sEH
top_percents = df_data[df_data['dataset'].isin(['triazine sEH'])]['top percent'].unique()
random_guess_AUCs, random_guess_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'Random guess', 'ROC AUC')
OH_FFNN, OH_FFNN_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'OH-FFNN', 'ROC AUC')
FP_FFNN, FP_FFNN_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'FP-FFNN', 'ROC AUC')
D_MPNN, D_MPNN_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'D-MPNN', 'ROC AUC')
OH_FFNN_pt, OH_FFNN_pt_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'OH-FFNN pt', 'ROC AUC')
FP_FFNN_pt, FP_FFNN_pt_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'FP-FFNN pt', 'ROC AUC')
D_MPNN_pt, D_MPNN_pt_stdevs = get_avg_AUCs_stdevs('triazine sEH', 'D-MPNN pt', 'ROC AUC')

In [None]:
make_plot_AUCs('triazine_sEH_OH-FFNN_20_thresholds_ROC_AUCs.png', 'triazine sEH', 'ROC AUC', 'OH-FFNN', OH_FFNN, OH_FFNN_stdevs, OH_FFNN_pt, OH_FFNN_pt_stdevs, 2.2, 1.9, xlabel=False, ylabel=False)

In [None]:
make_plot_AUCs('triazine_sEH_FP-FFNN_20_thresholds_ROC_AUCs.png', 'triazine sEH', 'ROC AUC', 'FP-FFNN', FP_FFNN, FP_FFNN_stdevs, FP_FFNN_pt, FP_FFNN_pt_stdevs, 2.2, 1.75, xlabel=False, ylabel=False, title=False)

In [None]:
make_plot_AUCs('triazine_sEH_D-MPNN_20_thresholds_ROC_AUCs.png', 'triazine sEH', 'ROC AUC', 'D-MPNN', D_MPNN, D_MPNN_stdevs, D_MPNN_pt, D_MPNN_pt_stdevs, 2.23, 2.2, ylabel=False, title=False)

In [None]:
# triazine SIRT2
top_percents = df_data[df_data['dataset'].isin(['triazine SIRT2'])]['top percent'].unique()
random_guess_AUCs, random_guess_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'Random guess', 'ROC AUC')
OH_FFNN, OH_FFNN_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'OH-FFNN', 'ROC AUC')
FP_FFNN, FP_FFNN_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'FP-FFNN', 'ROC AUC')
D_MPNN, D_MPNN_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'D-MPNN', 'ROC AUC')
OH_FFNN_pt, OH_FFNN_pt_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'OH-FFNN pt', 'ROC AUC')
FP_FFNN_pt, FP_FFNN_pt_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'FP-FFNN pt', 'ROC AUC')
D_MPNN_pt, D_MPNN_pt_stdevs = get_avg_AUCs_stdevs('triazine SIRT2', 'D-MPNN pt', 'ROC AUC')

In [None]:
make_plot_AUCs('triazine_SIRT2_OH-FFNN_20_thresholds_ROC_AUCs.png', 'triazine SIRT2', 'ROC AUC', 'OH-FFNN', OH_FFNN, OH_FFNN_stdevs, OH_FFNN_pt, OH_FFNN_pt_stdevs, 2.2, 1.9, xlabel=False, ylabel=False)

In [None]:
make_plot_AUCs('triazine_SIRT2_FP-FFNN_20_thresholds_ROC_AUCs.png', 'triazine SIRT2', 'ROC AUC', 'FP-FFNN', FP_FFNN, FP_FFNN_stdevs, FP_FFNN_pt, FP_FFNN_pt_stdevs, 2.2, 1.75, xlabel=False, ylabel=False, title=False)

In [None]:
make_plot_AUCs('triazine_SIRT2_D-MPNN_20_thresholds_ROC_AUCs.png', 'triazine SIRT2', 'ROC AUC', 'D-MPNN', D_MPNN, D_MPNN_stdevs, D_MPNN_pt, D_MPNN_pt_stdevs, 2.23, 2.2, ylabel=False, title=False)