In [6]:
%env TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub

env: TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub


In [7]:
import pandas as pd
import numpy as np
import Bio
from Bio import SeqIO
import os
import torch
import math
import esm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import scipy
from scipy import stats

In [8]:
class ESMForSingleMutation(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))

        

    def forward(self, token_ids1, token_ids2, pos):
                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        
        outputs = self.const1 * outputs1[:,pos,:] + self.const2 * outputs2[:,pos,:]
        
        #outputs = self.merge(torch.stack(, outputs2[:,0,:]]))
        #outputs = torch.cat([torch.mean(outputs, dim = 1), torch.max(outputs, dim = 1).values], dim = 1)
        logits = self.classifier(outputs)

        return logits

In [10]:
class ProteinDataset(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

    def __getitem__(self, idx):
        _, _, esm1b_batch_tokens1 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['wt_seq'])[:1022])])
        _, _, esm1b_batch_tokens2 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['mut_seq'])[:1022])])
        pos = self.df.iloc[idx]['pos']
        return esm1b_batch_tokens1, esm1b_batch_tokens2, pos, torch.unsqueeze(torch.FloatTensor([self.df.iloc[idx]['ddg']]), 0)

    def __len__(self):
        return len(self.df)

In [12]:
def valid(model, testing_loader):
    # put model in evaluation mode
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels, eval_scores = [], [], []
    
    with torch.no_grad():
        for idx, batch in enumerate(testing_loader):
            
            input_ids1, input_ids2, pos, labels = batch            
            input_ids1 = input_ids1[0].to(device)
            input_ids2 = input_ids2[0].to(device)
            labels = labels.to(device)
            #print(model.device, input_ids.device, labels.device)
            #print(input_ids)
            #print(input_ids.size())
            logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos)        
            #print(outputs)
            loss = torch.nn.functional.mse_loss(logits, labels)
            #print(model(input_ids=ids, attention_mask=mask, labels=labels))
            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)
        
            #if idx % 10==0:
            #    loss_step = eval_loss/nb_eval_steps
            #    print(f"Validation loss per 100 evaluation steps: {loss_step}")
             
            eval_labels.extend(labels.cpu().detach())
            eval_preds.extend(logits.cpu().detach())
            
  
    labels = [id.item() for id in eval_labels]
    predictions = [id.item() for id in eval_preds]
    
    eval_loss = eval_loss / nb_eval_steps
    print(f"Validation Loss: {eval_loss}")

    return labels, predictions

In [6]:
data = pd.read_csv('ddg_v1.csv') 

train_df = data[data.dataset == 's2648']
test_df = data[data.dataset == 's669']

In [9]:
lr = 1e-5
EPOCHS = 3
device = 'cuda:2'

spearmanrs=[]
                        
for random_state in range(5):                        
    
    train_ds, test_ds = ProteinDataset(train_df), ProteinDataset(test_df)
    
    training_loader = DataLoader(train_ds, batch_size=1, num_workers = 2, shuffle = True)
    testing_loader = DataLoader(test_ds, batch_size=1, num_workers = 2)
    
    model = ESMForSingleMutation()
    model.to(device)
    #model.const1 = model.const1.to(device)
    #model.const2 = model.const2.to(device)
    
    
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(training_loader), epochs=EPOCHS)
    
    for epoch in range(EPOCHS):
        train(epoch)
        labels, predictions = valid(model, testing_loader)
    print('*****************')
    print(f'random state: {random_state} Correlation {stats.spearmanr(labels, predictions)}') 
    spearmanrs.append(stats.spearmanr(labels, predictions)[0])
    del optimizer, scheduler, model
    
    #model.to('cpu')
print(np.mean(spearmanrs), np.std(spearmanrs))

Training loss epoch: 2.639705097934781
Validation Loss: 2.3335013400976914
Training loss epoch: 1.1393646466857565
Validation Loss: 2.1834155305653247
Training loss epoch: 0.46742047694741345
Validation Loss: 2.1480892694312606
*****************
random state: 0 Correlation SpearmanrResult(correlation=0.5081393279348888, pvalue=3.3224215903824654e-45)
Training loss epoch: 2.692595804584385
Validation Loss: 2.358534636169595
Training loss epoch: 1.170267046882815
Validation Loss: 2.1826430898481406
Training loss epoch: 0.4346870369894481
Validation Loss: 2.1457223346142653
*****************
random state: 1 Correlation SpearmanrResult(correlation=0.5046421972690767, pvalue=1.6380419683838082e-44)
Training loss epoch: 2.6936730900981045
Validation Loss: 2.33386447071174
Training loss epoch: 1.0799071757727707
Validation Loss: 2.1721827564609932
Training loss epoch: 0.3799153778997508
Validation Loss: 2.1661947834665205
*****************
random state: 2 Correlation SpearmanrResult(correlati

In [11]:
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    # put model in training mode
    model.train()
    
    for idx, batch in enumerate(training_loader):
        input_ids1, input_ids2, pos, labels = batch            
        input_ids1 = input_ids1[0].to(device)
        input_ids2 = input_ids2[0].to(device)
        #labels = labels.to(device)
        #print(model.device, input_ids.device, labels.device)
        #print(input_ids)
        #print(input_ids.size())
        logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos).to('cpu')        
        #print(outputs)
        loss = torch.nn.functional.mse_loss(logits, labels)
        #print(model(input_ids=ids, attention_mask=mask, labels=labels))
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        #if idx % 10==0:
        #    loss_step = tr_loss/nb_tr_steps
        #    print(f"Training loss per 10 training steps: {loss_step}")
               
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
        )
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    epoch_loss = tr_loss / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")