<a href="https://colab.research.google.com/github/mitiau/PROSTATA/blob/HSE_seminar/PROSTATA_tool.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install dependecies and download weights

In [None]:
!pip install transformers
!pip install fair-esm
!pip install biopython
!pip install gdown==4.5.4

In [None]:
from google.colab import drive, files

import torch
from torch.utils.data import Dataset
from torch import nn

import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

import pandas as pd
import numpy as np
import random

import esm
from esm import ProteinBertModel
from esm.pretrained import load_model_and_alphabet_hub

from Bio import SeqIO
from io import StringIO, BytesIO
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutationPosConcat
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutationPosOuter
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutation_cls
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutation_pos
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutation_pos_cat_cls

In [None]:
!git clone https://github.com/mitiau/PROSTATA.git
!git -C PROSTATA checkout HSE_seminar
!git -C PROSTATA pull

In [None]:
import torch.nn.functional as F

HIDDEN_UNITS_POS_CONTACT = 5
class ESMForSingleMutationPosConcat(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.fc1 = nn.Linear(1280 * 2, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outputs_pos_concat = torch.cat((outputs1_pos, outputs2_pos), 2)
        fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
        logits = self.fc2(fc1_outputs)
        return logits
    
HIDDEN_UNITS_POS_OUTER = 5
class ESMForSingleMutationPosOuter(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self._freeze_esm2_layers()
        self.fc1 = nn.Linear(1280 * 1280, HIDDEN_UNITS_POS_OUTER)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_OUTER, 1)

    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[
            :initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outer_prod = outputs1_pos.unsqueeze(3) @ outputs2_pos.unsqueeze(2)
        outer_prod_view = outer_prod.view(outer_prod.shape[0], outer_prod.shape[1], -1)
        fc1_outputs = F.relu(self.fc1(outer_prod_view))
        logits = self.fc2(fc1_outputs)
        return logits
    
class ESMForSingleMutation_pos(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,pos + 1,:] + self.const2 * outputs2[:,pos + 1,:]        
        logits = self.classifier(outputs)
        return logits
    
class ESMForSingleMutation_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]        
        logits = self.classifier(outputs.unsqueeze(0))
        return logits
    
class ESMForSingleMutation_pos_cat_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280*2, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        cls_out = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]
        pos_out = self.const1 * outputs1[:,pos+1,:] + self.const2 * outputs2[:,pos+1,:]
        outputs = torch.cat([cls_out.unsqueeze(0), pos_out], axis = -1)        
        logits = self.classifier(outputs)
        return logits
    

In [None]:
model_names = ['ESMForSingleMutationPosOuter',
          'ESMForSingleMutationPosConcat',
          'ESMForSingleMutation_pos_cat_cls',  
              'ESMForSingleMutation_pos', 
              'ESMForSingleMutation_cls']

# Compute DeltaDDG for test set and compare with experimental data

In [None]:
model = torch.load('ESMForSingleMutation_cls', map_location=torch.device('cpu'))
esm2_alphabet = model.esm1v_alphabet
esm2batch_converter = esm2_alphabet.get_batch_converter()

def predict_ddg(seqs, mutation_codes, poss = None):
    if poss is None:
        poss = [None]*len(seqs)
    inp = []
    for seq, mutation_code, pos in zip(seqs, mutation_codes, poss):
        print(mutation_code)
        wt_aa = mutation_code[0]
        mut_aa = mutation_code[-1]
        if pos:
            mut_pos = pos
        else:
            mut_pos = int(mutation_code[1:-1])-1

        assert seq[mut_pos] == wt_aa
        
        wt = seq
        tt = list(seq)
        tt[mut_pos] = mut_aa
        mut = ''.join(tt)

    
    
        _, _, esm2_batch_tokens1 = esm2batch_converter([('' , wt[:1022])])
        _, _, esm2_batch_tokens2 = esm2batch_converter([('' , mut[:1022])])
        esm2_batch_tokens1 = esm2_batch_tokens1.cuda()
        esm2_batch_tokens2 = esm2_batch_tokens2.cuda()
    
        inp.append((esm2_batch_tokens1, esm2_batch_tokens2, mut_pos))
    
    res = []
    for model_name in model_names:
        model = torch.load(model_name, map_location=torch.device('cpu'))
        model.eval()
        model.cuda()
        
        with torch.no_grad():
            res.append([model(token_ids1 = t1, token_ids2 = t2, 
                             pos = torch.LongTensor([p])).cpu().numpy() for t1, t2, p in inp])
        #print(f'Model {model_name} DDG prediction is {res[-1]}')
    res = np.mean(res, axis = 0)
    return res.ravel()
    

In [None]:
test_df = pd.read_csv('PROSTATA/cross_validation_datasets/test_1LNIA.csv')
test_df['ddg_pred'] = predict_ddg(test_df['wt_seq'].tolist(), 
                                  test_df['mut_info'].tolist(), 
                                  test_df['pos'].tolist())

In [None]:
y = test_df.ddg.to_list()
x = test_df.ddg_pred.to_list()
plt.scatter(x, y,alpha=0.5)
plt.show()

In [None]:
seqs = ['VINTFDGVADYLQTYHKLPDNYITKSEAQALGWVASKGNLADVAPGKSIGGDIFSNREGKLPGKSGRTWREADINYTSGFRNSDRILYSSDWLIYKTTDHYQTFTKIR']
mutation_codes = ['V1N'] #@param {type:"string"}

In [None]:
predict_ddg(seqs, mutation_codes)

# Find best mutation

In [None]:
wildtype = 'VINTFDGVADYLQTYHKLPDNYITKSEAQALGWVASKGNLADVAPGKSIGGDIFSNREGKLPGKSGRTWREADINYTSGFRNSDRILYSSDWLIYKTTDHYQTFTKIR'
wt_aas = list(set(wildtype))

In [None]:
best_seq = ''
best_score = 0

while True:
    print('Processing batch')
    muts = [random.randint(0, len(current_best)) for t in range(100)]
    muts = [wildtype[m-1] + str(m) + random.choice(wt_aas) for m in muts]
    
    res = predict_ddg([wildtype]*100, muts)
    for mut, score in zip(muts, res.to_list()):
        if score<best_score:
            best_score = score
            best_mut = mut
            print(f'Mutation {best_mut} gives score of {best_score}')