In [1]:
%env TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub

env: TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub


In [2]:
import pandas as pd
import numpy as np
import Bio
from Bio import SeqIO
import os
import torch
import math
import esm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import scipy
from scipy import stats

In [3]:
class ESMForSingleMutation(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))

        

    def forward(self, token_ids1, token_ids2, pos):
                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        
        outputs = self.const1 * outputs1[:,pos,:] + self.const2 * outputs2[:,pos,:]
        
        #outputs = self.merge(torch.stack(, outputs2[:,0,:]]))
        #outputs = torch.cat([torch.mean(outputs, dim = 1), torch.max(outputs, dim = 1).values], dim = 1)
        logits = self.classifier(outputs)

        return logits

In [4]:
class ProteinDataset(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

    def __getitem__(self, idx):
        _, _, esm1b_batch_tokens1 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['wt_seq'])[:1022])])
        _, _, esm1b_batch_tokens2 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['mut_seq'])[:1022])])
        pos = self.df.iloc[idx]['pos']
        return esm1b_batch_tokens1, esm1b_batch_tokens2, pos, torch.unsqueeze(torch.FloatTensor([self.df.iloc[idx]['ddg']]), 0)

    def __len__(self):
        return len(self.df)

In [5]:
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    # put model in training mode
    model.train()
    
    for idx, batch in enumerate(training_loader):
        input_ids1, input_ids2, pos, labels = batch            
        input_ids1 = input_ids1[0].to(device)
        input_ids2 = input_ids2[0].to(device)
        #labels = labels.to(device)
        #print(model.device, input_ids.device, labels.device)
        #print(input_ids)
        #print(input_ids.size())
        logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos).to('cpu')        
        #print(outputs)
        loss = torch.nn.functional.mse_loss(logits, labels)
        #print(model(input_ids=ids, attention_mask=mask, labels=labels))
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        #if idx % 10==0:
        #    loss_step = tr_loss/nb_tr_steps
        #    print(f"Training loss per 10 training steps: {loss_step}")
               
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
        )
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    epoch_loss = tr_loss / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")

In [6]:
def valid(model, testing_loader):
    # put model in evaluation mode
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels, eval_scores = [], [], []
    
    with torch.no_grad():
        for idx, batch in enumerate(testing_loader):
            
            input_ids1, input_ids2, pos, labels = batch            
            input_ids1 = input_ids1[0].to(device)
            input_ids2 = input_ids2[0].to(device)
            labels = labels.to(device)
            #print(model.device, input_ids.device, labels.device)
            #print(input_ids)
            #print(input_ids.size())
            logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos)        
            #print(outputs)
            loss = torch.nn.functional.mse_loss(logits, labels)
            #print(model(input_ids=ids, attention_mask=mask, labels=labels))
            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)
        
            #if idx % 10==0:
            #    loss_step = eval_loss/nb_eval_steps
            #    print(f"Validation loss per 100 evaluation steps: {loss_step}")
             
            eval_labels.extend(labels.cpu().detach())
            eval_preds.extend(logits.cpu().detach())
            
  
    labels = [id.item() for id in eval_labels]
    predictions = [id.item() for id in eval_preds]
    
    eval_loss = eval_loss / nb_eval_steps
    print(f"Validation Loss: {eval_loss}")

    return labels, predictions

In [7]:
experiments = [
    (['S3488.csv'], 'ssym.csv'),
    (['S3488.csv'], 'ssym_r.csv'),
    (['S2648.csv', 'ACDC_varibench.csv'], 'ssym.csv'),
    (['S2648.csv', 'ACDC_varibench.csv'], 'ssym_r.csv'),
    (['S2648.csv', 'ACDC_varibench.csv'], 'myoglobin.csv'),
    (['S2648.csv', 'ACDC_varibench.csv'], 'p53.csv'),    
    (['S2648.csv'], 'ssym.csv'),
    (['S2648.csv'], 'ssym_r.csv'),
    (['S2648.csv'], 'myoglobin.csv'),
    (['S2648.csv'], 'myoglobin_r.csv'),
    (['S3421.csv'], 'ssym.csv'),
    (['S3421.csv'], 'ssym_r.csv'),
    (['S3421.csv'], 'myoglobin.csv'),
    (['S3421.csv'], 'myoglobin_r.csv'), 
]

In [8]:
lr = 1e-5
EPOCHS = 3
device = 'cuda:2'
                        
for train_datasets, test_dataset in experiments:
    print('Train on ', train_datasets)
    print('Test on ', test_dataset)
    
    train_df = pd.concat([pd.read_csv(os.path.join('DATASETS', t)) for t in train_datasets])
    test_df = pd.read_csv(os.path.join('DATASETS', test_dataset))
    
    train_ds, test_ds = ProteinDataset(train_df), ProteinDataset(test_df)
    
    training_loader = DataLoader(train_ds, batch_size=1, num_workers = 2, shuffle = True)
    testing_loader = DataLoader(test_ds, batch_size=1, num_workers = 2)
    
    model = ESMForSingleMutation()
    model.to(device)
    #model.const1 = model.const1.to(device)
    #model.const2 = model.const2.to(device)
    
    
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(training_loader), epochs=EPOCHS)
    
    for epoch in range(EPOCHS):
        train(epoch)
        labels, predictions = valid(model, testing_loader)
    
    print(f'MAE {np.mean(np.abs(np.array(labels) - np.array(predictions)))} Correlation {stats.spearmanr(labels, predictions)}')     
    print('*****************')
    del optimizer, scheduler, model
    
    #model.to('cpu')


Train on  ['S3488.csv']
Test on  ssym.csv
Training loss epoch: 2.3765579482604955
Validation Loss: 6.514057166844965
Training loss epoch: 0.790747747586766
Validation Loss: 7.606570858323418
Training loss epoch: 0.19476715091208327
Validation Loss: 7.941786421074331
MAE 2.030546653857836 Correlation SpearmanrResult(correlation=-0.5590952175483324, pvalue=1.6315954600914708e-29)
*****************
Train on  ['S3488.csv']
Test on  ssym_r.csv
Training loss epoch: 2.2893451735836168
Validation Loss: 7.074502755399578
Training loss epoch: 0.6171288764961305
Validation Loss: 7.9161272492797226
Training loss epoch: 0.15353070672759242
Validation Loss: 7.9343863616425425
MAE 2.0373377312214402 Correlation SpearmanrResult(correlation=-0.5512651511893392, pvalue=1.4013630361929296e-28)
*****************
Train on  ['S2648.csv', 'ACDC_varibench.csv']
Test on  ssym.csv
Training loss epoch: 2.7818096098169347
Validation Loss: 7.866931681285076
Training loss epoch: 1.187477685575817
Validation Loss: 1

In [9]:
print(f'MAE {np.mean(np.abs(np.array(labels) - np.array(predictions)))} Correlation {stats.spearmanr(labels, predictions)}') 

MAE 1.3417977737662479 Correlation SpearmanrResult(correlation=-0.6343017898334876, pvalue=1.9049380611657633e-16)


In [11]:
experiments = [
    (['deepddg_train.csv'], 'deepddg_test.csv')]

In [12]:
lr = 1e-5
EPOCHS = 3
device = 'cuda:2'
                        
for train_datasets, test_dataset in experiments:
    print('Train on ', train_datasets)
    print('Test on ', test_dataset)
    
    train_df = pd.concat([pd.read_csv(os.path.join('DATASETS', t)) for t in train_datasets])
    test_df = pd.read_csv(os.path.join('DATASETS', test_dataset))
    
    train_ds, test_ds = ProteinDataset(train_df), ProteinDataset(test_df)
    
    training_loader = DataLoader(train_ds, batch_size=1, num_workers = 2, shuffle = True)
    testing_loader = DataLoader(test_ds, batch_size=1, num_workers = 2)
    
    model = ESMForSingleMutation()
    model.to(device)
    #model.const1 = model.const1.to(device)
    #model.const2 = model.const2.to(device)
    
    
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(training_loader), epochs=EPOCHS)
    
    for epoch in range(EPOCHS):
        train(epoch)
        labels, predictions = valid(model, testing_loader)
    
    print(f'MAE {np.mean(np.abs(np.array(labels) - np.array(predictions)))} Correlation {stats.spearmanr(labels, predictions)}')     
    print('*****************')
    del optimizer, scheduler, model

Train on  ['deepddg_train.csv']
Test on  deepddg_test.csv


ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/ayb/vol2/home/dumerenkov/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/ayb/vol2/home/dumerenkov/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/ayb/vol2/home/dumerenkov/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_21833/1540760637.py", line 11, in __getitem__
    return esm1b_batch_tokens1, esm1b_batch_tokens2, pos, torch.unsqueeze(torch.FloatTensor([self.df.iloc[idx]['ddg']]), 0)
ValueError: too many dimensions 'str'
