# OOD figures
Create figures of OOD experiments.

## Utility functions

In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from tqdm import trange
import random
import math
from scipy import interp
import statistics

from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, precision_recall_curve, roc_curve, auc
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

from matplotlib import collections
from matplotlib import colors
from numpy.random import normal

In [None]:
metrics = [
    'FPR (95% TPR)',
    'AUROC',
    'AUPR',
    'Detection error'
]

def fpr_at_x_tpr(true_values, predicted_values, cutoff=0.95):
    """Calculate FPR @X% TPR"""

    tprs = []
    base_fpr = np.linspace(0, 1, 101)
    auc = []
    fpr, tpr, thresholds = roc_curve(true_values, predicted_values)

    for fp, tp, threshold in zip(fpr, tpr, thresholds):
        if tp >= cutoff:
            return fp

def detection_error_x_tpr(true_values, predicted_values, cutoff=0.95):
    """Calculate detection error Pe.
    Measures the misclassification probability when TPR is 95%.
    Pe = 0.5(1 - TPR) + 0.5FPR
    """

    tprs = []
    base_fpr = np.linspace(0, 1, 101)
    auc = []
    fpr, tpr, thresholds = roc_curve(true_values, predicted_values)

    for fp, tp, threshold in zip(fpr, tpr, thresholds):
        if tp >= cutoff:
            return 0.5 * (1 - tp) + 0.5 * fp

def pr_auc(y_true, y_prob):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    pr_auc = auc(recall, precision)
    return pr_auc

def get_scores(y_true, y_prob):
    """
    Compute a df with all classification metrics and respective scores.
    """

    scores = [
        fpr_at_x_tpr(y_true, y_prob),
        roc_auc_score(y_true, y_prob),
        pr_auc(y_true, y_prob),
        detection_error_x_tpr(y_true, y_prob)
    ]
    
    df = pd.DataFrame(data={'score': scores, 'metrics': metrics})
    return df

In [None]:
import os
login = os.getlogin( )

DATA_BASE = f"/home/{login}/Git/tcr/data/"
RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.ood/results/"
FIGURES_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.ood/figures/"

In [None]:
predictions_files = [
    ('MSP', [pd.read_csv(RESULTS_BASE + f"mhc.mvib.msp.aoe.rep-{i}.csv") for i in range(5)]),
    ('ODIN (ε=0.001, T=1000)', [pd.read_csv(RESULTS_BASE + f"mhc.mvib.odin.aoe.T-1000.epsilon-0.001.rep-{i}.csv") for i in range(5)]),
    ('AVIB-R', [pd.read_csv(RESULTS_BASE + f"mhc.mvib.kld-joint.aoe.rep-{i}.csv") for i in range(5)]),
    ('AVIB-Maha', [pd.read_csv(RESULTS_BASE + f"mhc.mvib.aoe.maha.layer-0.epsilon-0.rep-{i}.csv") for i in range(5)]),
]

In [None]:
import matplotlib.pyplot as plt
SMALL_SIZE = 16
MEDIUM_SIZE = 18
BIGGER_SIZE = 24

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

import seaborn as sns

plt.style.use('seaborn-white')
sns.set_palette('magma', len(predictions_files))


def make_roc_curve_plot(ax, true_values_list, predicted_values_list, cutoff, model_label):
    """Calculate ROC and AUC from lists of true and predicted values and draw."""

    tprs = []
    base_fpr = np.linspace(0, 1, 101)
    auc = []
    for true_values, predicted_values in zip(true_values_list, predicted_values_list):
        fpr, tpr, thresholds = roc_curve(true_values, predicted_values)
        auc.append(roc_auc_score(true_values, predicted_values))
        tpr = interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)
    
    tprs = np.array(tprs)
    mean_tprs = tprs.mean(axis=0)

    ax.plot(base_fpr, mean_tprs, label=model_label,linewidth=3)
    
    ax.set_title("ID: Human TCR set | OOD: Human MHC set", y=1.04)
    ax.set_ylabel("True Positive Rate")
    ax.set_xlabel("False Positive Rate")
    ax.legend()
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.0])
    
#     for fp, tp, threshold in zip(fpr, tpr, thresholds):
#         if threshold < cutoff:
#             ax.plot(fp, tp, marker='o', markersize=10, color='grey', alpha=0.75)
#             break


def make_uninformative_roc(ax):
    ax.plot([0, 1], [0, 1], c='grey', linestyle='dashed', alpha=0.5,linewidth=3)
    ax.legend()
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend(facecolor="white")


fig, ax = plt.subplots()

for predictions_file in predictions_files:
    true_values_list, predicted_values_list = [], []
    for i in range(5):
        prediction_df = predictions_file[1][i]
        true_values_list.append(prediction_df['sign'].to_numpy())
        predicted_values_list.append(prediction_df[f'prediction_{i}'].to_numpy())

    make_roc_curve_plot(
        ax, 
        true_values_list, 
        predicted_values_list, 
        0.9,
        predictions_file[0]
    )
make_uninformative_roc(ax)
ax.tick_params(axis='x', pad=15)
# ax.legend(loc='best')
ax.legend(loc='best')
legend = plt.legend(frameon = 1)
frame = legend.get_frame()
frame.set_facecolor('white')
ax.grid(axis='y')
ax.grid(axis='x')
plt.savefig(FIGURES_BASE + "roc.mhc.svg", format='svg', dpi=300, bbox_inches='tight')
plt.savefig(FIGURES_BASE + "roc.mhc.png", format='png', dpi=300, bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def make_prc_curve_plot(ax, true_values, predicted_values, model_label):
    """Calculate PRC and AUC from lists of true and predicted values and draw."""
    
    reversed_mean_precision = 0.0
    base_recall = np.linspace(1, 0, 100)
    auc = []
    
    for true_values, predicted_values in zip(true_values_list, predicted_values_list):
        precision, recall, thresholds = precision_recall_curve(true_values, predicted_values)
        auc.append(pr_auc(true_values, predicted_values))
        reversed_recall = np.fliplr([recall])[0]
        reversed_precision = np.fliplr([precision])[0]
        reversed_mean_precision += interp(base_recall, reversed_recall, reversed_precision)
    
    reversed_mean_precision /= 5
    
    ax.plot(base_recall, reversed_mean_precision, label=model_label, linewidth=3)
    
    ax.set_title("ID: Human TCR set | OOD: Human MHC set", y=1.04)
    ax.set_ylabel("Precision")
    ax.set_xlabel("Recall")
    ax.legend()
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.01])


fig, ax = plt.subplots()

for predictions_file in predictions_files:
    true_values_list, predicted_values_list = [], []
    for i in range(5):
        prediction_df = predictions_file[1][i]
        true_values_list.append(prediction_df['sign'].to_numpy())
        predicted_values_list.append(prediction_df[f'prediction_{i}'].to_numpy())

    make_prc_curve_plot(
        ax, 
        true_values_list, 
        predicted_values_list, 
        predictions_file[0]
    )

ax.tick_params(axis='x', pad=15)
# ax.legend(loc='best')
ax.legend(loc='best')
legend = plt.legend(frameon = 1)
frame = legend.get_frame()
frame.set_facecolor('white')
ax.grid(axis='y')
ax.grid(axis='x')
plt.savefig(FIGURES_BASE + "prc.mhc.svg", format='svg', dpi=300, bbox_inches='tight')
plt.savefig(FIGURES_BASE + "prc.mhc.png", format='png', dpi=300, bbox_inches='tight')

In [None]:
df_list = []

for predictions_file in predictions_files:
    true_values_list, predicted_values_list = [], []
    for i in range(5):
        prediction_df = predictions_file[1][i]
        y_true = prediction_df['sign'].to_numpy()
        y_prob = prediction_df[f'prediction_{i}'].to_numpy()
        
        df = get_scores(y_true, y_prob)
        df['method'] = predictions_file[0]
        
        df_list.append(df)
results = pd.concat(df_list)

In [None]:
results.groupby(['method', 'metrics']).mean()

In [None]:
std_df = results.groupby(['method', 'metrics']).std()
std_df['score'] = std_df['score'] / 5
std_df