In [7]:
import esm
import torch

In [42]:
# Load ESM-2 model
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [43]:
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "BKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
    ("protein4", "ACDEFGHIKLMNPQRSTVWY<unk>"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

In [44]:
print(f"batch_labels:{batch_labels}")
print(f"batch_strs:{batch_strs}")
print(f"batch_tokens:{batch_tokens}")
print(f"batch_lens:{batch_lens}")

batch_labels:['protein1', 'protein2', 'protein2 with mask', 'protein3', 'protein4']
batch_strs:['BKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG', 'KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE', 'KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE', 'K A <mask> I S Q', 'ACDEFGHIKLMNPQRSTVWY<unk>']
batch_tokens:tensor([[ 0, 25, 15, 11,  7, 10, 16,  9, 10,  4, 15,  8, 12,  7, 10, 12,  4,  9,
         10,  8, 15,  9, 14,  7,  8,  6,  5, 16,  4,  5,  9,  9,  4,  8,  7,  8,
         10, 16,  7, 12,  7, 16, 13, 12,  5, 19,  4, 10,  8,  4,  6, 19, 17, 12,
          7,  5, 11, 14, 10,  6, 19,  7,  4,  5,  6,  6,  2,  1,  1,  1,  1,  1,
          1],
        [ 0, 15,  5,  4, 11,  5, 10, 16, 16,  9,  7, 18, 13,  4, 12, 10, 13, 21,
         12,  8, 16, 11,  6, 20, 14, 14, 11, 10,  5,  9, 12,  5, 16, 10,  4,  6,
         18, 10,  8, 14, 17,  5,  5,  9,  9, 21,  4, 15,  5,  4,  5, 10, 15,  6,
          7, 12,  9, 12,  7,  8,  6, 

In [46]:
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens.to(torch.device('cuda')), repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

In [47]:
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

In [48]:
# Look at the unsupervised self-attention map contact predictions
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(seq)
    plt.show()

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [8]:
from deeprpi.utils import RPIDataset
from deeprpi.model import embedding

In [9]:
RPIDataset = RPIDataset(data_path='./data/NPInter5.csv',
                            batch_size=32,
                            num_workers=4,
                            rna_col='RNA_aa_code',
                            protein_col='target_aa_code',
                            label_col='Y',
                            padding=True,
                            rna_max_length=1000,
                            protein_max_length=1000,
                            truncation=False,
                            val_ratio=0.1,
                            test_ratio=0.1)

In [10]:
RPIDataset.setup()
train_dataloader = RPIDataset.train_dataloader()

Shuffling RNA sequences: 100%|██████████| 1182/1182 [00:00<?, ?it/s]
Shuffling protein sequences: 100%|██████████| 1182/1182 [00:00<00:00, 1202733.46it/s]
Shuffling labels: 100%|██████████| 1182/1182 [00:00<?, ?it/s]
Selecting data: 100%|██████████| 1182/1182 [00:00<?, ?it/s]


Selected 313 samples from 1182 samples


Tokenizing RNA sequences: 100%|██████████| 313/313 [00:00<00:00, 19171.37it/s]
Padding RNA sequences: 100%|██████████| 313/313 [00:00<00:00, 313246.76it/s]
Tokenizing protein sequences: 100%|██████████| 313/313 [00:00<00:00, 34403.87it/s]
Padding protein sequences: 100%|██████████| 313/313 [00:00<00:00, 121861.80it/s]


In [11]:
for i in train_dataloader:
    print(len(i))
    print('='*50)
    print(i[0])
    print('='*50)
    print(i[1])
    break

5
tensor([[0, 4, 4,  ..., 1, 1, 1],
        [0, 5, 5,  ..., 1, 1, 1],
        [0, 5, 7,  ..., 1, 1, 1],
        ...,
        [0, 4, 6,  ..., 1, 1, 1],
        [0, 7, 7,  ..., 1, 1, 1],
        [0, 4, 4,  ..., 1, 1, 1]])
tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])


In [51]:
for i in train_dataloader:
    print(RPIDataset.protein_tokenizer.decode(i[2]))
    print(embedding(RPIDataset.protein_tokenizer.decode(i[2][:2]))[1])
    break

['MSNGYEDHMAEDCRDDIGRTNLIVNYLPQNMTQEELRSLFSSIGEVESAKLIRDKVAGHSLGYGFVNYVTAKDAERAISTLNGLRLQSKTIKVSYARPSSEVIKDANLYISGLPRTMTQKDVEDMFSRFGRIINSRVLVDQTTGLSRGVAFIRFDKRSEAEEAITSFNGHKPPGSSEPITVKFAANPNQNKNMALLSQLYHSPARRFGGPVHHQAQRFRFSPMGVDHMSGISGVNVPGNASSGWCIFIYNLGQDADEGILWQMFGPFGAVTNVKVIRDFNTNKCKGFGFVTMTNYEEAAMAIASLNGYRLGDKILQVSFKTNKSHK', 'MTILFLTMVISYFGCMKAAPMKEANIRGQGGLAYPGVRTHGTLESVNGPKAGSRGLTSLADTFEHVIEELLDEDQKVRPNEENNKDADLYTSRVMLSSQVPLEPPLLFLLEEYKNYLDAANMSMRVRRHSDPARRGELSVCDSISEWVTAADKKTAVDMSGGTVTVLEKVPVSKGQLKQYFYETKCNPMGYTKEGCRGIDKRHWNSQCRTTQSYVRALTMDSKKRIGWRFIRIDTSCVCTLTIKRGR', 'LPNNTSSSP', 'MALTDGGWCLPKRFGAAGADASDSRAFPAREPSTPPSPISSSSSSCSRGGERGPGGASNCGTPQLDTEAAAGPPARSLLLSSYASHPFGAPHGPSAPGVAGPGGNLSSWEDLLLFTDLDQAATASKLLWSSRGAKLSPFAPEQPEEMYQTLAALSSQGPAAYDGAPGGFVHSAAAAAAAAAAASSPVYVPTTRVGSMLPGLPYHLQGSGSGPANHAGGAGAHPGWPQASADSPPYGSGGGAAGGGAAGPGGAGSAAAHVSARFPYSPSPPMANGAAREPGGYAAAGSGGAGGVSGGGSSLAAMGGREPQYSSLSAARPLNGTYHHHHHHHHHHPSPYSPYVGAPLTPAWPAGPFETPVLHSLQSRAGAPLPVPRGPSADLLEDLSESRECVNCGSIQTPLWRRD

In [5]:
import esm
import torch
import torch.nn as nn
from torch import device
from typing import Tuple, Any, List

In [18]:
def load_esm_model()-> Tuple[nn.Module, esm.Alphabet]:
    """
    Load the ESM-1b model.
    :return: The ESM-1b model.
    """
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    model.eval()
    print("Model loaded successfully.")
    return model, alphabet

model, alphabet = load_esm_model()

Model loaded successfully.


In [49]:
from typing import Tuple, Any, List
class TokenESMEmbedding:
    """
    To generate protein embeddings using ESM-1b model.
    """
    def __init__(self, model, alphabet, device: device):
        super().__init__()
        self.device = device
        self.model = model.to(self.device)
        self.alphabet = alphabet
        self.batch_converter = alphabet.get_batch_converter()

    def __call__(self, seqs) -> tuple[Any, list[Any], Any]:
        """
        Generate embeddings for the given sequences. This step is done by a pretrained model.
        :param seqs: The sequences for which embeddings are to be generated.
        :return: The embeddings for the given sequences.
        """
        data = [(f"protein{i}", seq) for i, seq in enumerate(seqs)]
        batch_labels, batch_strs, batch_tokens = self.batch_converter(data)
        batch_tokens = batch_tokens.to(self.device)  # Move batch tokens to GPU
        batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1)
        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=[33], return_contacts=True)
        attention_contacts = []
        for i, (contact, seq_len) in enumerate(zip(results["contacts"], batch_lens)):
            attention_contacts.append(results["contacts"][i][:seq_len, :seq_len])
        return results["representations"][33], attention_contacts, batch_lens

In [50]:
embedding = TokenESMEmbedding(model, alphabet, device=torch.device('cuda'))


In [29]:
print(torch.cuda.is_available())

True
