In [2]:
import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, roc_auc_score, average_precision_score, accuracy_score
from scipy.special import expit

from ast import literal_eval
import json

import matplotlib.pyplot as plt
import seaborn as sns
import torch

%matplotlib inline
%config InlineBackend.figure_format='retina'
plt.rcParams['figure.dpi']=150
plt.rcParams['savefig.dpi'] = 600

# Load datasets

In [3]:

chemblv2_500_path = "/nasa/datasets/kyodai_federated/proj_202111_202203/activity/prepared/chembl_v2_above_500.csv"
splits_path = "/nasa/datasets/kyodai_federated/proj_202111_202203/activity/prepared/splits.json"
ood_path = "/nasa/shared_homes/loic/kyodai-kmol/kmol_internal_new/kmol-internal/data/datasets/chemblv2_OOD_set.csv"
superOOD_path = "/nasa/shared_homes/loic/kyodai-kmol/kmol_internal_new/kmol-internal/data/datasets/chemblv2_superOOD.csv"


original_df = pd.read_csv(chemblv2_500_path)

with open(splits_path) as file:
    splits = json.load(file)

test_df = original_df.iloc[splits['test']]

# Load predictions

In [4]:
logs_path = "/nasa/shared_homes/loic/kyodai-kmol/kmol_main2/kmol-internal/data/logs/"

evidential_preds = logs_path+"bc4_edl_nologits/predictions.csv"
mcdropout_preds = logs_path+"bc4_mcdropout/predictions.csv"
ensemble_preds = logs_path+"bc4_ensemble/predictions.csv"
ood_edl_preds = logs_path+"bc4_edl_nologits_OOD/2022-11-16_09-23/predictions.csv"
superOOD_edl_preds = logs_path+"bc4_edl_nologits_superOOD/2022-11-16_09-39/predictions.csv"

lrodd_preds = logs_path+"lrodd_test_full/2023-04-04_14-48/predictions.csv"
ood_lrodd_preds = logs_path+"lrodd_OOD/2023-04-05_16-00/predictions.csv"
superOOD_lrodd_preds = logs_path+"lrodd_superOOD/2023-04-05_16-04/predictions.csv"

experiments_dict = {
    "evidential": {"preds": evidential_preds, "ood_preds": ood_edl_preds, "superOOD_preds": superOOD_edl_preds},
    "mc_dropout": {"preds": mcdropout_preds, "ood_preds": None, "superOOD_preds": None},
    "ensemble": {"preds": ensemble_preds, "ood_preds": None, "superOOD_preds": None},
    "LRODD": {"preds": lrodd_preds, "ood_preds": ood_lrodd_preds, "superOOD_preds": superOOD_lrodd_preds}
}

#inf_type = "mc_dropout"
inf_type = "LRODD"
#inf_type = "ensemble"
#inf_type = "LROOD"
ood_tests = False

# Creating the predictions dataframe
predictions = pd.read_csv(experiments_dict[inf_type]["preds"])
predictions.set_index('id', inplace=True)

test_df = test_df.loc[predictions.index]

def read_and_set_index(filename):
    df = pd.read_csv(filename)
    df.set_index('id', inplace=True)
    return df

if experiments_dict[inf_type]["ood_preds"] != None:
    ood_tests = True
    predictions_ood = read_and_set_index(experiments_dict[inf_type]["ood_preds"])
    ood_df = pd.read_csv(ood_path)
    ood_df = ood_df.loc[predictions_ood.index]
    predictions_ood["target"] = ood_df["target_sequence"]

    predictions_superOOD = read_and_set_index(experiments_dict[inf_type]["superOOD_preds"])
    superOOD_df = pd.read_csv(superOOD_path)
    superOOD_df = superOOD_df.loc[predictions_superOOD.index]
    predictions_superOOD["target"] = superOOD_df["target_sequence"]


In [6]:
def process_predictions(preds_df, inf_type, threshold=0.5):
    if inf_type == "evidential":
        preds_df["t_100n_error"] = abs((preds_df["t_100n_ground_truth"] == preds_df["t_100n"]).astype(int)-preds_df["t_100n_softmax"])
        preds_df["t_1u_error"] = abs((preds_df["t_1u_ground_truth"] == preds_df["t_1u"]).astype(int)-preds_df["t_1u_softmax"])
        preds_df["t_10u_error"] = abs((preds_df["t_10u_ground_truth"] == preds_df["t_10u"]).astype(int)-preds_df["t_10u_softmax"])

        preds_df["t_100n_logits_aleatoric"] = abs(1 - preds_df["t_100n_softmax"])
        preds_df["t_1u_logits_aleatoric"] = abs(1 - preds_df["t_1u_softmax"])
        preds_df["t_10u_logits_aleatoric"] = abs(1 - preds_df["t_10u_softmax"])
    else:
        preds_df["t_100n_error"] = abs(preds_df["t_100n_ground_truth"] - preds_df["t_100n"])
        preds_df["t_1u_error"] = abs(preds_df["t_1u_ground_truth"] - preds_df["t_1u"])
        preds_df["t_10u_error"] = abs(preds_df["t_10u_ground_truth"] - preds_df["t_10u"])

        preds_df["t_100n_logits_aleatoric"] = abs(preds_df["t_100n"] - (preds_df["t_100n"] > threshold).astype(int))
        preds_df["t_1u_logits_aleatoric"] = abs(preds_df["t_1u"] - (preds_df["t_1u"] > threshold).astype(int))
        preds_df["t_10u_logits_aleatoric"] = abs(preds_df["t_10u"] - (preds_df["t_10u"] > threshold).astype(int))

    preds_df["t_100n_error_thresh"] = abs(preds_df["t_100n_ground_truth"] - (preds_df["t_100n"] > threshold).astype(int))
    preds_df["t_1u_error_thresh"] = abs(preds_df["t_1u_ground_truth"] - (preds_df["t_1u"] > threshold).astype(int))
    preds_df["t_10u_error_thresh"] = abs(preds_df["t_10u_ground_truth"] - (preds_df["t_10u"] > threshold).astype(int))

    preds_df["cumulated_error"] = preds_df["t_100n_error"] + preds_df["t_1u_error"] + preds_df["t_10u_error"]
    preds_df["cumulated_error_thresh"] = preds_df["t_100n_error_thresh"] + preds_df["t_1u_error_thresh"] + preds_df["t_10u_error_thresh"]
    if "likelihood_ratio" in preds_df.columns:
        preds_df["t_100n_logits_var"], preds_df["t_1u_logits_var"], preds_df["t_10u_logits_var"] = preds_df["likelihood_ratio"], preds_df["likelihood_ratio"], preds_df["likelihood_ratio"]
        preds_df["cumulated_uncertainty"] = preds_df["likelihood_ratio"]
    else:
        preds_df["cumulated_uncertainty"] = preds_df["t_100n_logits_var"] + preds_df["t_1u_logits_var"] + preds_df["t_10u_logits_var"]
    preds_df["cumulated_uncertainty_aleatoric"] = preds_df["t_100n_logits_aleatoric"] + preds_df["t_1u_logits_aleatoric"] + preds_df["t_10u_logits_aleatoric"]

    return preds_df

predictions = process_predictions(predictions, inf_type)

if experiments_dict[inf_type]["ood_preds"] != None:
    predictions_ood = process_predictions(predictions_ood, inf_type)
    predictions_superOOD = process_predictions(predictions_superOOD, inf_type)



In [16]:
def plot_unc_to_error(preds, title='uncertainty to error', save_path="./mt_unc_to_error.png", error_col="cumulated_error_thresh", unc_col="cumulated_uncertainty", save=False, preds_ood=[], preds_superOOD=[]):
    fig, ax = plt.subplots(figsize=(10,6))

    # Get the error and uncertainty values from the preds dataframe
    error = preds[error_col].values
    uncertainty = preds[unc_col].values # .abs()

    # size of points in plots
    point_size = 1

    # Plot the points for the preds dataframe
    ax.scatter(uncertainty, error, s=point_size)

    # Check if there are OOD or superOOD preds dataframes
    if len(preds_ood) > 0:
        # Get the error and uncertainty values from the OOD preds dataframe
        error_ood = preds_ood[error_col].values
        uncertainty_ood = preds_ood[unc_col].values #abs

        # Get the error and uncertainty values from the superOOD preds dataframe
        error_superOOD = preds_superOOD[error_col].values
        uncertainty_superOOD = preds_superOOD[unc_col].values #abs

        # Plot the points for the OOD and superOOD preds dataframes
        ax.scatter(uncertainty_ood, error_ood, s=point_size, c="green", label="OOD")
        ax.scatter(uncertainty_superOOD, error_superOOD, s=point_size, c="red", label="superOOD")
    
    # Set the plot title, x-axis label, and y-axis label
    ax.set_title(title)
    ax.set_xlabel('likelihood ratio')
    ax.set_ylabel('error')
    
    # Save the plot if specified
    if save:
        fig.savefig(save_path)
    plt.show()
    plt.close(fig)

In [None]:

plot_unc_to_error(predictions, save_path="./mt_unc_to_error_thresh_("+inf_type+").png")
plot_unc_to_error(predictions, error_col="t_100n_error_thresh")
plot_unc_to_error(predictions, error_col="cumulated_error_thresh", unc_col="cumulated_uncertainty")

"""
if ood_tests:
    plot_unc_to_error(predictions, save_path="./mt_unc_to_error_oods_("+inf_type+").png",thresh=False, preds_ood=predictions_ood, preds_superOOD=predictions_superOOD, save=True)
    plot_unc_to_error(predictions, save_path="./mt_unc_to_error_thresh_oods_("+inf_type+").png",thresh=True, preds_ood=predictions_ood, preds_superOOD=predictions_superOOD, save=True)
"""

In [9]:
predictions["protein_id"] = test_df["target_id"]

In [10]:
per_prot_avg_df = predictions.groupby(["protein_id"]).mean()
if ood_tests:
    per_prot_avg_ood_df = predictions_ood.groupby(["target"]).mean()
    per_prot_avg_superOOD_df = predictions_superOOD.groupby(["target"]).mean()

In [11]:

if ood_tests:
    accuracy = 1-(predictions["cumulated_error_thresh"].mean()/3.)
    accuracy_ood = 1-(predictions_ood["cumulated_error_thresh"].mean()/3.)
    accuracy_superOOD = 1-(predictions_superOOD["cumulated_error_thresh"].mean()/3.)

    avg_error = predictions["cumulated_error_thresh"].mean()
    avg_error_ood = predictions_ood["cumulated_error_thresh"].mean()
    avg_error_superOOD = predictions_superOOD["cumulated_error_thresh"].mean()

    avg_unc = predictions["cumulated_uncertainty"].mean()
    avg_unc_ood = predictions_ood["cumulated_uncertainty"].mean()
    avg_unc_superOOD = predictions_superOOD["cumulated_uncertainty"].mean()

    from sklearn.metrics import precision_recall_curve, auc
    normal_scores = predictions["cumulated_uncertainty"].tolist()
    ood_scores = predictions_ood["cumulated_uncertainty"].tolist() + predictions_superOOD["cumulated_uncertainty"].tolist()
    y_scores = normal_scores + ood_scores
    y_norm = [0] * len(normal_scores)
    y_ood = [1] * len(ood_scores)
    y_labels = y_norm+y_ood
    print(len(y_ood), len(y_norm))
    precision, recall, thresholds = precision_recall_curve(y_labels, y_scores)
    auc_precision_recall = auc(recall, precision)
    f1_scores = 2*recall*precision/(recall+precision)
    print('Best threshold: ', thresholds[np.argmax(f1_scores)])
    print('Best F1-Score: ', np.max(f1_scores))
    print('auprc', round(auc_precision_recall, 3))

print("accuracy norm - ood - superOOD", round(accuracy, 3), round(accuracy_ood, 3), round(accuracy_superOOD, 3))
print("error_avg norm - ood - superOOD", round(avg_error, 3), round(avg_error_ood, 3), round(avg_error_superOOD, 3))
print("unc_avg norm - ood - superOOD", round(avg_unc, 3), round(avg_unc_ood,3), round(avg_unc_superOOD,3))

384798 75703
Best threshold:  3.5762784e-08
Best F1-Score:  nan
auprc 0.749
accuracy norm - ood - superOOD 0.887 0.739 0.705
error_avg norm - ood - superOOD 0.339 0.784 0.885
unc_avg norm - ood - superOOD -0.0 -0.0 -0.036


  f1_scores = 2*recall*precision/(recall+precision)


In [None]:

#plot_unc_to_error(per_prot_avg_df, title='Scatter plot uncertainty to error avg per protein', save_path='./mt_unc_to_error_avg_prot_('+inf_type+').png')
#plot_unc_to_error(per_prot_avg_df, title='Scatter plot uncertainty to error avg per protein', save_path='./mt_unc_to_error_avg_prot_thresh_('+inf_type+').png')
#plot_unc_to_error(per_prot_avg_df, title='Scatter plot uncertainty to error avg per protein', save_path='./mt_unc_to_error_avg_prot_thresh_aloatoric('+inf_type+').png', aleatoric=True)

if ood_tests:
    #plot_unc_to_error(per_prot_avg_df, title='Scatter plot uncertainty to error avg per protein', save_path='./mt_unc_to_error_avg_prot_thresh_ood_('+inf_type+').png', preds_ood=per_prot_avg_ood_df, preds_superOOD=per_prot_avg_superOOD_df, save=True)
    for col in ["t_100n", "t_1u", "t_10u"]:
        plot_unc_to_error(per_prot_avg_df, title='likelihood ratio to error per protein with OOD samples at threshold '+col, save_path='./mt_unc_to_error_avg_prot_thresh_ood_('+inf_type+col+').png', error_col=col+"_error_thresh", unc_col=col+"_logits_var", preds_ood=per_prot_avg_ood_df, preds_superOOD=per_prot_avg_superOOD_df, save=True)
    plot_unc_to_error(per_prot_avg_df, title='likelihood ratio to error per protein with OOD samples for cumulated thresholds', save_path='./mt_unc_to_error_avg_prot_thresh_ood_('+inf_type+').png', preds_ood=per_prot_avg_ood_df, preds_superOOD=per_prot_avg_superOOD_df, save=True)



In [14]:
def plot_lr_density(preds, title='likehood ratio density', save_path="./likehood_ratio_density.png", unc_col="likelihood_ratio", save=False, preds_ood=[], preds_superOOD=[]):
    fig, ax = plt.subplots(figsize=(10,6))

    # Get the error and uncertainty values from the preds dataframe
    uncertainty = preds[unc_col].values # .abs()

    # Plot the distribution of uncertainties for normal samples
    sns.distplot(uncertainty, ax=ax, color="blue", bins=200, hist_kws=dict(edgecolor="k", linewidth=1, alpha=0.5), kde_kws={"color": "blue", "lw": 4, "label": "Normal"})

    # Check if there are OOD or superOOD preds dataframes
    if len(preds_ood) > 0:
        # Get the error and uncertainty values from the OOD preds dataframe
        uncertainty_ood = preds_ood[unc_col].values #abs

        # Get the error and uncertainty values from the superOOD preds dataframe
        uncertainty_superOOD = preds_superOOD[unc_col].values #abs

         # Plot the distribution of uncertainties for OOD and superOOD samples
        sns.distplot(uncertainty_superOOD, ax=ax, color="red", bins=200, hist_kws=dict(edgecolor="k", linewidth=1, alpha=0.5), kde_kws={"color": "red", "lw": 2, "label": "superOOD"})
        sns.distplot(uncertainty_ood, ax=ax, color="green", bins=200, hist_kws=dict(edgecolor="k", linewidth=1, alpha=0.5), kde_kws={"color": "green", "lw": 2, "label": "OOD"})
       
    
    # Set the plot title, x-axis label, and y-axis label
    ax.set_title(title)
    ax.set_xlabel('likelihood ratio')
    ax.set_ylabel('density')

    # Set logarithmic scale for the y-axis
    #ax.set_yscale('log')
    ax.set_ylim(top=100)
    
    # Save the plot if specified
    if save:
        fig.savefig(save_path)
    plt.show()
    plt.close(fig)

In [None]:
if ood_tests:
     plot_lr_density(per_prot_avg_df, preds_ood=per_prot_avg_ood_df, preds_superOOD=per_prot_avg_superOOD_df, save=True)

In [None]:
#plot_unc_to_error(per_prot_avg_df, title='Scatter plot uncertainty to error avg per protein', save_path='./mt_unc_to_error_avg_prot_thresh_ood_('+inf_type+').png', thresh=True, preds_ood=per_prot_avg_ood_df, preds_superOOD=per_prot_avg_superOOD_df, save=False, aleatoric=True)

In [None]:
prot_df = pd.read_csv('/nasa/datasets/kyodai_federated/proj_202111_202203/activity/raw/protein-classification-all.csv')
prot_df = prot_df.loc[:, ['pref_name', 'short_name', 'sequence']]

prot_df

In [None]:
target_sequences = []
for id in per_prot_avg_df.index:
    row = test_df[test_df["target_id"] == id]
    target_sequence = row["target_sequence"].iloc[0]
    target_sequences.append(target_sequence)

per_prot_avg_df["target_sequence"] = target_sequences

per_prot_avg_df
    

In [None]:
pref_names = []
short_names = []

for id in per_prot_avg_df.index:
    sequence = per_prot_avg_df.loc[id]["target_sequence"]
    try:
        sequence_info = prot_df[prot_df["sequence"] == sequence].iloc[0]
    except IndexError:
        sequence_info = {"pref_name": "unknown", "short_name":"unknown"}
    pref_names.append(sequence_info["pref_name"])
    short_names.append(sequence_info["short_name"])


per_prot_avg_df["pref_name"] = pref_names
per_prot_avg_df["short_name"] = pref_names


per_prot_avg_df
    

In [None]:
per_prot_families_avg_df = per_prot_avg_df.groupby(["short_name"]).mean()
per_prot_families_avg_df

In [None]:
plot_unc_to_error(per_prot_families_avg_df, title='Scatter plot uncertainty to error avg per prot families', save_path='./mt_unc_to_error_avg_prot_families_('+inf_type+').png')
plot_unc_to_error(per_prot_families_avg_df, title='Scatter plot uncertainty to error avg per prot families', save_path='./mt_unc_to_error_avg_prot_families_thersh_('+inf_type+').png', thresh=True)
plot_unc_to_error(per_prot_families_avg_df, title='Scatter plot uncertainty to error avg per prot families', save_path='./mt_unc_to_error_avg_prot_families_thresh_aleatoric_('+inf_type+').png', thresh=True, aleatoric=True)

In [None]:
R_per_sample = predictions["cumulated_uncertainty"].corr(predictions["cumulated_error"])
R_per_protein = per_prot_avg_df["cumulated_uncertainty"].corr(per_prot_avg_df["cumulated_error"])
R_per_protein_families = per_prot_families_avg_df["cumulated_uncertainty"].corr(per_prot_families_avg_df["cumulated_error"])

R_per_sample_thresh = predictions["cumulated_uncertainty"].corr(predictions["cumulated_error_thresh"])
R_per_protein_thresh = per_prot_avg_df["cumulated_uncertainty"].corr(per_prot_avg_df["cumulated_error_thresh"])
R_per_protein_families_thresh = per_prot_families_avg_df["cumulated_uncertainty"].corr(per_prot_families_avg_df["cumulated_error_thresh"])

R_per_sample_thresh_aleatoric = predictions["cumulated_uncertainty_aleatoric"].corr(predictions["cumulated_error_thresh"])
R_per_protein_thresh_aleatoric = per_prot_avg_df["cumulated_uncertainty_aleatoric"].corr(per_prot_avg_df["cumulated_error_thresh"])
R_per_protein_families_thresh_aleatoric = per_prot_families_avg_df["cumulated_uncertainty_aleatoric"].corr(per_prot_families_avg_df["cumulated_error_thresh"])

accuracy = 1-(predictions["cumulated_error_thresh"].mean()/3.)
print(R_per_sample, R_per_protein, R_per_protein_families, accuracy, R_per_sample_thresh_aleatoric, R_per_protein_thresh_aleatoric, R_per_protein_families_thresh_aleatoric)


with open(inf_type+"_metrics_recap.txt", "w") as file:
    
    file.write("pearson correlation (R) epistemic uncertainty to error per samples = " + str(R_per_sample) + "\n")
    file.write("pearson correlation (R) epistemic uncertainty to error per protein = " + str(R_per_protein) + "\n")
    file.write("pearson correlation (R) epistemic uncertainty to error per protein families = " + str(R_per_protein_families) + "\n")

    file.write("accuracy = " + str(accuracy) + "\n")

    file.write("pearson correlation (R) epistemic uncertainty to error (thresholded) per samples = " + str(R_per_sample_thresh) + "\n")
    file.write("pearson correlation (R) epistemic uncertainty to error (thresholded) per protein = " + str(R_per_protein_thresh) + "\n")
    file.write("pearson correlation (R) epistemic uncertainty to error (thresholded) per protein families = " + str(R_per_protein_families_thresh) + "\n")

    file.write("pearson correlation (R) aleatoric uncertainty to error (thresholded) per samples = " + str(R_per_sample_thresh_aleatoric) + "\n")
    file.write("pearson correlation (R) aleatoric uncertainty to error (thresholded) per protein = " + str(R_per_protein_thresh_aleatoric) + "\n")
    file.write("pearson correlation (R) aleatoric uncertainty to error (thresholded) per protein families = " + str(R_per_protein_families_thresh_aleatoric) + "\n")

