In [1]:
%env TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub

env: TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub


In [2]:
import pandas as pd
import numpy as np
import Bio
from Bio import SeqIO
import os
import torch
import math
import esm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import scipy
from scipy import stats

In [3]:
class ESMForSingleMutation_pos(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,pos + 1,:] + self.const2 * outputs2[:,pos + 1,:]        
        logits = self.classifier(outputs)
        return logits

In [4]:
class ESMForSingleMutation_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]        
        logits = self.classifier(outputs)
        return logits

In [5]:
class ProteinDataset(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

    def __getitem__(self, idx):
        _, _, esm1b_batch_tokens1 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['wt_seq'])[:1022])])
        _, _, esm1b_batch_tokens2 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['mut_seq'])[:1022])])
        pos = self.df.iloc[idx]['pos']
        return esm1b_batch_tokens1, esm1b_batch_tokens2, pos, torch.unsqueeze(torch.FloatTensor([self.df.iloc[idx]['ddg']]), 0)

    def __len__(self):
        return len(self.df)

In [6]:
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    model.train()
    
    for idx, batch in enumerate(training_loader):
        input_ids1, input_ids2, pos, labels = batch            
        input_ids1 = input_ids1[0].to(device)
        input_ids2 = input_ids2[0].to(device)
        logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos).to('cpu')        
        loss = torch.nn.functional.mse_loss(logits, labels)
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    epoch_loss = tr_loss / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")

In [7]:
def valid(model, testing_loader):
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels, eval_scores = [], [], []
    
    with torch.no_grad():
        for idx, batch in enumerate(testing_loader):
            
            input_ids1, input_ids2, pos, labels = batch            
            input_ids1 = input_ids1[0].to(device)
            input_ids2 = input_ids2[0].to(device)
            labels = labels.to(device)
            logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos)        
            loss = torch.nn.functional.mse_loss(logits, labels)
            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)
                     
            eval_labels.extend(labels.cpu().detach())
            eval_preds.extend(logits.cpu().detach())
            
  
    labels = [id.item() for id in eval_labels]
    predictions = [id.item() for id in eval_preds]
    
    eval_loss = eval_loss / nb_eval_steps
    print(f"Validation Loss: {eval_loss}")

    return labels, predictions

In [8]:
experiments = [
    (['S2648.csv'], ['ssym.csv', 'ssym_r.csv', 'myoglobin.csv', 'myoglobin_r.csv']),
    (['S3421.csv'], ['ssym.csv', 'ssym_r.csv', 'myoglobin.csv', 'myoglobin_r.csv']),
    (['S3488.csv'], ['ssym.csv', 'ssym_r.csv']),
    (['S2648.csv', 'ACDC_varibench.csv'], ['ssym.csv', 'ssym_r.csv', 'myoglobin.csv', 'p53.csv']),
    (['deepddg_train.csv'], 'deepddg_test.csv')
]

In [None]:
lr = 1e-5
EPOCHS = 3
device = 'cuda:2'
                        
for train_datasets, test_datasets in experiments:
    print('Train on ', train_datasets)
    train_df = pd.concat([pd.read_csv(os.path.join('DATASETS', t)) for t in train_datasets])
    train_ds = ProteinDataset(train_df)
    
    models = [ESMForSingleMutation_pos(), ESMForSingleMutation_cls()]
    all_preds = {}
    all_true = {}
    for model_no, model in enumerate(models):    
        model.to(device) 
        optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
        training_loader = DataLoader(train_ds, batch_size=1, num_workers = 2, shuffle = True)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(training_loader), epochs=EPOCHS)
        

        print(f'Training model {model_no}')
        for epoch in range(EPOCHS):
            train(epoch)
        print()
        
        for test_no, test_dataset in enumerate(test_datasets):
            if test_no not in all_preds:
                all_preds[test_no] = {}
            print('Test on ', test_dataset)
            test_df = pd.read_csv(os.path.join('DATASETS', test_dataset))
            test_ds = ProteinDataset(test_df)
            testing_loader = DataLoader(test_ds, batch_size=1, num_workers = 2)
            labels, predictions = valid(model, testing_loader)
            print(f'MAE {np.mean(np.abs(np.array(labels) - np.array(predictions)))} Correlation {stats.spearmanr(labels, predictions)}')     
            all_preds[test_no][model_no] = predictions
            all_true[test_no] = labels
            
            
        print()
    
    for test_idx in all_true.keys():
        print(f'Ensemble result for dataset {test_datasets[test_idx]}')
        ens_preds = np.mean(np.stack([all_preds[test_idx][t] for t in all_preds[test_idx].keys()], axis = -1), axis = -1)

        print(f'Correlation {stats.spearmanr(all_true[test_idx], ens_preds)}')
        print(f'MSE {np.sqrt(np.mean((np.array(ens_preds)-np.array(all_true[test_idx]))**2))}')        
        print(f'MAE {np.mean(np.abs(np.array(all_true[test_idx]) - np.array(ens_preds)))}')    
    
    print('*****************')
    print()
    del optimizer, scheduler, model
    
    #model.to('cpu')


Train on  ['S2648.csv']
Training model 0
Training loss epoch: 2.5486238524853078
Training loss epoch: 0.9919666604479591
Training loss epoch: 0.38118088802880873

Test on  ssym.csv
Validation Loss: 0.9462949894599406
MAE 0.6437292271307316 Correlation SpearmanrResult(correlation=0.7837955995257001, pvalue=2.49183792994039e-72)
Test on  ssym_r.csv
Validation Loss: 0.9584759757189668
MAE 0.6488666684374932 Correlation SpearmanrResult(correlation=0.7836012845967981, pvalue=2.8504199243205333e-72)
Test on  myoglobin.csv
Validation Loss: 0.7291574193528808
MAE 0.5926445554941893 Correlation SpearmanrResult(correlation=0.644766556664865, pvalue=4.205138302846179e-17)
Test on  myoglobin_r.csv
Validation Loss: 0.7290051383196774
MAE 0.5900852801486739 Correlation SpearmanrResult(correlation=0.6446169174823382, pvalue=4.298742671163432e-17)

Training model 1


  loss = torch.nn.functional.mse_loss(logits, labels)


Training loss epoch: 2.965374656697283
Training loss epoch: 1.9291575581249363
Training loss epoch: 1.000102159788676

Test on  ssym.csv


  loss = torch.nn.functional.mse_loss(logits, labels)


Validation Loss: 1.0555514228278529
MAE 0.706959012574489 Correlation SpearmanrResult(correlation=0.7574394428100879, pvalue=6.3780581430158415e-65)
Test on  ssym_r.csv


  loss = torch.nn.functional.mse_loss(logits, labels)


Validation Loss: 1.2880453322543342
MAE 0.7998354901526973 Correlation SpearmanrResult(correlation=0.7501494068982905, pvalue=4.87740599967887e-63)
Test on  myoglobin.csv


  loss = torch.nn.functional.mse_loss(logits, labels)


Validation Loss: 0.6544752783552958
MAE 0.568952531067293 Correlation SpearmanrResult(correlation=0.735814517273325, pvalue=4.212772686578414e-24)
Test on  myoglobin_r.csv


  loss = torch.nn.functional.mse_loss(logits, labels)


Validation Loss: 0.8323934726527589
MAE 0.6367423246239325 Correlation SpearmanrResult(correlation=0.7356174923496647, pvalue=4.3933549488328215e-24)

Ensemble result for dataset ssym.csv
Correlation SpearmanrResult(correlation=0.8028878290535527, pvalue=2.2207225171015702e-78)
MSE 0.9426828407500886
MAE 0.6247719830815675
Ensemble result for dataset ssym_r.csv
Correlation SpearmanrResult(correlation=0.8020883619174985, pvalue=4.1022157161219596e-78)
MSE 0.9982676618564813
MAE 0.6673528479040486
Ensemble result for dataset myoglobin.csv
Correlation SpearmanrResult(correlation=0.7675579758600201, pvalue=2.863625760803683e-27)
MSE 0.7939569765127532
MAE 0.5520670894783601
Ensemble result for dataset myoglobin_r.csv
Correlation SpearmanrResult(correlation=0.7679420497618389, pvalue=2.603340618590685e-27)
MSE 0.8213995229064927
MAE 0.554327963013897
*****************

Train on  ['S3421.csv']
Training model 0
