In [1]:
%env TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub

env: TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub


In [2]:
import pandas as pd
import numpy as np
import Bio
from Bio import SeqIO
import os
import torch
import math
import esm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import scipy
from scipy import stats

import gc
import pickle
import logging

logging.basicConfig(filename='PROSTATA_experiments_pearson.log', encoding='utf-8', level=logging.DEBUG)

In [3]:
import torch.nn.functional as F

HIDDEN_UNITS_POS_CONTACT = 5
class ESMForSingleMutationPosConcat(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.fc1 = nn.Linear(1280 * 2, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outputs_pos_concat = torch.cat((outputs1_pos, outputs2_pos), 2)
        fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
        logits = self.fc2(fc1_outputs)
        return logits

In [4]:
HIDDEN_UNITS_POS_OUTER = 5
class ESMForSingleMutationPosOuter(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self._freeze_esm2_layers()
        self.fc1 = nn.Linear(1280 * 1280, HIDDEN_UNITS_POS_OUTER)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_OUTER, 1)

    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[
            :initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outer_prod = outputs1_pos.unsqueeze(3) @ outputs2_pos.unsqueeze(2)
        outer_prod_view = outer_prod.view(outer_prod.shape[0], outer_prod.shape[1], -1)
        fc1_outputs = F.relu(self.fc1(outer_prod_view))
        logits = self.fc2(fc1_outputs)
        return logits

In [5]:
class ESMForSingleMutation_pos(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,pos + 1,:] + self.const2 * outputs2[:,pos + 1,:]        
        logits = self.classifier(outputs)
        return logits

In [6]:
class ESMForSingleMutation_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]        
        logits = self.classifier(outputs.unsqueeze(0))
        return logits

In [7]:
class ESMForSingleMutation_pos_cat_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280*2, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        cls_out = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]
        pos_out = self.const1 * outputs1[:,pos+1,:] + self.const2 * outputs2[:,pos+1,:]
        outputs = torch.cat([cls_out.unsqueeze(0), pos_out], axis = -1)        
        logits = self.classifier(outputs)
        return logits

In [8]:
class ProteinDataset(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

    def __getitem__(self, idx):
        _, _, esm1b_batch_tokens1 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['wt_seq'])[:1022])])
        _, _, esm1b_batch_tokens2 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['mut_seq'])[:1022])])
        pos = self.df.iloc[idx]['pos']
        return esm1b_batch_tokens1, esm1b_batch_tokens2, pos, torch.unsqueeze(torch.FloatTensor([self.df.iloc[idx]['ddg']]), 0)

    def __len__(self):
        return len(self.df)

In [9]:
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    model.train()
    
    for idx, batch in enumerate(training_loader):
        input_ids1, input_ids2, pos, labels = batch            
        input_ids1 = input_ids1[0].to(device)
        input_ids2 = input_ids2[0].to(device)
        logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos).to('cpu')        
        loss = torch.nn.functional.mse_loss(logits, labels)
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    epoch_loss = tr_loss / nb_tr_steps
    logging.info(f"Training loss epoch: {epoch_loss}")

In [10]:
def valid(model, testing_loader):
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels, eval_scores = [], [], []
    
    with torch.no_grad():
        for idx, batch in enumerate(testing_loader):
            
            input_ids1, input_ids2, pos, labels = batch            
            input_ids1 = input_ids1[0].to(device)
            input_ids2 = input_ids2[0].to(device)
            labels = labels.to(device)
            logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos)        
            loss = torch.nn.functional.mse_loss(logits, labels)
            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)
                     
            eval_labels.extend(labels.cpu().detach())
            eval_preds.extend(logits.cpu().detach())
            
  
    labels = [id.item() for id in eval_labels]
    predictions = [id.item() for id in eval_preds]
    
    eval_loss = eval_loss / nb_eval_steps
    logging.info(f"Validation Loss: {eval_loss}")

    return labels, predictions

In [11]:
experiments = [{"train":["Q3488"],             "test":["ssym",           "ssym_r"]},
               {"train":["Q3488"],             "test":["p53",             "p53_r"]},
               {"train":["Q3488"],             "test":["myoglobin", "myoglobin_r"]},
               {"train":["S2648","Vb","Vb_r"], "test":["ssym",           "ssym_r"]},
               {"train":["S2648","Vb","Vb_r"], "test":["p53",             "p53_r"]},
               {"train":["S2648","Vb","Vb_r"], "test":["L20",             "L20_r"]},
               {"train":["S2648","Vb","Vb_r"], "test":["myoglobin", "myoglobin_r"]},
               {"train":["Q3421"],             "test":["ssym",           "ssym_r"]},
               {"train":["prostata"],          "test":["s669","s669_r","s669_dssp"]},
               {"train":["prostata"],          "test":["hem_test"]},
               {"train":["prostata"],          "test":["oligo_test"]},
               {"train":["Q3421"],             "test":["ssym", "ssym_r"]}]
              

In [12]:
lr = 1e-5
EPOCHS = 3
device = 'cuda:1'


models = ['ESMForSingleMutationPosOuter',
          'ESMForSingleMutationPosConcat',
          'ESMForSingleMutation_pos_cat_cls',  
              'ESMForSingleMutation_pos', 
              'ESMForSingleMutation_cls']
                        
for experiment in experiments:
    train_datasets, 
    test_datasets
    
    logging.info('Train on ' + str(train_datasets))
    train_df = pd.concat([pd.read_csv(os.path.join('DATASETS', t)) for t in train_datasets])
    train_ds = ProteinDataset(train_df)
    
    
    all_preds = {}
    all_true = {}
    for model_no, model_name in enumerate(models):
        model_class = globals()[model_name]
        logging.info(f'Training model {model_name}')
        model = model_class()                        
        model.to(device)  
        optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
        training_loader = DataLoader(train_ds, batch_size=1, num_workers = 2, shuffle = True)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(training_loader), epochs=EPOCHS)

        for epoch in range(EPOCHS):
            train(epoch)
        
        
        for test_no, test_dataset in enumerate(test_datasets):
            if test_no not in all_preds:
                all_preds[test_no] = {}
            logging.info('Test on ' + str(test_dataset))
            test_df = pd.read_csv(os.path.join('DATASETS', test_dataset))
            test_ds = ProteinDataset(test_df)
            testing_loader = DataLoader(test_ds, batch_size=1, num_workers = 2)
            labels, predictions = valid(model, testing_loader)
            logging.info(f'MAE {np.mean(np.abs(np.array(labels) - np.array(predictions)))} Correlation {stats.pearsonr(labels, predictions)}')     
            all_preds[test_no][model_no] = predictions
            all_true[test_no] = labels
            
        del model, optimizer, scheduler
        #torch.cuda.empty_cache()    
        logging.info('')
    
    for test_idx in all_true.keys():
        logging.info(f'Ensemble result for dataset {test_datasets[test_idx]}')
        ens_preds = np.mean(np.stack([all_preds[test_idx][t] for t in all_preds[test_idx].keys()], axis = -1), axis = -1)

        logging.info(f'Correlation {stats.pearsonr(all_true[test_idx], ens_preds)}')
        logging.info(f'RMSE {np.sqrt(np.mean((np.array(ens_preds)-np.array(all_true[test_idx]))**2))}')        
        logging.info(f'MAE {np.mean(np.abs(np.array(all_true[test_idx]) - np.array(ens_preds)))}')
        
        with open(os.path.join('PREDS', 'train_on_' + '_and_'.join(train_datasets) + 'test_on_' + test_datasets[test_idx] + '.preds.pkl'),"wb") as fh:
              pickle.dump(ens_preds, fh)
    
    logging.info('*****************')
    logging.info('')
