# ESM

In [1]:
import torch

In [2]:
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")

Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


In [6]:
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("2gi9", "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[0], return_contacts=False)
token_representations = results["representations"][0]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, (_, seq) in enumerate(data):
    print(token_representations[i, 1 : len(seq) + 1].shape)
    sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))

print(token_representations, token_representations.shape, sequence_representations, len(sequence_representations[0]))

torch.Size([56, 1280])
tensor([[[-0.0663, -0.0206,  0.0212,  ...,  0.0336,  0.3534,  0.3684],
         [ 0.2793, -0.0092,  0.0531,  ...,  0.0053, -0.0854,  0.2145],
         [ 0.8682,  0.0139, -0.0363,  ..., -0.0775, -0.0688,  0.1010],
         ...,
         [-0.1451,  0.0310, -0.0578,  ..., -0.1140, -0.0157,  0.2026],
         [-0.0656, -0.0076, -0.0974,  ...,  0.0546,  0.0260,  0.2645],
         [-0.2940, -0.0431,  0.0488,  ...,  0.0070,  0.0056,  0.2847]]]) torch.Size([1, 58, 1280]) [tensor([-0.1665, -0.0073,  0.0005,  ...,  0.0174,  0.0038,  0.1164])] 1280


In [8]:

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[0], return_contacts=False)
token_representations = results["representations"][0]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, (_, seq) in enumerate(data):
    print(token_representations[i, 1 : len(seq) + 1].shape)
    sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))

print(token_representations, token_representations.shape, sequence_representations, len(sequence_representations[0]))

torch.Size([56, 1280])
tensor([[[-0.0663, -0.0206,  0.0212,  ...,  0.0336,  0.3534,  0.3684],
         [ 0.2793, -0.0092,  0.0531,  ...,  0.0053, -0.0854,  0.2145],
         [ 0.8682,  0.0139, -0.0363,  ..., -0.0775, -0.0688,  0.1010],
         ...,
         [-0.1451,  0.0310, -0.0578,  ..., -0.1140, -0.0157,  0.2026],
         [-0.0656, -0.0076, -0.0974,  ...,  0.0546,  0.0260,  0.2645],
         [-0.2940, -0.0431,  0.0488,  ...,  0.0070,  0.0056,  0.2847]]]) torch.Size([1, 58, 1280]) [tensor([-0.1665, -0.0073,  0.0005,  ...,  0.0174,  0.0038,  0.1164])] 1280


In [9]:

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[1], return_contacts=False)
token_representations = results["representations"][1]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, (_, seq) in enumerate(data):
    print(token_representations[i, 1 : len(seq) + 1].shape)
    sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))

print(token_representations, token_representations.shape, sequence_representations, len(sequence_representations[0]))

torch.Size([56, 1280])
tensor([[[-0.3899,  0.1210, -0.7449,  ..., -0.9544,  0.4704,  0.6206],
         [ 0.4447,  0.8576,  0.1733,  ..., -0.8428,  0.4788,  0.9012],
         [ 1.3972,  1.1721, -0.5245,  ..., -1.4560,  0.0084,  0.3406],
         ...,
         [ 0.4156,  0.8773,  0.4632,  ..., -0.9489, -0.3087,  1.6864],
         [ 0.7929,  0.9090, -0.9665,  ..., -0.6917,  0.5742,  0.9072],
         [-0.0441,  0.6803, -0.3759,  ..., -0.3495,  0.1413,  0.8631]]]) torch.Size([1, 58, 1280]) [tensor([ 0.1365,  0.7353, -0.1419,  ..., -0.5605,  0.1081,  0.6370])] 1280


In [19]:
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=list(range(0, 33+1)), return_contacts=False)
print(results["representations"].keys())
token_representations = results["representations"][1]
print(token_representations)


dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33])
tensor([[[-0.3899,  0.1210, -0.7449,  ..., -0.9544,  0.4704,  0.6206],
         [ 0.4447,  0.8576,  0.1733,  ..., -0.8428,  0.4788,  0.9012],
         [ 1.3972,  1.1721, -0.5245,  ..., -1.4560,  0.0084,  0.3406],
         ...,
         [ 0.4156,  0.8773,  0.4632,  ..., -0.9489, -0.3087,  1.6864],
         [ 0.7929,  0.9090, -0.9665,  ..., -0.6917,  0.5742,  0.9072],
         [-0.0441,  0.6803, -0.3759,  ..., -0.3495,  0.1413,  0.8631]]])


In [7]:
help(model)

Help on ProteinBertModel in module esm.model object:

class ProteinBertModel(torch.nn.modules.module.Module)
 |  ProteinBertModel(args, alphabet)
 |  
 |  Method resolution order:
 |      ProteinBertModel
 |      torch.nn.modules.module.Module
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, args, alphabet)
 |      Initializes internal Module state, shared by both nn.Module and ScriptModule.
 |  
 |  forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False)
 |      Defines the computation performed at every call.
 |      
 |      Should be overridden by all subclasses.
 |      
 |      .. note::
 |          Although the recipe for forward pass needs to be defined within
 |          this function, one should call the :class:`Module` instance afterwards
 |          instead of this since the former takes care of running the
 |          registered hooks while the latter silently ignores them.
 |  
 |  predict_contacts(self, tokens)


In [6]:
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("2gi9", "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=False)
token_representations = results["representations"][33]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))

print(token_representations, token_representations.shape, sequence_representations, len(sequence_representations[0]))

tensor([[[-0.0454,  0.2249,  0.1654,  ...,  0.0380,  0.0451, -0.0621],
         [-0.0642, -0.2746,  0.4808,  ..., -0.0545, -0.1764,  0.0235],
         [-0.1702,  0.0312, -0.2013,  ..., -0.0612, -0.2743,  0.3677],
         ...,
         [-0.0621,  0.2908,  0.6940,  ..., -0.2058, -0.0074,  0.4298],
         [-0.1076,  0.1801,  0.1223,  ..., -0.1654,  0.0073,  0.0412],
         [-0.0576, -0.0605,  0.0807,  ..., -0.0661,  0.1188,  0.1108]]]) torch.Size([1, 58, 1280]) [tensor([-0.0211,  0.1367, -0.0039,  ..., -0.0647, -0.0582,  0.1644])] 1280


In [6]:
results["representations"][33].shape

torch.Size([1, 58, 1280])

In [7]:
results.keys()

dict_keys(['logits', 'representations', 'attentions', 'contacts'])

In [9]:
results["representations"].keys()

dict_keys([33])