## This notebook assumes that you have extracted the embeddings (using the procedure mentioned in the 0_ESM_Embeddings_Extractor.ipynb notebook) and have stored them in a zipped format

## Since we have used google colab; we copy the embeddings from google drive before training the model; similar procedure can be used to run it locally;

## A example is shown below:

In [None]:
####################################################################################################### TCR_beta_90
#### Download the embeddings for TCR_beta_90 training set
!cp -r '/content/drive/MyDrive/TCR-pMHC-results/netTCR/TCR_beta_90/ESM1v_cdr3b.zip' '/content'
!cp -r '/content/drive/MyDrive/TCR-pMHC-results/netTCR/TCR_beta_90/ESM1v_peptide.zip' '/content'
# eval for MIRA
!cp -r '/content/drive/MyDrive/TCR-pMHC-results/netTCR/TCR_beta_90/ESM1v_mira_cdr3b.zip' '/content'
!cp -r '/content/drive/MyDrive/TCR-pMHC-results/netTCR/TCR_beta_90/ESM1v_mira_peptide.zip' '/content'
## actual csv for the class labels
!cp -r '/content/drive/MyDrive/TCR-pMHC-results/netTCR/TCR_beta_90/train_beta_90.csv' '/content'
!cp -r '/content/drive/MyDrive/TCR-pMHC-results/netTCR/TCR_beta_90/mira_eval_threshold90.csv' '/content'

## Unzip the embeddings to folder for developing the train and test set

In [None]:

#### unzip train
!unzip -q /content/ESM1v_cdr3b.zip -d /content/train_cdr3b
!unzip -q /content/ESM1v_peptide.zip -d /content/train_peptide

### unzip eval
!unzip -q /content/ESM1v_mira_cdr3b.zip -d /content/mira_cdr3b
!unzip -q /content/ESM1v_mira_peptide.zip -d /content/mira_peptide

## for running locally the format should be

In [None]:
#### unzip train
!unzip -q <path_to_train_cdr3b.zip> -d  train_cdr3b
!unzip -q <path_to_train_peptide.zip -d train_peptide

### unzip eval
!unzip -q <path_to_test_cdr3b.zip> -d  test_cdr3b
!unzip -q <path_to_test_peptide.zip -d test_peptide

In [None]:
import sys
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
from natsort import natsorted
import matplotlib.pyplot as plt
plt.style.use('seaborn')
%matplotlib inline
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Layer,Input, Dense, Dropout, Activation, Concatenate, Flatten, BatchNormalization
from tensorflow.keras.regularizers import l2,l1
from tensorflow.keras.optimizers import SGD,Adam,RMSprop
#from tensorflow.compat.v1 import InteractiveSession
import tensorflow.keras.backend as K
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array, load_img
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint,ReduceLROnPlateau
from tensorflow.keras.models import load_model
import tensorflow.keras.metrics
from sklearn.preprocessing import LabelEncoder
import sklearn
import os
from natsort import natsorted
from sklearn.metrics import *
import tensorflow as tf
tf.random.set_seed(1)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.utils import to_categorical, plot_model
from tqdm import tqdm
import torch

In [None]:
### train files
path_train_cdr3b = <'path_to_train_cdr3b_embeddings'>                  # Example ./train_cdr3b
path_train_pepti = <'path_to_train_peptide_embeddings'>                # Example ./train_peptide

mat_cdr3b = os.listdir(path_train_cdr3b) 
mat_pepti = os.listdir(path_train_pepti)

### eval files
path_eval_cdr3b = <'path_to_test_cdr3b_embeddings'>           # Example ./test_cdr3b
path_eval_pepti = <'path_to_test_peptide_embeddings'>         # Example ./test_peptide

evalmat_cdr3b = os.listdir(path_eval_cdr3b) 
evalmat_pepti = os.listdir(path_eval_pepti)


## natsort is used to order the pairs as they appear in the .csv; this would be helpful later to map the pairs with their respective labels

In [None]:
###train
mat_cdr3b = natsorted(mat_cdr3b)
mat_pepti = natsorted(mat_pepti)


###eval
evalmat_cdr3b = natsorted(evalmat_cdr3b)
evalmat_pepti = natsorted(evalmat_pepti)

## the following step extracts the embeddings and stores in a numpy matrix

In [None]:
orig_matmat_cdr3b = np.zeros((len(mat_cdr3b),1280))
orig_matmat_pepti = np.zeros((len(mat_cdr3b),1280))
orig_evalmatmat_cdr3b = np.zeros((len(evalmat_pepti) ,1280))
orig_evalmatmat_pepti = np.zeros((len(evalmat_pepti) ,1280))

### load train samples 

for i in tqdm(range(len(mat_cdr3b))):

    x = torch.load(path_train_cdr3b+mat_cdr3b[i])['mean_representations'][33]
    orig_matmat_cdr3b[i] = x

    y = torch.load(path_train_pepti+mat_pepti[i])['mean_representations'][33]
    orig_matmat_pepti[i] = y


### load eval samples

for j in tqdm(range(len(evalmat_pepti))):

    z = torch.load(path_eval_cdr3b+evalmat_cdr3b[j])['mean_representations'][33]
    orig_evalmatmat_cdr3b[j] = z

    t = torch.load(path_eval_pepti+evalmat_pepti[j])['mean_representations'][33]
    orig_evalmatmat_pepti[j] = t



## load *.csv files for labels

In [None]:
df_train = pd.read_csv('/content/train_beta_90.csv')
df_eval = pd.read_csv('/content/mira_eval_threshold90.csv')
y_train = df_train['binder'].values
y_eval  = df_eval['binder'].values
orig_y_train = y_train.reshape(-1,1)
orig_y_eval = y_eval.reshape(-1,1)

## mode is used to switch between cross-validation and testing on the MIRA dataset

In [None]:
#mode = 'cv'
mode = 'eval'


######## get CV splits from original data

if mode=='cv':
    train_list = [3,4,1,2]
    test_list  = [5]


    matmat_cdr3b = orig_matmat_cdr3b[df_train['partition'].isin(train_list)]
    matmat_pepti = orig_matmat_pepti[df_train['partition'].isin(train_list)]
    y_train      = orig_y_train[df_train['partition'].isin(train_list)]

    #### internal eval
    evalmatmat_cdr3b = orig_matmat_cdr3b[df_train['partition'].isin(test_list)]
    evalmatmat_pepti = orig_matmat_pepti[df_train['partition'].isin(test_list)]
    y_eval           = orig_y_train[df_train['partition'].isin(test_list)]

elif mode=='eval':
    matmat_cdr3b = orig_matmat_cdr3b
    matmat_pepti = orig_matmat_pepti
    y_train      = orig_y_train

    #### on test
    evalmatmat_cdr3b = orig_evalmatmat_cdr3b
    evalmatmat_pepti = orig_evalmatmat_pepti
    y_eval           = orig_y_eval

## Ensure the correctness of data dimension of train and test

In [None]:
### data

print('Training', matmat_cdr3b.shape, matmat_pepti.shape, y_train.shape)
print('Evaluation', evalmatmat_cdr3b.shape, evalmatmat_pepti.shape, y_eval.shape)

## following code clears the session and is used to reset the network after each CV

In [None]:
### model
def clear_sess():
  try:
    del model 
    del history 
  except:
    pass
  from tensorflow.keras import backend as K
  K.clear_session()
  import gc
  gc.collect()
  return None



In [None]:
clear_sess()
!rm *.hdf5 ### remove old saved model

# model MLP for CDR3a + CDR3b + peptide

In [None]:
clear_sess()
#input_1

input_1 = Input(shape = (1280,), name='i_1')
dense1_1 = Dense(128, activation = 'relu')(input_1)
bn1_1 = BatchNormalization()(dense1_1)

#input_2
input_2 = Input(shape = (1280,), name='i_2')
dense2_1 = Dense(128, activation = 'relu')(input_2)
bn2_1 = BatchNormalization()(dense2_1)

#input_3
input_3 = Input(shape = (1280,), name='i_3')
dense3_1 = Dense(128, activation = 'relu')(input_3)
bn3_1 = BatchNormalization()(dense3_1)
 
# concatenate
##concat   = Concatenate()([dense1_1, dense2_1])
concat   = Concatenate()([bn1_1, bn2_1, bn3_1])
fc_1   = Dense(512, activation = 'relu')(concat)
#drop_1 = Dropout(0.5)(fc_1)
fc_2   = Dense(256, activation = 'relu')(fc_1)
#classification output
output  = Dense(1, activation = 'sigmoid')(fc_2)
 
# create model with two inputs
model = Model(inputs=[input_1,input_2, input_3], outputs=output)

## model MLP for CDR3b+peptide and CDR3a-peptide

In [None]:
clear_sess()
#input_1

input_1 = Input(shape = (1280,), name='i_1')
dense1_1 = Dense(128, activation = 'relu')(input_1)
bn1_1 = BatchNormalization()(dense1_1)

#input_2
input_2 = Input(shape = (1280,), name='i_2')
dense2_1 = Dense(128, activation = 'relu')(input_2)
bn2_1 = BatchNormalization()(dense2_1)


# concatenate
##concat   = Concatenate()([dense1_1, dense2_1])
concat   = Concatenate()([bn1_1, bn2_1])
fc_1   = Dense(512, activation = 'relu')(concat)
#drop_1 = Dropout(0.5)(fc_1)
fc_2   = Dense(256, activation = 'relu')(fc_1)
#classification output
output  = Dense(1, activation = 'sigmoid')(fc_2)
 
# create model with two inputs
model = Model(inputs=[input_1,input_2], outputs=output)

In [None]:
def keras_mcc(y_true, y_pred):
    tp = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    tn = K.sum(K.round(K.clip((1 - y_true) * (1 - y_pred), 0, 1)))
    fp = K.sum(K.round(K.clip((1 - y_true) * y_pred, 0, 1)))
    fn = K.sum(K.round(K.clip(y_true * (1 - y_pred), 0, 1)))

    num = tp * tn - fp * fn
    den = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
    return num / K.sqrt(den + K.epsilon())

In [None]:
metrics_c = [tensorflow.keras.metrics.AUC(name="auc_roc",curve="ROC"),tensorflow.keras.metrics.AUC(name="auc_pr",curve="PR"),keras_mcc]


In [None]:
model.compile(loss='binary_crossentropy', 
              optimizer=tensorflow.keras.optimizers.Adam(learning_rate=0.008), 
              metrics=metrics_c)

In [None]:
early_stop = EarlyStopping(monitor='loss',min_delta=0,patience=10, verbose=0,mode='min',restore_best_weights=True)

In [None]:
reduce_lr = ReduceLROnPlateau(monitor='val_keras_mcc', factor=0.99,patience=20, min_lr=0.005, verbose=1)

In [None]:
checkpoint_filepath_1 = 'weights-improvement-val-auc-pr.hdf5'
model_checkpoint_callback_1 = ModelCheckpoint(filepath=checkpoint_filepath_1,save_weights_only=False,monitor='val_auc_pr',mode='max',save_best_only=True)

checkpoint_filepath_2 = 'weights-improvement-val-keras-mcc.hdf5'
model_checkpoint_callback_2 = ModelCheckpoint(filepath=checkpoint_filepath_2,save_weights_only=False,monitor='val_keras_mcc',mode='max',save_best_only=True)

## uncomment the following cell for running CDR3a, CDR3b, peptide model

In [None]:
# fit the keras model on the dataset CDR3a, CDR3b, peptide
# history=model.fit([train_cdr3a, train_cdr3b,train_pep],train_Y,
#                   batch_size=1024, epochs=500,
#                   validation_split=0.1,
#                   callbacks=[model_checkpoint_callback_1, model_checkpoint_callback_2, reduce_lr ]
#                   verbose=1
#                   )

## uncomment the following cell for running CDR3b, peptide model

In [None]:
# # fit the keras model on the dataset CDR3b, peptide
# history=model.fit([matmat_cdr3b,matmat_pepti],y_train,
#                   batch_size=1024, epochs=500,
#                   verbose=0,
#                   validation_split=0.1,
#                   callbacks=[model_checkpoint_callback_1, model_checkpoint_callback_2, reduce_lr ]
#                   )
# print('done')

## uncomment the following cell for running CDR3a, peptide model

In [None]:
# # fit the keras model on the dataset CDR3a, CDR3b, peptide
# history=model.fit([matmat_cdr3a,matmat_pepti],y_train,
#                   batch_size=1024, epochs=500,
#                   verbose=0,
#                   validation_split=0.1,
#                   callbacks=[model_checkpoint_callback_1, model_checkpoint_callback_2, reduce_lr ]
#                   )
# print('done')

## once the best model is trained, we can test it over the evaluation dataset

In [None]:
####model load max aucpr


model_loaded = '/content/weights-improvement-val-auc-pr.hdf5'
model = tensorflow.keras.models.load_model(model_loaded,compile=False)



y_pred = model.predict([evalmatmat_cdr3b,evalmatmat_pepti])
y_act = y_eval.flatten()
y_pred= y_pred.flatten()
y_pred_c=np.where(y_pred>0.5,1,0)


print(roc_auc_score(y_act, y_pred),average_precision_score(y_act, y_pred),matthews_corrcoef(y_act,y_pred_c),cohen_kappa_score(y_act,y_pred_c))

In [None]:
####model load max mcc


model_loaded = '/content/weights-improvement-val-keras-mcc.hdf5'
model = tensorflow.keras.models.load_model(model_loaded,compile=False)



y_pred = model.predict([evalmatmat_cdr3b,evalmatmat_pepti])
y_act = y_eval.flatten()
y_pred= y_pred.flatten()
y_pred_c=np.where(y_pred>0.5,1,0)


print(roc_auc_score(y_act, y_pred),average_precision_score(y_act, y_pred),matthews_corrcoef(y_act,y_pred_c),cohen_kappa_score(y_act,y_pred_c))

## peptide wise

In [None]:
pep_list = [ 'GILGFVFTL', 'GLCTLVAML']

for i in range(2):
    pep_f = pep_list[i]

    y_pred_pep = y_pred_c[df_eval['peptide'] == pep_f]

    y_act_pep  = y_act[df_eval['peptide'] == pep_f]

    #y_act_pep.shape, y_pred_pep.shape

    y_act_pw = y_act_pep.flatten()
    y_pred_pw= y_pred_pep.flatten()

    y_pred_pw_c=np.where(y_pred_pw>0.5,1,0)

    print('MCC', pep_f ,matthews_corrcoef(y_act_pw,y_pred_pw_c))