In [1]:
from typing import Tuple, Any
import esm
import torch
import torch.nn as nn
from torch import device
from deeprpi.utils import RPIDataset
from deeprpi.config import glob

In [34]:
def load_esm()-> Tuple[nn.Module, esm.Alphabet]:
    """
    Load the ESM-1b model.
    :return: The ESM-1b model.
    """
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    model.eval()
    print("Model loaded successfully.")
    return model, alphabet

class ESMEmbedding:
    """
    To generate protein embeddings using ESM-1b model.
    """
    def __init__(self, model, alphabet, device: device):
        super().__init__()
        self.device = device
        self.model = model.to(self.device)
        self.alphabet = alphabet
        self.batch_converter = alphabet.get_batch_converter()

    def __call__(self, raw_seqs) -> tuple[Any, list[Any], Any]:
        """
        Generate embeddings for the given sequences. This step is done by a pretrained model.
        :param raw_seqs: The sequences for which embeddings are to be generated.
        :return: The embeddings for the given sequences.
        """
        # Extract start and end token indices
        start_token = glob.AMINO_ACIDS['<bos>']
        end_token = glob.AMINO_ACIDS['<eos>']
        idx_to_token = {v: k for k, v in glob.AMINO_ACIDS.items()}

        # Convert sequences to strings and retain only the part between start and end tokens
        seqs = []
        for seq in raw_seqs:
            start_idx = list(seq).index(start_token) + 1
            end_idx = list(seq).index(end_token)
            seq_str = ''.join([idx_to_token[int(idx)] for idx in seq[start_idx:end_idx]])
            seqs.append(seq_str)

        # Existing code to generate embeddings and contacts
        data = [(f"protein{i}", seq) for i, seq in enumerate(seqs)]
        batch_labels, batch_strs, batch_tokens = self.batch_converter(data)
        batch_tokens = batch_tokens.to(self.device)
        batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1)
        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=[33], return_contacts=True)
        attention_contacts = []
        # Extract the attention contacts for each sequence.
        for i, (contact, seq_len) in enumerate(zip(results["contacts"], batch_lens)):
            attention_contacts.append(results["contacts"][i][:seq_len, :seq_len])
        return results["representations"][33], attention_contacts, batch_lens

In [3]:
RPIDataset = RPIDataset(data_path='./data/NPInter5.csv',
                            batch_size=32,
                            num_workers=4,
                            rna_col='RNA_aa_code',
                            protein_col='target_aa_code',
                            label_col='Y',
                            padding=True,
                            rna_max_length=1000,
                            protein_max_length=1000,
                            truncation=False,
                            val_ratio=0.1,
                            test_ratio=0.1)

In [4]:
RPIDataset.setup()
train_dataloader = RPIDataset.train_dataloader()

Shuffling RNA sequences: 100%|██████████| 1182/1182 [00:00<00:00, 112403.47it/s]
Shuffling protein sequences: 100%|██████████| 1182/1182 [00:00<?, ?it/s]
Shuffling labels: 100%|██████████| 1182/1182 [00:00<?, ?it/s]
Selecting data: 100%|██████████| 1182/1182 [00:00<00:00, 219521.22it/s]


Selected 313 samples from 1182 samples


Tokenizing RNA sequences: 100%|██████████| 313/313 [00:00<00:00, 18632.88it/s]
Padding RNA sequences: 100%|██████████| 313/313 [00:00<?, ?it/s]
Tokenizing protein sequences: 100%|██████████| 313/313 [00:00<00:00, 8711.52it/s]
Padding protein sequences: 100%|██████████| 313/313 [00:00<00:00, 36663.70it/s]


In [5]:
for i in train_dataloader:
    print(len(i))
    print('='*50)
    print(i[0])
    print('='*50)
    print(i[1])
    break

5
tensor([[0, 4, 4,  ..., 1, 1, 1],
        [0, 5, 5,  ..., 1, 1, 1],
        [0, 5, 7,  ..., 1, 1, 1],
        ...,
        [0, 4, 6,  ..., 1, 1, 1],
        [0, 7, 7,  ..., 1, 1, 1],
        [0, 4, 4,  ..., 1, 1, 1]])
tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])


In [23]:
model, alphabet = load_esm()

Model loaded successfully.


In [35]:
embedding = ESMEmbedding(model, alphabet, device=torch.device('cpu'))

In [36]:
for i in train_dataloader:
    # print(RPIDataset.protein_tokenizer.decode(i[2]))
    print(embedding(i[2][:2]))
    break

(tensor([[[ 0.0455,  0.0185,  0.1103,  ..., -0.2935,  0.1578,  0.0456],
         [ 0.1249, -0.0747, -0.0016,  ..., -0.0209,  0.0364,  0.0414],
         [ 0.0895,  0.1725,  0.1456,  ...,  0.1058,  0.0619, -0.0703],
         ...,
         [ 0.1062, -0.0367, -0.0535,  ..., -0.1001,  0.1504, -0.4176],
         [-0.0432, -0.0424,  0.0300,  ..., -0.4647,  0.0893, -0.0100],
         [-0.0372, -0.0478,  0.0847,  ..., -0.4383,  0.3191,  0.0205]],

        [[ 0.0770, -0.0544,  0.0573,  ..., -0.2624,  0.1406, -0.0483],
         [-0.1049, -0.0689, -0.1309,  ..., -0.0552, -0.1758, -0.0479],
         [-0.1652,  0.0649,  0.0227,  ...,  0.1014, -0.0166, -0.0879],
         ...,
         [ 0.0098, -0.2441,  0.0756,  ..., -0.1425,  0.1530, -0.1203],
         [ 0.0074, -0.2351,  0.0823,  ..., -0.1438,  0.1477, -0.1171],
         [-0.0096, -0.2292,  0.0815,  ..., -0.1409,  0.1358, -0.1224]]]), [tensor([[0.3840, 0.6624, 0.1103,  ..., 0.0074, 0.0068, 0.0066],
        [0.6624, 0.9731, 0.5003,  ..., 0.0062, 0.