In [1]:
%reload_ext autoreload
%autoreload 2
import sys
sys.path.append('/home/sebastian/masters/') # add my repo to python path
import os
import torch
import numpy as np
import pandas as pd
import modules
import matplotlib.pyplot as plt

from sklearn.metrics import *
from pathlib import Path

from modules.dataset import *
from modules.utils import *
from modules.models import *
from modules.lstm_utils import *

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f6870187530>

In [2]:
def compute_auc_mcc(pred_paths, threshold):
    overall_pred = list()
    overall_true = list()
    overall_thres_pred = list()
    individual_auc = list()
    individual_mcc = list()

    for i in range(len(pred_paths)):
        data = torch.load(pred_paths[i])
        pred = data["y_pred"]
        true = data["y_true"]

        try:
            auc = roc_auc_score(true, pred)
            individual_auc.append(auc)
            fpr, tpr, _ = roc_curve(true, pred)

            thresh_pred = torch.zeros(len(pred))
            thresh_pred[pred >= threshold] = 1
            mcc = matthews_corrcoef(true, thresh_pred)
            individual_mcc.append(mcc)

            overall_pred.extend(pred)
            overall_true.extend(true)
            overall_thres_pred.extend(thresh_pred)
        except ValueError:
            print(f"Fail for idx {i}")
            individual_auc.append(None)
            individual_mcc.append(None)
        
    overall_auc = roc_auc_score(overall_true, overall_pred)
    overall_fpr, overall_tpr, _ = roc_curve(true, pred, pos_label=1)
    overall_mcc = matthews_corrcoef(overall_true, overall_thres_pred)
    
    return overall_auc, overall_fpr, overall_tpr, individual_auc, overall_mcc, individual_mcc

def plot_roc_curve(unique_peptide, excluded, tpr, fpr, individual_auc):
    # ROC plot
    cm = plt.get_cmap('tab20')  # https://matplotlib.org/stable/tutorials/colors/colormaps.html

    fig = plt.figure(figsize=(12, 7))
    ax = fig.add_subplot(111)
    ax.set_prop_cycle(color=[cm(1*i/n_splits) for i in range(n_splits)])
    for i, pep in enumerate(unique_peptides):
        if pep not in excluded:
            ax.plot(
                fpr[i], 
                tpr[i], 
                label=f"{pep}, AUC = {round(individual_auc[i], 3)}",
            )
    plt.legend()
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("LOO validation ROC curve")# (peptides with count < 40 left out for visual clarity (drop on CDR3b))")
    plt.show()

In [3]:
root = Path("/home/sebastian/masters/data/")
data_root = root / "neat_data"
metadata_path = data_root / "metadata.csv"
processed_dir = data_root / "processed"
state_file = root / "state_files" / "e53-s1952148-d93703104.state"
out_dir = root / "results_2"

model_dir = data_root / "raw" / "tcrpmhc"

paths = list(model_dir.glob("*"))
join_key = [int(x.name.split("_")[0]) for x in paths]
path_df = pd.DataFrame({'#ID': join_key, 'path': paths})

metadata = pd.read_csv(metadata_path)
metadata = metadata.join(path_df.set_index("#ID"), on="#ID", how="inner")  # filter to non-missing data
metadata = metadata.reset_index(drop=True)
unique_peptides = metadata["peptide"].unique()

In [40]:
#validation_pep = "KTWGQYWQV"
#unique_peptides = np.delete(unique_peptides, np.where(unique_peptides == validation_pep))

no_pos_peptides = ["CLGGLLTMV", "ILKEPVHGV"]
for pep in no_pos_peptides:
    unique_peptides = np.delete(unique_peptides, np.where(unique_peptides == pep))

threshold = 0.15
cv_fold = 5
excluded = [] # ["KLQCVDLHV", "KVAELVHFL", "YLLEMLWRL", "SLLMWITQV"]


overall_CV_results_df = pd.DataFrame(columns=[
    "Model", 
    "AUC",
    f"MCC (threshold = {threshold})",
])

### Embedding verification

In [72]:
save_dir = Path(out_dir / "lstm_embedding_verification/")
pred_paths = [save_dir / f"pred_{i}.pt" for i in range(cv_fold)]
overall_pred = list()
overall_true = list()
for i in range(len(pred_paths)):
    data = torch.load(pred_paths[i], map_location="cpu")
    pred = data["y_pred"]
    true = data["y_true"]

    auc = roc_auc_score(true, pred , multi_class="ovr", average="weighted")
    print(auc)
    overall_pred.extend(pred)
    overall_true.extend(true)
overall_auc = roc_auc_score(true, pred , multi_class="ovr", average="weighted")
print(overall_auc)

1.0
1.0
1.0
1.0
1.0
1.0


### LOO validation

In [41]:
overall_LOO_results_df = pd.DataFrame(columns=[
    "Model", 
    "AUC",
    f"MCC (threshold = {threshold})",
])

individual_AUC_results_df = pd.DataFrame(columns=["Model"] + list(unique_peptides))

individual_MCC_results_df = pd.DataFrame(columns=["Model"] + list(unique_peptides))


model_names = [
    "Single LSTM PS",
    "Single LSTM PS (no swapped negatives)",
    "Single LSTM energy terms",
    "Quad LSTM BLOSUM50",
    "Quad LSTM BLOSUM50 (no swapped negatives)",
    "Quad LSTM PS",
    "Quad LSTM PS (no swapped negatives)",
    "Quad LSTM PS+ESM",
    "Quad LSTM PS+ESM (no swapped negatives)",
    "Quad LSTM ESM",
    "Quad LSTM ESM (no swapped negatives)",
    "Quad LSTM PS (FoldX)",
    "Quad LSTM PS+ESM (FoldX)",
    "Quad LSTM PS (Rosetta)",
    "Quad LSTM PS+ESM (Rosetta)",
    "Quad LSTM PS nested LOO",
    "Quad LSTM PS+ESM nested LOO",
    "Quad LSTM ESM nested LOO",
    "CDR3+peptide LSTM PS",
    "CDR3+peptide LSTM PS (no swapped negatives)",
    "Modified ProteinSolver",
    "Modified ProteinSolver (no swapped negatives)",
    "Modified ProteinSolver (FoldX)",
    "Modified ProteinSolver (Rosetta)",
]

paths = [
    "lstm_ps_single",
    "lstm_ps_single_no_swapped",
    "lstm_single_energy",
    "lstm_blosum",
    "lstm_blosum_no_swapped",
    "lstm_ps",
    "lstm_ps_no_swapped",
    "lstm_esm_ps",
    "lstm_esm_ps_no_swapped",
    "lstm_esm",
    "lstm_esm_no_swapped",
    "lstm_ps_foldx",
    "lstm_esm_ps_foldx",
    "lstm_ps_rosetta",
    "lstm_esm_ps_rosetta",
    "lstm_ps_nested_loo",
    "lstm_esm_ps_nested_loo",
    "lstm_esm_nested_loo",
    "lstm_ps_cdr_pep_only",
    "lstm_ps_cdr_pep_only_no_swapped",
    "proteinsolver_finetune",
    "proteinsolver_finetune_no_swapped",
    "proteinsolver_finetune_foldx",
    "proteinsolver_finetune_rosetta"
]
for model_name, path in zip(model_names, paths):
    print(model_name, path)
    try:
        save_dir = next(Path(out_dir / f"{path}").glob("*"))
        pred_paths = [save_dir / f"pred_{i}.pt" for i in range(len(unique_peptides))]
        overall_auc, overall_fpr, overall_tpr, individual_auc, overall_mcc, individual_mcc = compute_auc_mcc(pred_paths, threshold)
        overall_row = [model_name, overall_auc, overall_mcc]
        overall_LOO_results_df = overall_LOO_results_df.append(
            pd.Series(overall_row, index=overall_LOO_results_df.columns), ignore_index=True
        )
        individual_AUC_results_df = individual_AUC_results_df.append(
            pd.Series([model_name] + individual_auc, index=individual_AUC_results_df.columns), ignore_index=True
        )
        individual_MCC_results_df = individual_MCC_results_df.append(
            pd.Series([model_name] + individual_mcc, index=individual_MCC_results_df.columns), ignore_index=True
        )
    except:
        print(model_name, "failed!")

Single LSTM PS lstm_ps_single
Single LSTM PS failed!
Single LSTM PS (no swapped negatives) lstm_ps_single_no_swapped
Single LSTM PS (no swapped negatives) failed!
Single LSTM energy terms lstm_single_energy
Single LSTM energy terms failed!
Quad LSTM BLOSUM50 lstm_blosum
Quad LSTM BLOSUM50 failed!
Quad LSTM BLOSUM50 (no swapped negatives) lstm_blosum_no_swapped
Quad LSTM BLOSUM50 (no swapped negatives) failed!
Quad LSTM PS lstm_ps


  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


Quad LSTM PS (no swapped negatives) lstm_ps_no_swapped


  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


Quad LSTM PS+ESM lstm_esm_ps


  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


Quad LSTM PS+ESM (no swapped negatives) lstm_esm_ps_no_swapped
Quad LSTM ESM lstm_esm
Quad LSTM ESM (no swapped negatives) lstm_esm_no_swapped
Quad LSTM ESM (no swapped negatives) failed!
Quad LSTM PS (FoldX) lstm_ps_foldx
Fail for idx 14
Fail for idx 15


  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


Quad LSTM PS+ESM (FoldX) lstm_esm_ps_foldx
Fail for idx 14
Fail for idx 15


  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)
  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


Quad LSTM PS (Rosetta) lstm_ps_rosetta
Fail for idx 14
Fail for idx 15


  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


Quad LSTM PS+ESM (Rosetta) lstm_esm_ps_rosetta
Fail for idx 14
Fail for idx 15




Quad LSTM PS nested LOO lstm_ps_nested_loo
Quad LSTM PS nested LOO failed!
Quad LSTM PS+ESM nested LOO lstm_esm_ps_nested_loo
Quad LSTM PS+ESM nested LOO failed!
Quad LSTM ESM nested LOO lstm_esm_nested_loo
Quad LSTM ESM nested LOO failed!
CDR3+peptide LSTM PS lstm_ps_cdr_pep_only
CDR3+peptide LSTM PS (no swapped negatives) lstm_ps_cdr_pep_only_no_swapped
Modified ProteinSolver proteinsolver_finetune
Modified ProteinSolver failed!
Modified ProteinSolver (no swapped negatives) proteinsolver_finetune_no_swapped
Modified ProteinSolver (no swapped negatives) failed!
Modified ProteinSolver (FoldX) proteinsolver_finetune_foldx
Modified ProteinSolver (FoldX) failed!
Modified ProteinSolver (Rosetta) proteinsolver_finetune_rosetta
Modified ProteinSolver (Rosetta) failed!


  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)
  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


In [42]:
overall_LOO_results_df

Unnamed: 0,Model,AUC,MCC (threshold = 0.15)
0,Quad LSTM PS,0.553454,0.059061
1,Quad LSTM PS (no swapped negatives),0.564843,0.036234
2,Quad LSTM PS+ESM,0.534286,0.033582
3,Quad LSTM PS+ESM (no swapped negatives),0.593484,0.083686
4,Quad LSTM ESM,0.532073,-0.024482
5,Quad LSTM PS (FoldX),0.616252,0.148333
6,Quad LSTM PS+ESM (FoldX),0.49732,-0.067958
7,Quad LSTM PS (Rosetta),0.617995,0.092145
8,Quad LSTM PS+ESM (Rosetta),0.473753,0.0574
9,CDR3+peptide LSTM PS,0.478532,-0.069805


In [21]:
individual_AUC_results_df

Unnamed: 0,Model,NLVPMVATV,SLFNTVATL,KTWGQYWQV,RTLNAWVKV,KVLEYVIKV,YLLEMLWRL,IMDQVPFSV,MLDLQPETT,FLYALALLL,SLLMWITQV,LLFGYPVYV,GLCTLVAML,RMFPNAPYL,GILGFVFTL,KLQCVDLHV,KVAELVHFL
0,Quad LSTM PS,0.625934,0.469388,0.498701,0.495544,0.462963,0.323232,0.585648,0.545455,0.327019,0.388235,0.480533,0.603979,0.592593,0.53074,1.0,0.0
1,Quad LSTM PS (no swapped negatives),0.560408,0.232653,0.488312,0.58066,0.388889,0.646465,0.425926,0.519886,0.637888,0.647059,0.3848,0.517953,0.404558,0.590271,1.0,0.0
2,Quad LSTM PS+ESM,0.632675,0.461224,0.457143,0.459893,0.444444,0.30303,0.543981,0.46875,0.564286,0.482353,0.5328,0.5059,0.472934,0.537914,0.75,0.5
3,Quad LSTM PS+ESM (no swapped negatives),0.562191,0.302041,0.548052,0.53164,0.462963,0.484848,0.546296,0.568182,0.431677,0.705882,0.390933,0.526442,0.438746,0.624908,1.0,0.5
4,Quad LSTM ESM,0.526214,0.510204,0.431169,0.519162,0.185185,0.535354,0.576389,0.426136,0.77795,0.494118,0.472267,0.593058,0.623932,0.514647,0.75,0.0
5,Quad LSTM PS (FoldX),0.566243,0.533333,0.419913,0.515038,0.74359,0.141026,0.37013,0.587302,0.660965,0.541667,0.464286,0.627928,0.392857,0.640495,,
6,Quad LSTM PS+ESM (FoldX),0.615496,0.280952,0.467532,0.485589,0.435897,0.307692,0.50974,0.595238,0.478947,0.5,0.659555,0.522959,0.321429,0.455369,,
7,Quad LSTM PS (Rosetta),0.588697,0.314286,0.519481,0.573935,0.602564,0.461538,0.555195,0.543651,0.776754,0.708333,0.434137,0.619156,0.46875,0.6335,,
8,Quad LSTM PS+ESM (Rosetta),0.60595,0.32381,0.4329,0.505013,0.551282,0.487179,0.50974,0.376984,0.72193,0.75,0.419759,0.566075,0.553571,0.381663,,
9,CDR3+peptide LSTM PS,0.515583,0.485714,0.566234,0.425579,0.472222,0.464646,0.548611,0.482955,0.756832,0.329412,0.367733,0.570635,0.569801,0.450719,1.0,0.5


In [36]:
t = ["/home/sebastian/masters/data/results_2/lstm_esm_ps_nested_loo/b3edafc0112356cefbdde3ca0ec5b396/pred_0.pt",
    "/home/sebastian/masters/data/results_2/lstm_esm_ps_nested_loo/b3edafc0112356cefbdde3ca0ec5b396/pred_1.pt"]
overall_auc, overall_fpr, overall_tpr, individual_auc, overall_mcc, individual_mcc = compute_auc_mcc(t, threshold)

  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


### 5 fold CV

In [43]:
model_names = [
    #"Single LSTM PS",
    #"Single LSTM PS (swapped negatives)",
    "Single LSTM energy terms",
    "Quad LSTM BLOSUM50",
    "Quad LSTM BLOSUM50 (no swapped negatives)",
    "Quad LSTM PS",
    "Quad LSTM PS (no swapped negatives)",
    "Quad LSTM PS+ESM",
    "Quad LSTM PS+ESM (no swapped negatives)",
    "Quad LSTM ESM",
    "Quad LSTM ESM (no swapped negatives)",
    #"Quad LSTM PS (FoldX)",
    #"Quad LSTM PS+ESM (FoldX)",
    #"Quad LSTM PS (Rosetta)",
    #"Quad LSTM PS+ESM (Rosetta)",
    #"CDR3+peptide LSTM PS",
    #"CDR3+peptide LSTM PS (swapped negatives)",
    "Modified ProteinSolver",
    "Modified ProteinSolver (no swapped negatives)",
    #"Modified ProteinSolver (FoldX)",
    #"Modified ProteinSolver (Rosetta)"
    
    
]

paths = [
    #"lstm_ps_single",
    #"lstm_ps_single_swapped",
    "lstm_single_energy_80_cv",
    "lstm_blosum_80_cv",
    "lstm_blosum_80_cv_no_swapped",
    "lstm_ps_80_cv",
    "lstm_ps_80_no_swapped",
    "lstm_esm_ps_80_cv_no_swapped",
    "lstm_esm_80_cv",
    "lstm_esm_80_cv_no_swapped",
    #"lstm_ps_foldx",
    #"lstm_esm_ps_foldx",
    #"lstm_ps_rosetta",
    #"lstm_esm_ps_rosetta",
    #"lstm_ps_cdr_pep_only",
    #"lstm_ps_cdr_pep_only_swapped",
    "proteinsolver_finetune_80_cv",
    "proteinsolver_finetune_80_cv (no swapped negatives)",
    #"proteinsolver_finetune_foldx",
    #"proteinsolver_finetune_rosetta"
    
]
for model_name, path in zip(model_names, paths):
    print(model_name, path)
    try:
        save_dir = Path(out_dir / f"{path}/b3edafc0112356cefbdde3ca0ec5b396")
        pred_paths = [save_dir / f"pred_{i}.pt" for i in range(cv_fold)]
        overall_auc, overall_fpr, overall_tpr, individual_auc, overall_mcc, individual_mcc = compute_auc_mcc(pred_paths, threshold)
        overall_row = [model_name, overall_auc, overall_mcc]
        overall_CV_results_df = overall_CV_results_df.append(
            pd.Series(overall_row, index=overall_CV_results_df.columns), ignore_index=True
        )
    except:
        print(model_name, "failed!")

Quad LSTM PS lstm_ps_80_cv
Quad LSTM PS (no swapped negatives) lstm_ps_80_no_swapped
Quad LSTM PS (no swapped negatives) failed!
Quad LSTM PS+ESM lstm_esm_ps_80_cv_no_swapped
Quad LSTM PS+ESM failed!
Quad LSTM PS+ESM (no swapped negatives) lstm_esm_80_cv
Quad LSTM PS+ESM (no swapped negatives) failed!
Quad LSTM ESM lstm_esm_80_cv_no_swapped
Quad LSTM ESM (no swapped negatives) proteinsolver_finetune_80_cv
Quad LSTM ESM (no swapped negatives) failed!
Modified ProteinSolver proteinsolver_finetune_80_cv (no swapped negatives)
Modified ProteinSolver failed!


In [44]:
overall_CV_results_df

Unnamed: 0,Model,AUC,MCC (threshold = 0.15)
0,Quad LSTM PS,0.862486,0.557956
1,Quad LSTM ESM,0.876277,0.495974
