In [1]:
# https://github.com/facebookresearch/esm/blob/main/README.md
import torch
import esm

import sys; sys.path.append("..")

In [7]:
# Load an ESM-2 model - note that many different model sizes are available!

# model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()
# model_repr_layer = 36

# ** 650M parameters **
# This model checkpoint is about 2.4 GB, and the embeddings can be retrieved from layer #33.
# Embeddings are 1280-dimensional.
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # from README
model_repr_layer = 33

# model, alphabet = esm.pretrained.esm2_t30_150M_UR50D() # from README
# model_repr_layer = 30

# ** 35M parameters **
# This model checkpoint is about 128 MB, and the embeddings are also from layer #12.
# Embeddings are 480-dimensional.
# model, alphabet = esm.pretrained.esm2_t12_35M_UR50D() # smallest 
# model_repr_layer = 12

# ** 8M parameters **
# This model checkpoint is about 29 MB, and the embeddings are also from layer #6.
# Embeddings are 320-dimensional.
# model, alphabet = esm.pretrained.esm2_t6_8M_UR50D() # smallest 
# model_repr_layer = 6

batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [3]:
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
# data = [
#     ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
#     ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
#     ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
#     ("protein3",  "K A <mask> I S Q"),
# ]

data = [
    ("protein1", "MKTVRQERLKSIVRILE"),
    ("protein2", "KALTARQQEVFDLIRDH"),
]

batch_converter = alphabet.get_batch_converter()

batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

In [4]:
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[model_repr_layer], return_contacts=True)
token_representations = results["representations"][model_repr_layer]
token_representations[0].shape

token_representations[:, 0, :10]

tensor([[ 0.0300, -0.0054,  0.0006,  0.0563, -0.0085,  0.0144,  0.0689, -0.0529,
         -0.1287,  0.0063],
        [ 0.0474, -0.0161, -0.0309, -0.0067, -0.0259,  0.0428,  0.0714, -0.0636,
         -0.1214, -0.0116]])

In [5]:
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

# NOTE(milo): This model produces 1280-dimensional embeddings.
sequence_representations[0].shape

torch.Size([1280])

In [None]:
# Look at the unsupervised self-attention map contact predictions
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(seq)
    plt.show()

In [None]:
import gvpgnn.embeddings as embeddings

In [None]:
model_factory = embeddings.esm2_model_dictionary["esm2_t6_8M_UR50D"]
layer = embeddings.esm2_embedding_layer["esm2_t6_8M_UR50D"]

model, alphabet = model_factory()

seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG<unk>"
out = embeddings.extract_embedding_single(model, alphabet, layer, seq)

# out.shape
print(out.shape)
print(len(seq))

out