In [1]:
import torch
import esm

In [2]:
# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [4]:
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q")
]

In [5]:
print(len(data[0][1]))

65


In [6]:
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

In [7]:
batch_lens

tensor([67, 73, 73,  8])

In [8]:
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

In [9]:
token_representations.shape

torch.Size([4, 73, 1280])

In [None]:
# 33 because the model has 33 layers. 
# This is kind of dumb, we should have a way of finding out the total number of layers
# for a given model.

In [10]:
padding_mask = batch_tokens.eq(model.padding_idx)
for layer_idx, layer in enumerate(model.layers):
    print(layer_idx)
    # print(layer_idx)
    # print(layer)
    # x, attn = layer(
    #     x,
    #     self_attn_padding_mask=padding_mask,
    #     need_head_weights=False,
    # )
    # print(x)
    # print(attn)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


In [12]:
for key, value in results["representations"].items():
    print(key)

33


In [11]:
for key, value in results.items():
    print(key)

logits
representations
attentions
contacts


In [13]:
results["representations"][33]

tensor([[[ 0.0744, -0.0747,  0.0824,  ..., -0.2394,  0.1661, -0.0306],
         [ 0.0826, -0.2050, -0.0171,  ...,  0.1044,  0.1343, -0.0945],
         [-0.0528, -0.0635, -0.2515,  ..., -0.0461,  0.2368, -0.1214],
         ...,
         [ 0.0904, -0.0895,  0.0698,  ..., -0.2790,  0.1822, -0.0566],
         [ 0.0960, -0.0847,  0.0674,  ..., -0.2860,  0.1888, -0.0324],
         [ 0.0830, -0.0842,  0.0592,  ..., -0.2927,  0.1853, -0.0333]],

        [[ 0.0819, -0.0513,  0.0804,  ..., -0.3185,  0.1573,  0.0690],
         [-0.0232,  0.0039, -0.2268,  ...,  0.0346,  0.0886,  0.3810],
         [ 0.0146, -0.0700, -0.1314,  ..., -0.0757,  0.3837,  0.0932],
         ...,
         [ 0.0198, -0.1917,  0.1589,  ..., -0.2205, -0.0567,  0.3618],
         [-0.0704, -0.2010,  0.0613,  ..., -0.1972,  0.0557,  0.2920],
         [ 0.1067, -0.0337,  0.0860,  ..., -0.4020,  0.1861,  0.0410]],

        [[ 0.0822, -0.0485,  0.0707,  ..., -0.3180,  0.1576,  0.0736],
         [-0.0306,  0.0159, -0.2224,  ...,  0

In [16]:
token_representations.shape

torch.Size([4, 73, 1280])

In [14]:
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

In [15]:
sequence_representations

[tensor([ 0.0614, -0.0687,  0.0430,  ..., -0.1642, -0.0678,  0.0446]),
 tensor([ 0.0553, -0.0757,  0.0414,  ..., -0.3117, -0.0026,  0.1683]),
 tensor([ 0.0618, -0.0769,  0.0405,  ..., -0.3037, -0.0013,  0.1741]),
 tensor([ 0.0084,  0.1425,  0.0506,  ...,  0.0403, -0.1063,  0.0079])]

In [14]:
sequence_representations[0].shape

torch.Size([1280])

In [21]:
for i, tokens_len in enumerate(batch_lens):
    print(tokens_len)

tensor(67)
tensor(73)
tensor(73)
tensor(8)
