In [1]:
#-------- Import Libraries --------#
import torchvision
import torch
import os
import time
import esm
import sys
import random
import pickle
import mlflow
import gc
import collections
import numpy as np
import pandas as pd
import torch.nn as nn
import seaborn as sn
import matplotlib.pyplot as plt
from datetime import date
from sklearn.metrics import matthews_corrcoef
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve, auc
gc.collect()

0

In [2]:
#-------- Import Modules from project--------#
import functions as func
#import encoding as enc

# Load pre-trained ESM-MSA-1b model
model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
batch_converter = alphabet.get_batch_converter()
    
## NO PADDING:
def esm_MSA(peptide, pooling=False, add_padding=True):
    
    peptides = [peptide]
    
    embeddings = list()
    # Load pre-trained ESM-1b model
            #above
        
    data = []
    
    for peptide in peptides:
        data.append(("", peptide))
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[12], return_contacts=True)
    token_representations = results["representations"][12].numpy()[0][0]
        
    del results, batch_labels, batch_strs, batch_tokens
    gc.collect()
    
    sequence_representations = []
    
    for i, (_, seq) in enumerate(data):

        if pooling:
            sequence_representations = token_representations[1:, ].mean(0)

        else:
            sequence_representations = token_representations[1:, ]
            if add_padding:
                pad = 420 - sequence_representations.shape[0]
                sequence_representations = np.pad(sequence_representations, ((0, pad), (0, 0)), 'constant')
        
    del token_representations
    gc.collect()
    
    return sequence_representations

In [3]:
#-------- Set Device --------#

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
else:
    print('No GPUs available. Using CPU instead.')
    device = torch.device('cpu')

No GPUs available. Using CPU instead.


In [4]:
#-------- Seeds --------#

seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

torch.use_deterministic_algorithms(True)

In [5]:
#-------- Import Modules from project--------#

import encoding as enc
from model import Net_project
import functions as func


In [6]:
#-------- Import Dataset --------#      

data_list = []
target_list = []

import glob
for index in range(5):
    for fp in glob.glob("../data/train/*{}*input.npz".format(index+1)):
        print("Read file", fp)
        data = np.load(fp)["arr_0"]
        targets = np.load(fp.replace("input", "labels"))["arr_0"]
        data_list.append(data)
        target_list.append(targets)
    
for fp in glob.glob("../data/validation/*5*input.npz"):
    print("Read file", fp)
    data = np.load(fp)["arr_0"]
    targets = np.load(fp.replace("input", "labels"))["arr_0"]
    data_list.append(data)
    target_list.append(targets)
    
print("\n")

print("data_list:", len(data_list))
print("target_list", len(target_list))

print("\n")

data_partitions = len(data_list)
for i in range(data_partitions):
    print("Size of file", i+1, len(data_list[i]))

Read file ../data/train/P1_input.npz
Read file ../data/train/P2_input.npz
Read file ../data/train/P3_input.npz
Read file ../data/train/P4_input.npz
Read file ../data/validation/P5_input.npz


data_list: 5
target_list 5


Size of file 1 1526
Size of file 2 1168
Size of file 3 1480
Size of file 4 1532
Size of file 5 1207


In [None]:
#embedding of data
embedding = 'MSA'
files_complete = False
merge = False

data_list_enc = list()

start = time.time()
mhc_bool = False

#create directory to fetch/store embedded

if embedding == "baseline":
    print('baseline')
    data_list_enc = data_list

elif embedding == "MSA":
    if files_complete == False:
        print("MSA")
        count = 0

        for dataset in data_list:

            count += 1

            print("\nWorking on file", count)
            
            mhc_enc = list()
            pep_enc = list()
            tcr_enc = list()

            x_enc = func.extract_sequences(dataset, merge = merge)
            
            print("Sequences are extracted")

            if merge:
                pass

            else:
                print("Merge false ( means Separate True )")
                print(len(dataset), len(x_enc['MHC'].tolist()), len(x_enc['peptide'].tolist()), len(x_enc['tcr'].tolist()))

                if mhc_bool == False:
                    print("\nEncode MHC")
                    mhc_enc_1 = esm_MSA(x_enc['MHC'].tolist()[0], pooling=False, add_padding=False)
                    print("Done:", len(mhc_enc_1))
                    mhc_bool = True

                for i in range(len(x_enc['MHC'].tolist())):
                    if i % 10 == 0:
                        print("\nFlag", i, "peptides are encoded - Time:", round((time.time()-start)/60,2), "mins.")

                    mhc_enc.append(mhc_enc_1)
                    pep_enc.append(esm_MSA(x_enc['peptide'].tolist()[i], pooling=False, add_padding=False))
                    tcr_enc.append(esm_MSA(x_enc['tcr'].tolist()[i], pooling=False, add_padding=False))

                mhc_enc = [x.tolist() for x in mhc_enc]
                pep_enc = [x.tolist() for x in pep_enc]
                tcr_enc = [x.tolist() for x in tcr_enc]
                
                print("ESM_1B is done.\n")
                print("Wriring in files.\n")

                # save separate encodings:
                
                outfilemhc = open(embedding_dir + 'dataset-{}-file{}-mhc-MSA.pkl'.format(embedding, count),'wb')
                outfilepep = open(embedding_dir + 'dataset-{}-file{}-pep-MSA.pkl'.format(embedding, count),'wb')
                outfiletcr = open(embedding_dir + 'dataset-{}-file{}-tcr-MSA.pkl'.format(embedding, count),'wb')
                pickle.dump(mhc_enc, outfilemhc)
                pickle.dump(pep_enc, outfilepep)
                pickle.dump(tcr_enc, outfiletcr)                
                outfilemhc.close()
                outfilepep.close()
                outfiletcr.close()
                
                print("Wriring in files is done.\n")
                

                ## PREPARE df to paste here
                
                print("create df_aa")
                df_aa = func.extract_aa_and_energy_terms(dataset)
                print("df_aa is done")
                
                ### paste energy terms here
                
                data_list_PASTE = list()

                for cmplx in range(len(dataset)):

                        df = pd.DataFrame( dataset[cmplx] )
                        new_df = pd.DataFrame( df_aa[cmplx] )

                        df_emb_mhc = pd.DataFrame(mhc_enc[cmplx])
                        df_emb_pep = pd.DataFrame(pep_enc[cmplx])
                        df_emb_tcr = pd.DataFrame(tcr_enc[cmplx])

                        pad_index_list = sorted(new_df[new_df.iloc[:,34]=='X'].index.tolist())

                        padding1_len = 0
                        for pad in range(len(pad_index_list)):
                            padding1_len += 1
                            if pad_index_list[-1] > 230:
                                if pad_index_list[pad+1]-pad_index_list[pad] > 100:
                                    tcr_start = pad_index_list[pad] + 1
                                    tcr_end = pad_index_list[pad+1]
                                    break
                            else:
                                tcr_start = pad_index_list[-1] + 1
                                tcr_end = 420

                        padding2_len = len(pad_index_list) - padding1_len

                        mhc = pd.concat([df_emb_mhc.reset_index(drop=True), df.iloc[:179,20:].reset_index(drop=True)], axis=1).iloc[:, :-1].reset_index(drop=True)
                        pep = pd.concat([df_emb_pep.reset_index(drop=True), df.iloc[179:188,20:].reset_index(drop=True)], axis=1).iloc[:, :-1].reset_index(drop=True)
                        padding1 = pd.DataFrame(0, index=np.arange(padding1_len), columns=pep.columns)
                        tcr = pd.concat([df_emb_tcr.reset_index(drop=True), (df.iloc[tcr_start:tcr_end,20:]).reset_index(drop=True)], axis=1).iloc[:, :-1].reset_index(drop=True)
                        padding2 = pd.DataFrame(0, index=np.arange(padding2_len), columns=pep.columns)

                        final_cmplx = pd.concat([mhc,pep,padding1,tcr,padding2]).reset_index(drop=True).values

                        data_list_PASTE.append( final_cmplx )

                        if cmplx % 100 == 0:
                            print("Flag - pasting:", cmplx)
                            print(len(final_cmplx))

                o = open('esm-energies-file-MSA-{}.pkl'.format(count),'wb')
                pickle.dump(data_list_PASTE, o)
                o.close()
                
print("Done")