In [1]:
%env TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub

env: TORCH_HOME=/ayb/vol2/home/dumerenkov/torch_hub


In [2]:
import pandas as pd
import numpy as np
import Bio
from Bio import SeqIO
import os
import torch
import math
import esm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import scipy
from scipy import stats

import gc

In [3]:
import torch.nn.functional as F

HIDDEN_UNITS_POS_CONTACT = 5
class ESMForSingleMutationPosConcat(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.fc1 = nn.Linear(1280 * 2, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outputs_pos_concat = torch.cat((outputs1_pos, outputs2_pos), 2)
        fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
        logits = self.fc2(fc1_outputs)
        return logits

In [4]:
HIDDEN_UNITS_POS_OUTER = 5
class ESMForSingleMutationPosOuter(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self._freeze_esm2_layers()
        self.fc1 = nn.Linear(1280 * 1280, HIDDEN_UNITS_POS_OUTER)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_OUTER, 1)

    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[
            :initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outer_prod = outputs1_pos.unsqueeze(3) @ outputs2_pos.unsqueeze(2)
        outer_prod_view = outer_prod.view(outer_prod.shape[0], outer_prod.shape[1], -1)
        fc1_outputs = F.relu(self.fc1(outer_prod_view))
        logits = self.fc2(fc1_outputs)
        return logits

In [5]:
class ESMForSingleMutation_pos(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,pos + 1,:] + self.const2 * outputs2[:,pos + 1,:]        
        logits = self.classifier(outputs)
        return logits

In [6]:
class ESMForSingleMutation_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]        
        logits = self.classifier(outputs.unsqueeze(0))
        return logits

In [7]:
class ESMForSingleMutation_pos_cat_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280*2, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        cls_out = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]
        pos_out = self.const1 * outputs1[:,pos+1,:] + self.const2 * outputs2[:,pos+1,:]
        outputs = torch.cat([cls_out.unsqueeze(0), pos_out], axis = -1)        
        logits = self.classifier(outputs)
        return logits

In [8]:
class ProteinDataset(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

    def __getitem__(self, idx):
        _, _, esm1b_batch_tokens1 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['wt_seq'])[:1022])])
        _, _, esm1b_batch_tokens2 = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx]['mut_seq'])[:1022])])
        pos = self.df.iloc[idx]['pos']
        return esm1b_batch_tokens1, esm1b_batch_tokens2, pos, torch.unsqueeze(torch.FloatTensor([self.df.iloc[idx]['ddg']]), 0)

    def __len__(self):
        return len(self.df)

In [9]:
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    model.train()
    
    for idx, batch in enumerate(training_loader):
        input_ids1, input_ids2, pos, labels = batch            
        input_ids1 = input_ids1[0].to(device)
        input_ids2 = input_ids2[0].to(device)
        logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos).to('cpu')        
        loss = torch.nn.functional.mse_loss(logits, labels)
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    epoch_loss = tr_loss / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")

In [10]:
lr = 1e-5
EPOCHS = 3
device = 'cuda:1'

models = ['ESMForSingleMutationPosOuter',
          'ESMForSingleMutationPosConcat',
          'ESMForSingleMutation_pos_cat_cls',  
              'ESMForSingleMutation_pos', 
              'ESMForSingleMutation_cls']

full_df = pd.read_csv('DATASETS/new_ds_with_folds.csv')

preds = {n:[] for n in models} 
true = [None]*5

for model_name in models:
    model_class = globals()[model_name]
    print(f'Training model {model_name}')
    train_df = full_df
    train_ds = ProteinDataset(train_df)
        
    model = model_class()                        
    model.to(device) 
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    training_loader = DataLoader(train_ds, batch_size=1, num_workers = 2, shuffle = True)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(training_loader), epochs=EPOCHS)
        
    for epoch in range(EPOCHS):
        train(epoch)
         
    model.to('cpu')
    
    torch.save(model, 'weights/' + model_name)
    
    del model

Training model ESMForSingleMutationPosOuter
Training loss epoch: 3.876894674672059
Training loss epoch: 3.1980062041550696
Training loss epoch: 2.915139241010014
Training model ESMForSingleMutationPosConcat
Training loss epoch: 3.860952251269573
Training loss epoch: 1.469412682013314
Training loss epoch: 0.661039534414614
Training model ESMForSingleMutation_pos_cat_cls
Training loss epoch: 2.6397629074271167
Training loss epoch: 0.8607995854492856
Training loss epoch: 0.3215935788703277
Training model ESMForSingleMutation_pos
Training loss epoch: 2.5457445398761536
Training loss epoch: 0.8276246989439202
Training loss epoch: 0.2824712599465907
Training model ESMForSingleMutation_cls
Training loss epoch: 3.6326927324780582
Training loss epoch: 2.246442528126786
Training loss epoch: 0.9332105714671484
