# ESM3 - Model Walkthrough

![image.png](https://github.com/evolutionaryscale/esm/blob/main/_assets/esm3_diagram.png?raw=true)

# Imports

If you're running in Colab, you probably want to get a GPU runtime first (Runtime > Change runtime type > T4 GPU).

In [None]:
%set_env HF_TOKEN=

In [None]:
%set_env TOKENIZERS_PARALLELISM=false
!pip install esm
import numpy as np
import torch
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

!pip install py3Dmol
import py3Dmol
from esm.models.esm3 import ESM3
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain

# Load `esm-open-small` on GPU


In [None]:
from esm.utils.misc import huggingfacehub_login

huggingfacehub_login()  # will prompt you to get an API key and accept the ESM3 license.

In [None]:
model = ESM3.from_pretrained("esm3_sm_open_v1", device=torch.device("cpu"))

Alternatively, you could use the Forge API running the model remotely, and use the local `client` to call the API just like you're used to with the model running locally on your GPU:


In [5]:
# from getpass import getpass
# token = getpass("Token from Forge console: ")
# model = client(
#     model="esm3-large-2024-03",
#     url="https://forge.evolutionaryscale.ai",
#     token=token,
# )

# Taking a Look at the Model

In [6]:
print(model)

ESM3(
  (encoder): EncodeInputs(
    (sequence_embed): Embedding(64, 1536)
    (plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_per_res_plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_tokens_embed): Embedding(4101, 1536)
    (ss8_embed): Embedding(11, 1536)
    (sasa_embed): Embedding(19, 1536)
    (function_embed): ModuleList(
      (0-7): 8 x Embedding(260, 192, padding_idx=0)
    )
    (residue_embed): EmbeddingBag(1478, 1536, mode='sum', padding_idx=0)
  )
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0): UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=1536, out_features=4608, bias=False)
          )
          (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (q_ln): LayerNorm((1536,), eps=1e-05, el

# Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein


First, we can use the `ProteinChain` class from the `esm` sdk to grab a protein structure from the PDB.
We'll work with a human renal (kidney) dipeptidase (a protein that breaks up two amino acids bound together). Renal dipeptidases are of particular interest because they metabolize certain antibiotics.


In [7]:
pdb_id = "1ITU"  # PDB ID corresponding to Renal Dipeptidase
chain_id = "A"  # Chain ID corresponding to Renal Dipeptidase in the PDB structure
renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)
# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file

The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein


In [8]:
print(renal_dipep_chain.sequence), len(renal_dipep_chain.sequence)

DFFRDEAERIMRDSPVIDGHNDLPWQLLDMFNNRLQDERANLTTLAGTHTNIPKLRAGFVGGQFWSVYTPCDTQNKDAVRRTLEQMDVVHRMCRMYPETFLYVTSSAGIRQAFREGKVASLIGVEGGHSIDSSLGVLRALYQLGMRYLTLTHSCNTPWADNWLVDTGDSEPQSQGLSPFGQRVVKELNRLGVLIDLAHVSVATMKATLQLSRAPVIFSHSSAYSVCASRRNVPDDVLRLVKQTDSLVMVNFYNNYISCTNKANLSQVADHLDHIKEVAGARAVGFGGDFDGVPRVPEGLEDVSKYPDLIAELLRRNWTEAEVKGALADNLLRVFEAVEQASNLTQAPEEEPIPLDQLGGSCRTHYGYSS


(None, 369)

`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein.

The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**.


In [9]:
renal_dipep_chain.atom37_positions.shape

(369, 37, 3)

In [10]:
print("atom37_positions shape: ", renal_dipep_chain.atom37_positions.shape)
print(renal_dipep_chain.atom37_positions[:3])

atom37_positions shape:  (369, 37, 3)
[[[-40.525  -9.87   -2.643]
  [-39.79   -9.325  -3.825]
  [-38.765 -10.354  -4.294]
  [-39.096  -8.012  -3.45 ]
  [-37.878 -10.748  -3.53 ]
  [-38.41   -7.359  -4.629]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [-39.105  -7.036  -5.617]
  [-37.177  -7.161  -4.562]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan

We can visualize the protein chain using the `py3Dmol` library


In [None]:
# First we can create a `py3Dmol` view object
view = py3Dmol.view(width=500, height=500)
# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
pdb_str = renal_dipep_chain.to_pdb_string()
# Load the PDB string into the `py3Dmol` view object
view.addModel(pdb_str, "pdb")
# Set the style of the protein chain
view.setStyle({"cartoon": {"color": "spectrum"}})
# Zoom in on the protein chain
view.zoomTo()
# Display the protein chain
view.show()

Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif


In [12]:
motif_inds = np.arange(123, 146)
# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues
motif_sequence = renal_dipep_chain[motif_inds].sequence
motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions
print("motif_inds len: ", len(motif_inds))
print("Motif sequence: ", motif_sequence)
print("Motif atom37_positions shape: ", motif_atom37_positions.shape)

motif_inds len:  23
Motif sequence:  VEGGHSIDSSLGVLRALYQLGMR
Motif atom37_positions shape:  (23, 37, 3)


We can also visualize the motif in the original chain using `py3Dmol`. We'll color the original chain in grey and the motif in blue


In [None]:
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
motif_res_inds = (
    motif_inds + 1
).tolist()  # residue indices are 1-indexed in PDB files, so we add 1 to the indices
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()

Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the motif


In [14]:
prompt_length = 200

# First, we can construct a sequence prompt of all masks
sequence_prompt = ["_"] * prompt_length

# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)
sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)
sequence_prompt = "".join(sequence_prompt)

print("Sequence prompt: ", sequence_prompt)
print("Length of sequence prompt: ", len(sequence_prompt))

# Next, we can construct a structure prompt of all nan coordinates
structure_prompt = torch.full((prompt_length, 37, 3), np.nan)

# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72
structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(
    motif_atom37_positions
)
print("Structure prompt shape: ", structure_prompt.shape)
print(
    "Indices with structure conditioning: ",
    torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),
)

Sequence prompt:  ________________________________________________________________________VEGGHSIDSSLGVLRALYQLGMR_________________________________________________________________________________________________________
Length of sequence prompt:  200
Structure prompt shape:  torch.Size([200, 37, 3])
Indices with structure conditioning:  [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94]


`ESMProtein` is used to compose the sequence and structure prompts into a single prompt that can be passed to ESM3

In [15]:
# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3
protein_prompt = ESMProtein(
    sequence=sequence_prompt, 
    coordinates=structure_prompt
)

# How Model Generation Works

Here's a walkthrough of what happens when we ask the model to generate the sequence

In [16]:
# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use
sequence_generation_config = GenerationConfig(
    track="sequence",  # We want ESM3 to generate tokens for the sequence track
    num_steps=sequence_prompt.count("_") // 2,  # We'll use num(mask tokens) // 2 steps to decode the sequence
    temperature=0.5,  # We'll use a temperature of 0.5 to control the randomness of the decoding process
)

## From `ESMProtein` to `ESMProteinTensor` via `.encode`

`.encode` method converts the `ESMProtein` into a numerical tensor which the model can then work with:
- Protein sequence is converted from a string to a sequence of tokens
- Structure tokens are created; Yep, coordinates end up being tokenized (<==> discretised).

This is done using the model's tokenizers - `tokenizers` attribute.

In [26]:
input_tokens = [model.encode(protein_prompt)]

In [46]:
input_tokens[0]

ESMProteinTensor(sequence=tensor([ 0, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32,  7,  9,  6,  6, 21,  8, 12, 13,  8,  8,  4,  6,  7,  4, 10,  5,  4,
        19, 16,  4,  6, 20, 10, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
        32, 32, 32,  2]), structure=tensor([4098, 2246, 2246, 2246, 2246, 2246, 2246, 2246, 22

With `ESMProteinTensor` at hand, 

... we're still reshaping our inputs - the interesting bit happens next

In [69]:
from esm.sdk.api import LogitsConfig
from esm.utils.constants import esm3 as C
from esm.utils.generation import _stack_protein_tensors
from esm.utils.structure.affine3d import build_affine3d_from_coordinates

In [52]:
sequence_lengths = [len(tokens) for tokens in input_tokens]
devices = set([t.device for t in input_tokens])

In [55]:
batched_tokens = _stack_protein_tensors(
    input_tokens, 
    sequence_lengths, 
    model.tokenizers, 
    devices.pop()
)

## Under the hood of `.forward()`

We decided to give our model 88 iterations to fill in all the blanks in our original prompt. 

At each iteration, we are going to be feeding the current Protein Tensor input through the model:

```
    output = model.forward(
        sequence_tokens=batched_tokens.sequence,
        structure_tokens=batched_tokens.structure,
        ss8_tokens=batched_tokens.secondary_structure,
        ...
    )
```

`.forward()` does 3 things:

- First it "mashes" all the protein token inputs together into a single tensor using a bunch of emedding layers
- Then this tensor is fed through the Transformer - this bit is responsible for most of the magic
- Lastly, the output created by the Transformer is reconstructed back into the protein's constituents

In [70]:
# Input prep
sequence_tokens = batched_tokens.sequence
structure_tokens = batched_tokens.structure
ss8_tokens = batched_tokens.secondary_structure
sasa_tokens = batched_tokens.sasa
function_tokens = batched_tokens.function
residue_annotation_tokens = batched_tokens.residue_annotations
average_plddt = torch.tensor(1.0, device=batched_tokens.device)
per_res_plddt = batched_tokens.coordinates.isfinite().all(dim=-1).any(dim=-1).float()
structure_coords = batched_tokens.coordinates
chain_id = None
sequence_id = None

L, device = next(
    (x.shape[1], x.device)
    for x in [
        sequence_tokens,
        structure_tokens,
        ss8_tokens,
        sasa_tokens,
        structure_coords,
        function_tokens,
        residue_annotation_tokens,
    ]
    if x is not None
)

t = model.tokenizers
defaults = lambda x, tok: (
    torch.full((1, L), tok, dtype=torch.long, device=device) if x is None else x
)

sequence_tokens = defaults(sequence_tokens, t.sequence.mask_token_id)
ss8_tokens = defaults(ss8_tokens, C.SS8_PAD_TOKEN)
sasa_tokens = defaults(sasa_tokens, C.SASA_PAD_TOKEN)
average_plddt = defaults(average_plddt, 1).float()
per_res_plddt = defaults(per_res_plddt, 0).float()
chain_id = defaults(chain_id, 0)

if residue_annotation_tokens is None:
    residue_annotation_tokens = torch.full(
        (1, L, 16), C.RESIDUE_PAD_TOKEN, dtype=torch.long, device=device
    )

if function_tokens is None:
    function_tokens = torch.full(
        (1, L, 8), C.INTERPRO_PAD_TOKEN, dtype=torch.long, device=device
    )

if structure_coords is None:
    structure_coords = torch.full(
        (1, L, 3, 3), float("nan"), dtype=torch.float, device=device
    )

structure_coords = structure_coords[
    ..., :3, :
]  # In case we pass in an atom14 or atom37 repr
affine, affine_mask = build_affine3d_from_coordinates(structure_coords)

structure_tokens = defaults(structure_tokens, C.STRUCTURE_MASK_TOKEN)
assert structure_tokens is not None
structure_tokens = (
    structure_tokens.masked_fill(structure_tokens == -1, C.STRUCTURE_MASK_TOKEN)
    .masked_fill(sequence_tokens == C.SEQUENCE_BOS_TOKEN, C.STRUCTURE_BOS_TOKEN)
    .masked_fill(sequence_tokens == C.SEQUENCE_PAD_TOKEN, C.STRUCTURE_PAD_TOKEN)
    .masked_fill(sequence_tokens == C.SEQUENCE_EOS_TOKEN, C.STRUCTURE_EOS_TOKEN)
    .masked_fill(
        sequence_tokens == C.SEQUENCE_CHAINBREAK_TOKEN,
        C.STRUCTURE_CHAINBREAK_TOKEN,
    )
)

In [104]:
structure_coords.shape

torch.Size([1, 202, 3, 3])

`.encoder()` - converting various input tokens into a single Tensor

In [71]:
x = model.encoder(
    sequence_tokens,
    structure_tokens,
    average_plddt,
    per_res_plddt,
    ss8_tokens,
    sasa_tokens,
    function_tokens,
    residue_annotation_tokens,
)

In [73]:
x.shape

torch.Size([1, 202, 1536])

`.transformer()` - for 47 out of its 48 layers, it's an ordinary transformer. The first layer uses ESM's "special sause" Geometric Attention: it makes use of affine transformations of the coordinates.

In [87]:
model.transformer

TransformerStack(
  (blocks): ModuleList(
    (0): UnifiedTransformerBlock(
      (attn): MultiHeadAttention(
        (layernorm_qkv): Sequential(
          (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1536, out_features=4608, bias=False)
        )
        (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
        (q_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (k_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (rotary): RotaryEmbedding()
      )
      (geom_attn): GeometricReasoningOriginalImpl(
        (s_norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (proj): Linear(in_features=1536, out_features=3840, bias=False)
        (out_proj): Linear(in_features=768, out_features=1536, bias=False)
      )
      (ffn): Sequential(
        (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=1536, out_features=8192, bias=False)

In [88]:
# x is a normalised version of embedding
x, embedding, _ = model.transformer(
    x, sequence_id, affine, affine_mask, chain_id
)

In [90]:
x.shape, embedding.shape

(torch.Size([1, 202, 1536]), torch.Size([1, 202, 1536]))

In [92]:
x[0, 0]

tensor([ 0.1083,  0.2119, -0.0835,  ..., -0.0580, -0.0446,  0.1248],
       grad_fn=<SelectBackward0>)

In [91]:
embedding[0, 0]

tensor([ 67.7248, 126.6239, -50.0481,  ..., -36.7244, -24.8273,  79.9671],
       grad_fn=<SelectBackward0>)

`.output_heads()` - re-assembling the structure back from the "mush" of `x`:

In [94]:
out = model.output_heads(x, embedding)

In [99]:
out.sequence_logits.shape, out.structure_logits.shape

(torch.Size([1, 202, 64]), torch.Size([1, 202, 4096]))

`.logits()` is a convenience wrapper for all these steps:

In [114]:
forward_out = model.logits(
    batched_tokens,
    LogitsConfig(
        sequence=True,
        structure=True,
        secondary_structure=True,
        sasa=True,
        function=True,
        residue_annotations=True,
        return_embeddings=True,
    ),
)

TODO: 
- explain sampling; `logits` gives us a probability distribution given 