<a href="https://colab.research.google.com/github/mitiau/nlp_epitopes/blob/main/epitope_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install dependecies and download weights

In [None]:
#Code provided according Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License

In [2]:
!pip install transformers
!pip install fair-esm
!pip install biopython
!pip install gdown



In [3]:
!gdown --id 1993ZcxUx_PcRlZb015yexJa1CZzfBqog
!gdown --id 1RU_hkcOMmQYP1lQV8zCxso3zrTT16BKj
!gdown --id 1e6k4q023WTrwpdBOl4hV_HqNTy6EwUrJ
!gdown --id 1U065cgX7FVugZVs09ecr5N1hyq_5P9If
!gdown --id 1yyQRvvv4a1ZoVuv0veAwe8__rInM_fKB
!gdown --id 15-WU5YuslcYPS76C8LVUGBBBHSURTCkR

Downloading...
From: https://drive.google.com/uc?id=1VaOtlK389p4XDfPfiEb8Xt1EMdGghJDI
To: /content/esm1v_ft_epitopes_0.pth
100% 2.61G/2.61G [00:29<00:00, 87.4MB/s]
Downloading...
From: https://drive.google.com/uc?id=1doP6d6K6O5DHzi0EhXYqeNVHChVDB9cu
To: /content/esm1v_ft_epitopes_1.pth
100% 2.61G/2.61G [00:52<00:00, 49.5MB/s]
Downloading...
From: https://drive.google.com/uc?id=1KYS8y28I1Ul4sXInWufwxi560-h2nahD
To: /content/esm1v_ft_epitopes_2.pth
100% 2.61G/2.61G [00:52<00:00, 50.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Z5wEwnZgn3oUGubBVUbOy6fOuiA2F5Dh
To: /content/esm1v_ft_epitopes_3.pth
100% 2.61G/2.61G [00:54<00:00, 48.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1P0zm_AcRFCAE3Y8w_25yoyyME4q1eqCQ
To: /content/esm1v_ft_epitopes_4.pth
100% 2.61G/2.61G [00:54<00:00, 47.8MB/s]
Downloading...
From: https://drive.google.com/uc?id=15-WU5YuslcYPS76C8LVUGBBBHSURTCkR
To: /content/dummy.fasta
100% 26.8k/26.8k [00:00<00:00, 22.0MB/s]


In [1]:
from google.colab import drive, files

import torch
from torch.utils.data import Dataset
from torch import nn

import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

import pandas as pd
import numpy as np

import esm
from esm import ProteinBertModel
from esm.pretrained import load_model_and_alphabet_hub

from Bio import SeqIO
from io import StringIO, BytesIO
from tqdm import tqdm

In [2]:
class ESM1vForTokenClassification(nn.Module):

    def __init__(self, num_labels = 2):
        super().__init__()
        self.num_labels = num_labels       
        self.esm1v, self.esm1v_alphabet = load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1")
        self.classifier = nn.Linear(1280, self.num_labels)

    def forward(self, token_ids):
        outputs = self.esm1v.forward(token_ids, repr_layers=[33])['representations'][33]
        outputs = outputs[:,1:-1,:]
        logits = self.classifier(outputs)
        return SequenceClassifierOutput(logits=logits)
    
class pred_Dataset(Dataset):
    def __init__(self, df):
        self.df = df
        esm1v_alphabet = esm.Alphabet.from_architecture('roberta_large')
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

    def __getitem__(self, idx):
        _, _, esm1b_batch_tokens = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx,1])[:1022])])
        return {'token_ids': esm1b_batch_tokens}
    
    def __len__(self):
        return len(self.df)

In [None]:
model = ESM1vForTokenClassification()

# Compute epitopes

In [None]:
uploaded = files.upload()
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

In [None]:
for key in uploaded.keys():
  identifiers = []
  seqs = []
  for seq_record in SeqIO.parse(StringIO(BytesIO(uploaded[key]).read().decode('UTF-8')), 'fasta'):  # (generator)
    identifiers.append(seq_record.id)
    seqs.append(seq_record.seq)

  pred_df = pd.DataFrame({'id':identifiers, 'seq': seqs})
  pred_ds = pred_Dataset(pred_df)

  res=[]
  for ens_idx in range(5):
    
    model.load_state_dict(torch.load('esm1v_ft_epitopes_regr'+str(ens_idx) + '.pth'))
    model.eval()
    model.cuda()

    with torch.no_grad():
        preds=[]
        for it in tqdm(pred_ds):
            preds.append(model.forward(it['token_ids'].cuda())[0][0][:,1].cpu().numpy())
    res.append(preds)

  merged=[]
  for it_idx in range(len(res[0])):
    merged.append(np.mean(np.stack([res[i][it_idx] for i in range(5)], axis =0), axis =0))
    
  pred_df['epitope_prediction'] = merged
  pred_df[['id', 'epitope_prediction']].to_pickle(key + '.preds.pkl')

In [None]:
from google.colab import files
for key in uploaded.keys():
  files.download(key + '.preds.pkl') 