In [1]:
%env TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub

env: TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub


In [2]:
import pandas as pd
import numpy as np
import Bio
from Bio import SeqIO
import os
import torch
import math
import esm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import scipy
from scipy import stats

import gc

In [3]:
import torch.nn.functional as F

HIDDEN_UNITS_POS_CONTACT = 5
class ESMForSingleMutationPosConcat(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.fc1 = nn.Linear(1280 * 2, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outputs_pos_concat = torch.cat((outputs1_pos, outputs2_pos), 2)
        fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
        logits = self.fc2(fc1_outputs)
        return logits

In [4]:
HIDDEN_UNITS_POS_OUTER = 5
class ESMForSingleMutationPosOuter(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self._freeze_esm2_layers()
        self.fc1 = nn.Linear(1280 * 1280, HIDDEN_UNITS_POS_OUTER)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_OUTER, 1)

    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[
            :initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outer_prod = outputs1_pos.unsqueeze(3) @ outputs2_pos.unsqueeze(2)
        outer_prod_view = outer_prod.view(outer_prod.shape[0], outer_prod.shape[1], -1)
        fc1_outputs = F.relu(self.fc1(outer_prod_view))
        logits = self.fc2(fc1_outputs)
        return logits

In [5]:
class ESMForSingleMutation_pos(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,pos + 1,:] + self.const2 * outputs2[:,pos + 1,:]        
        logits = self.classifier(outputs)
        return logits

In [6]:
class ESMForSingleMutation_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]        
        logits = self.classifier(outputs.unsqueeze(0))
        return logits

In [7]:
class ESMForSingleMutation_pos_cat_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280*2, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        cls_out = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]
        pos_out = self.const1 * outputs1[:,pos+1,:] + self.const2 * outputs2[:,pos+1,:]
        outputs = torch.cat([cls_out.unsqueeze(0), pos_out], axis = -1)        
        logits = self.classifier(outputs)
        return logits

In [8]:
class ProteinDataset(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

    def __getitem__(self, idx):
        _, _, esm1b_batch_tokens1 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['wt_seq'])[:1022])])
        _, _, esm1b_batch_tokens2 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['mut_seq'])[:1022])])
        pos = self.df.iloc[idx]['pos']
        return esm1b_batch_tokens1, esm1b_batch_tokens2, pos, torch.unsqueeze(torch.FloatTensor([self.df.iloc[idx]['ddg']]), 0)

    def __len__(self):
        return len(self.df)

In [9]:
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    model.train()
    
    for idx, batch in enumerate(training_loader):
        input_ids1, input_ids2, pos, labels = batch            
        input_ids1 = input_ids1[0].to(device)
        input_ids2 = input_ids2[0].to(device)
        logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos).to('cpu')        
        loss = torch.nn.functional.mse_loss(logits, labels)
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    epoch_loss = tr_loss / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")

In [10]:
def valid(model, testing_loader):
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels, eval_scores = [], [], []
    
    with torch.no_grad():
        for idx, batch in enumerate(testing_loader):
            
            input_ids1, input_ids2, pos, labels = batch            
            input_ids1 = input_ids1[0].to(device)
            input_ids2 = input_ids2[0].to(device)
            labels = labels.to(device)
            logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos)        
            loss = torch.nn.functional.mse_loss(logits, labels)
            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)
                     
            eval_labels.extend(labels.cpu().detach())
            eval_preds.extend(logits.cpu().detach())
            
  
    labels = [id.item() for id in eval_labels]
    predictions = [id.item() for id in eval_preds]
    
    eval_loss = eval_loss / nb_eval_steps
    print(f"Validation Loss: {eval_loss}")

    return labels, predictions

In [11]:
lr = 1e-5
EPOCHS = 3
device = 'cuda:3'

models = ['ESMForSingleMutationPosOuter',
          'ESMForSingleMutationPosConcat',
          'ESMForSingleMutation_pos_cat_cls',  
              'ESMForSingleMutation_pos', 
              'ESMForSingleMutation_cls']

full_df = pd.read_csv('DATASETS/new_ds_with_folds.csv')

preds = {n:[] for n in models} 
true = [None]*5

for fold_no in range(5):
    for model_name in models:
        model_class = globals()[model_name]
        print(f'Training model {model_name} on fold {fold_no}')
        train_df, test_df = full_df[full_df.fold!=fold_no], full_df[full_df.fold==fold_no]
        train_ds, test_ds = ProteinDataset(train_df), ProteinDataset(test_df)
        
        model = model_class()                        
        model.to(device) 
        optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
        training_loader = DataLoader(train_ds, batch_size=1, num_workers = 2, shuffle = True)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(training_loader), epochs=EPOCHS)
        
        for epoch in range(EPOCHS):
            train(epoch)
            
        testing_loader = DataLoader(test_ds, batch_size=1, num_workers = 2)
        labels, predictions = valid(model, testing_loader)
        print(f'MAE {np.mean(np.abs(np.array(labels) - np.array(predictions)))} Correlation {stats.spearmanr(labels, predictions)}')     
        preds[model_name].append(predictions)
        if true[fold_no] is None:
            true[fold_no] = labels

         
        model.to('cpu')
        del model
    print()
        
    

Training model ESMForSingleMutationPosOuter on fold 0
Training loss epoch: 2.635423413551298
Training loss epoch: 1.1322193411048722
Training loss epoch: 0.6197464244398997
Validation Loss: 3.172407191069145
MAE 1.1664041202748194 Correlation SpearmanrResult(correlation=0.6642448778966968, pvalue=1.8245666909273552e-268)
Training model ESMForSingleMutationPosConcat on fold 0
Training loss epoch: 3.6268955301182415
Training loss epoch: 1.5482527004603415
Training loss epoch: 0.6599749537010731
Validation Loss: 3.353341689513026
MAE 1.2339659470389288 Correlation SpearmanrResult(correlation=0.613092901597194, pvalue=7.370066835129871e-218)
Training model ESMForSingleMutation_pos_cat_cls on fold 0
Training loss epoch: 2.5174619745669298
Training loss epoch: 0.7621416057688709
Training loss epoch: 0.19051986832222012
Validation Loss: 3.228121882597974
MAE 1.1842531303281363 Correlation SpearmanrResult(correlation=0.6476683091905255, pvalue=5.402042441777921e-251)
Training model ESMForSingl

In [12]:
all_true = np.concatenate(true)
for model in models:
    all_pred = np.concatenate(preds[model])
    print(f'{model} RMSE {np.sqrt(np.mean((all_true-all_pred)**2))} MAE {np.mean(np.abs(all_true - all_pred))} Correlation {stats.spearmanr(all_true, all_pred)}')     

print()
ens_pred = np.mean(np.stack([np.concatenate(preds[model]) for model in models], axis = 0), axis = 0)
print(f'Ensemble RMSE {np.sqrt(np.mean((all_true-ens_pred)**2))} MAE {np.mean(np.abs(all_true - ens_pred))} Correlation {stats.spearmanr(all_true, ens_pred)}')     
    

ESMForSingleMutationPosOuter RMSE 1.6537080398020116 MAE 1.1164675282649332 Correlation SpearmanrResult(correlation=0.6460236725156133, pvalue=0.0)
ESMForSingleMutationPosConcat RMSE 1.636556025307776 MAE 1.103576905154647 Correlation SpearmanrResult(correlation=0.637497712577878, pvalue=0.0)
ESMForSingleMutation_pos_cat_cls RMSE 1.5999260928945713 MAE 1.0615672858765015 Correlation SpearmanrResult(correlation=0.6696735701582707, pvalue=0.0)
ESMForSingleMutation_pos RMSE 1.595591303549888 MAE 1.0619648237340489 Correlation SpearmanrResult(correlation=0.6673067138634392, pvalue=0.0)
ESMForSingleMutation_cls RMSE 1.6958673546219474 MAE 1.143649272722892 Correlation SpearmanrResult(correlation=0.6615472579470065, pvalue=0.0)

Ensemble RMSE 1.566587274006694 MAE 1.0343388491591312 Correlation SpearmanrResult(correlation=0.6945659312361261, pvalue=0.0)
