In [20]:
import os
import sys
import tqdm
import torch
import textwrap
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from torch import nn
from typing import Union
from prettytable import PrettyTable
from collections import defaultdict
from torch.utils.data import DataLoader, random_split

import os
import sys
import time
import datetime
import random
from datetime import date
from os import path
from typing import Union
import logging
import numpy as np
import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from collections import defaultdict, OrderedDict
from pnlp.db.dataset import SeqDataset, initialize_db
from pnlp.embedding.tokenizer import ProteinTokenizer, token_to_index, index_to_token
from pnlp.embedding.nlp_embedding import NLPEmbedding
from pnlp.model.language import ProteinLM
from pnlp.model.bert import BERT
from runner_util import plot_run

In [2]:
def count_parameters(model):
    """
    Count model parameters and print a summary

    A nice hack from:
    https://stackoverflow.com/a/62508086/1992369
    """
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}\n")
    return total_params

### FCNN

In [13]:
from torch import nn

class FCN(nn.Module):
    """ Fully Connected Network """

    def __init__(self,
                 fcn_input_size,     # The number of input features
                 fcn_hidden_size,    # The number of features in hidden layer of FCN.
                 device):            # Device ('cpu' or 'cuda')
        super().__init__()
        self.device = device

        # FCN layers
        self.fcn = nn.Sequential(nn.Linear(fcn_input_size, fcn_hidden_size),
                                 nn.ReLU(),
                                 nn.Linear(fcn_hidden_size, 1))  # Adjust this line based on the required output size

    def forward(self, x):
        fcn_out = self.fcn(x)
        fcn_final_out = fcn_out[:, -1, :]
        prediction = fcn_final_out.to(self.device)

        return prediction

In [14]:
class EmbeddedDMSDataset(Dataset):
    """ Binding or Expression DMS Dataset """
    
    def __init__(self, pickle_file:str):
        """
        Load from pickle file:
        - sequence label (seq_id), 
        - binding or expression numerical target (log10Ka or ML_meanF), and 
        - embeddings
        """
        with open(pickle_file, 'rb') as f:
            dms_list = pickle.load(f)
        
            self.labels = [entry['seq_id'] for entry in dms_list]
            self.numerical = [entry["log10Ka" if "binding" in pickle_file else "ML_meanF"] for entry in dms_list]
            self.embeddings = [entry['embedding'] for entry in dms_list]
 
    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx):
        # Convert to pytorch geometric graph
        embedding = self.embeddings[idx]
        edges = [(i, i+1) for i in range(embedding.size(0) - 1)]
        edge_index = torch.tensor(edges, dtype=torch.int64).t().contiguous()
        y = torch.tensor([self.numerical[idx]], dtype=torch.float32).view(-1, 1)
        
        return Data(x=embedding, edge_index=edge_index, y=y)

device = torch.device("cuda:0")
embedded_train_pkl = '../../../data/pickles/dms_mutation_binding_Kds_train_esm_embedded.pkl' 
train_dataset = EmbeddedDMSDataset(embedded_train_pkl)

#  FCN input
fcn_input_size = train_dataset.embeddings[0].size(1)   
fcn_hidden_size = fcn_input_size
model = FCN(fcn_input_size, fcn_hidden_size, device)

# Run
count_parameters(model)

+--------------+------------+
|   Modules    | Parameters |
+--------------+------------+
| fcn.0.weight |   102400   |
|  fcn.0.bias  |    320     |
| fcn.2.weight |    320     |
|  fcn.2.bias  |     1      |
+--------------+------------+
Total Trainable Params: 103041



103041

### GCN

In [4]:
from torch_geometric.nn import SAGEConv, global_mean_pool
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

class GraphSAGE(nn.Module):
    """ GraphSAGE. """

    def __init__(self, in_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, 16)
        self.conv2 = SAGEConv(16, out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return x

In [9]:
class EmbeddedDMSDataset(Dataset):
    """ Binding or Expression DMS Dataset """
    
    def __init__(self, pickle_file:str):
        """
        Load from pickle file:
        - sequence label (seq_id), 
        - binding or expression numerical target (log10Ka or ML_meanF), and 
        - embeddings
        """
        with open(pickle_file, 'rb') as f:
            dms_list = pickle.load(f)
        
            self.labels = [entry['seq_id'] for entry in dms_list]
            self.numerical = [entry["log10Ka" if "binding" in pickle_file else "ML_meanF"] for entry in dms_list]
            self.embeddings = [entry['embedding'] for entry in dms_list]
 
    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx):
        # Convert to pytorch geometric graph
        embedding = self.embeddings[idx]
        edges = [(i, i+1) for i in range(embedding.size(0) - 1)]
        edge_index = torch.tensor(edges, dtype=torch.int64).t().contiguous()
        y = torch.tensor([self.numerical[idx]], dtype=torch.float32).view(-1, 1)
        
        return Data(x=embedding, edge_index=edge_index, y=y)

device = torch.device("cuda:0")
embedded_train_pkl = '../../../data/pickles/dms_mutation_binding_Kds_train_esm_embedded.pkl' 
train_dataset = EmbeddedDMSDataset(embedded_train_pkl)

# GraphSAGE input
input_channels = train_dataset.embeddings[0].size(1) # number of input channels (dimensions of the embeddings)
out_channels = 1  # For regression output
model = GraphSAGE(input_channels, out_channels).to(device)

# Run
print(input_channels)
count_parameters(model)

320
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
| conv1.lin_l.weight |    5120    |
|  conv1.lin_l.bias  |     16     |
| conv1.lin_r.weight |    5120    |
| conv2.lin_l.weight |     16     |
|  conv2.lin_l.bias  |     1      |
| conv2.lin_r.weight |     16     |
+--------------------+------------+
Total Trainable Params: 10289



10289

### BLSTM

In [7]:
""" BLSTM model with FCN layer. """

import torch
from torch import nn

class BLSTM(nn.Module):
    """ Bidirectional LSTM with FCN layer. """

    def __init__(self,
                 lstm_input_size,    # The number of expected features.
                 lstm_hidden_size,   # The number of features in hidden state h.
                 lstm_num_layers,    # Number of recurrent layers in LSTM.
                 lstm_bidirectional, # Bidrectional LSTM.
                 fcn_hidden_size):   # The number of features in hidden layer of CN.
        super().__init__()

        # LSTM layer
        self.lstm = nn.LSTM(input_size=lstm_input_size,
                            hidden_size=lstm_hidden_size,
                            num_layers=lstm_num_layers,
                            bidirectional=lstm_bidirectional,
                            batch_first=True)           

        # FCN
        if lstm_bidirectional:
            self.fcn = nn.Sequential(nn.Linear(2 * lstm_hidden_size, fcn_hidden_size),
                                     nn.ReLU())
        else:
            self.fcn = nn.Sequential(nn.Linear(lstm_hidden_size, fcn_hidden_size),
                                     nn.ReLU())

        # FCN output layer
        self.out = nn.Linear(fcn_hidden_size, 1)

    def forward(self, x):
        num_directions = 2 if self.lstm.bidirectional else 1
        h_0 = torch.zeros(num_directions * self.lstm.num_layers, x.size(0), self.lstm.hidden_size, device=x.device)
        c_0 = torch.zeros(num_directions * self.lstm.num_layers, x.size(0), self.lstm.hidden_size, device=x.device)

        lstm_out, (h_n, c_n) = self.lstm(x, (h_0, c_0))
        h_n.detach()
        c_n.detach()
        lstm_final_out = lstm_out[:, -1, :]
        fcn_out = self.fcn(lstm_final_out)
        prediction = self.out(fcn_out)

        return prediction

In [8]:
# BLSTM input
lstm_input_size = 320
lstm_hidden_size = 320
lstm_num_layers = 1        
lstm_bidrectional = True   
fcn_hidden_size = 320
model = BLSTM(lstm_input_size, lstm_hidden_size, lstm_num_layers, lstm_bidrectional, fcn_hidden_size)

# Run
count_parameters(model)

+---------------------------+------------+
|          Modules          | Parameters |
+---------------------------+------------+
|     lstm.weight_ih_l0     |   409600   |
|     lstm.weight_hh_l0     |   409600   |
|      lstm.bias_ih_l0      |    1280    |
|      lstm.bias_hh_l0      |    1280    |
| lstm.weight_ih_l0_reverse |   409600   |
| lstm.weight_hh_l0_reverse |   409600   |
|  lstm.bias_ih_l0_reverse  |    1280    |
|  lstm.bias_hh_l0_reverse  |    1280    |
|        fcn.0.weight       |   204800   |
|         fcn.0.bias        |    320     |
|         out.weight        |    320     |
|          out.bias         |     1      |
+---------------------------+------------+
Total Trainable Params: 1848961



1848961

### ESM-BLSTM

In [10]:
from transformers import AutoTokenizer, EsmModel 

class ESM_BLSTM(nn.Module):
    def __init__(self, esm, blstm):
        super().__init__()
        self.esm = esm
        self.blstm = blstm

    def forward(self, tokenized_seqs):
        with torch.set_grad_enabled(self.training):  # Enable gradients, managed by model.eval() or model.train() in epoch_iteration
            esm_output = self.esm(**tokenized_seqs).last_hidden_state
            reshaped_output = esm_output.squeeze(0)  
            output = self.blstm(reshaped_output)
        return output

In [12]:
# ESM input
esm = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D").to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

# BLSTM input
lstm_input_size = 320
lstm_hidden_size = 320
lstm_num_layers = 1        
lstm_bidrectional = True   
fcn_hidden_size = 320
blstm = BLSTM(lstm_input_size, lstm_hidden_size, lstm_num_layers, lstm_bidrectional, fcn_hidden_size)

model = ESM_BLSTM(esm, blstm)

# Run
count_parameters(model)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


+---------------------------------------------------+------------+
|                      Modules                      | Parameters |
+---------------------------------------------------+------------+
|       esm.embeddings.word_embeddings.weight       |   10560    |
|     esm.embeddings.position_embeddings.weight     |   328320   |
|  esm.encoder.layer.0.attention.self.query.weight  |   102400   |
|   esm.encoder.layer.0.attention.self.query.bias   |    320     |
|   esm.encoder.layer.0.attention.self.key.weight   |   102400   |
|    esm.encoder.layer.0.attention.self.key.bias    |    320     |
|  esm.encoder.layer.0.attention.self.value.weight  |   102400   |
|   esm.encoder.layer.0.attention.self.value.bias   |    320     |
| esm.encoder.layer.0.attention.output.dense.weight |   102400   |
|  esm.encoder.layer.0.attention.output.dense.bias  |    320     |
|   esm.encoder.layer.0.attention.LayerNorm.weight  |    320     |
|    esm.encoder.layer.0.attention.LayerNorm.bias   |    320  

9689082

### BERT-BLSTM

In [17]:
""" BLSTM with FCN layer, MLM, and BERT. """

import torch
import torch.nn as nn
from pnlp.model.language import ProteinMaskedLanguageModel, BERT

class BERT_BLSTM(nn.Module):
    """" BLSTM with FCN layer, MLM, and BERT. """
    
    def __init__(self, bert: BERT, blstm:BLSTM, vocab_size: int):
        super().__init__()

        self.bert = bert
        self.mlm = ProteinMaskedLanguageModel(self.bert.hidden, vocab_size)
        self.blstm = blstm

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.bert(x)
        error_1 = self.mlm(x) # error from masked language
        error_2 = self.blstm(x) # error from regession

        return error_1, error_2

In [22]:
# BERT input
max_len = 280
mask_prob = 0.15
embedding_dim = 320 
dropout = 0.1
n_transformer_layers = 12
n_attn_heads = 10
tokenizer = ProteinTokenizer(max_len, mask_prob)
bert = BERT(embedding_dim, dropout, max_len, mask_prob, n_transformer_layers, n_attn_heads)

# BLSTM input
lstm_input_size = 320
lstm_hidden_size = 320
lstm_num_layers = 1        
lstm_bidrectional = True   
fcn_hidden_size = 320
blstm = BLSTM(lstm_input_size, lstm_hidden_size, lstm_num_layers, lstm_bidrectional, fcn_hidden_size)

# BERT_BLSTM input
vocab_size = len(token_to_index)
model = BERT_BLSTM(bert, blstm, vocab_size)

# Run
count_parameters(model)

+-----------------------------------------------------------+------------+
|                          Modules                          | Parameters |
+-----------------------------------------------------------+------------+
|           bert.embedding.token_embedding.weight           |    8960    |
|    bert.transformer_blocks.0.attention.linears.0.weight   |   102400   |
|     bert.transformer_blocks.0.attention.linears.0.bias    |    320     |
|    bert.transformer_blocks.0.attention.linears.1.weight   |   102400   |
|     bert.transformer_blocks.0.attention.linears.1.bias    |    320     |
|    bert.transformer_blocks.0.attention.linears.2.weight   |   102400   |
|     bert.transformer_blocks.0.attention.linears.2.bias    |    320     |
|    bert.transformer_blocks.0.attention.linears.3.weight   |   102400   |
|     bert.transformer_blocks.0.attention.linears.3.bias    |    320     |
|  bert.transformer_blocks.0.attention.output_linear.weight |   102400   |
|   bert.transformer_bloc

17895069

### ESM

In [23]:
esm = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D").to(device)
count_parameters(esm)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


+-----------------------------------------------+------------+
|                    Modules                    | Parameters |
+-----------------------------------------------+------------+
|       embeddings.word_embeddings.weight       |   10560    |
|     embeddings.position_embeddings.weight     |   328320   |
|  encoder.layer.0.attention.self.query.weight  |   102400   |
|   encoder.layer.0.attention.self.query.bias   |    320     |
|   encoder.layer.0.attention.self.key.weight   |   102400   |
|    encoder.layer.0.attention.self.key.bias    |    320     |
|  encoder.layer.0.attention.self.value.weight  |   102400   |
|   encoder.layer.0.attention.self.value.bias   |    320     |
| encoder.layer.0.attention.output.dense.weight |   102400   |
|  encoder.layer.0.attention.output.dense.bias  |    320     |
|   encoder.layer.0.attention.LayerNorm.weight  |    320     |
|    encoder.layer.0.attention.LayerNorm.bias   |    320     |
|   encoder.layer.0.intermediate.dense.weight   |   409

7840121