# ESM3

NOTE: citing ESM3 and the tutorial notebook they linked as reference template for this! It was really helpful in figuring this out.

ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.

ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.

![image.png](https://github.com/evolutionaryscale/esm/blob/main/_assets/esm3_diagram.png?raw=true)





# Imports

Make sure GPU runtime first (Runtime > Change runtime type > T4 GPU).


In [1]:
%set_env TOKENIZERS_PARALLELISM=false
!pip install esm
!pip install biopython
!pip install requests
import numpy as np
import torch
import os
import requests
import json

!pip install py3Dmol
import py3Dmol
from esm.models.esm3 import ESM3
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain

env: TOKENIZERS_PARALLELISM=false
Collecting esm
  Downloading esm-3.1.1-py3-none-any.whl.metadata (11 kB)
Collecting torchtext (from esm)
  Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting biotite==0.41.2 (from esm)
  Downloading biotite-0.41.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.1 kB)
Collecting msgpack-numpy (from esm)
  Downloading msgpack_numpy-0.4.8-py2.py3-none-any.whl.metadata (5.0 kB)
Collecting biopython (from esm)
  Downloading biopython-1.84-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting brotli (from esm)
  Downloading Brotli-1.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.5 kB)
Collecting jedi>=0.16 (from ipython->esm)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading esm-3.1.1-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m



1. Go to huggingface and make an account if you don't already have one
2. Accent the ESM3 liscense here: https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1
3. Create a token with write access and save the token and put in in login()

In [2]:
from huggingface_hub import login

# put your login token here
login("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")
model = ESM3.from_pretrained("esm3_sm_open_v1", device=torch.device("cuda"))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/3.00 [00:00<?, ?B/s]

data/1utn.pdb:   0%|          | 0.00/569k [00:00<?, ?B/s]

data/esm3_entry.list:   0%|          | 0.00/1.93M [00:00<?, ?B/s]

data/entry_list_safety_29026.list:   0%|          | 0.00/1.60M [00:00<?, ?B/s]

data/ParentChildTreeFile.txt:   0%|          | 0.00/595k [00:00<?, ?B/s]

hyperplanes_8bit_58641.npz:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.30k [00:00<?, ?B/s]

data/interpro2keywords.csv:   0%|          | 0.00/7.32M [00:00<?, ?B/s]

(…)ata/interpro_29026_to_keywords_58641.csv:   0%|          | 0.00/10.1M [00:00<?, ?B/s]

hyperplanes_8bit_68103.npz:   0%|          | 0.00/34.9M [00:00<?, ?B/s]

(…)ord_vocabulary_safety_filtered_58641.txt:   0%|          | 0.00/788k [00:00<?, ?B/s]

data/keywords.txt:   0%|          | 0.00/788k [00:00<?, ?B/s]

keyword_idf_safety_filtered_58641.npy:   0%|          | 0.00/469k [00:00<?, ?B/s]

data/tag_dict_4.json:   0%|          | 0.00/691k [00:00<?, ?B/s]

data/tag_dict_4_safety_filtered.json:   0%|          | 0.00/569k [00:00<?, ?B/s]

(…)0_residue_annotations_gt_1k_proteins.csv:   0%|          | 0.00/109k [00:00<?, ?B/s]

tfidf_safety_filtered_58641.pkl:   0%|          | 0.00/2.02M [00:00<?, ?B/s]

esm3_function_decoder_v0.pth:   0%|          | 0.00/1.30G [00:00<?, ?B/s]

esm3_sm_open_v1.pth:   0%|          | 0.00/2.80G [00:00<?, ?B/s]

esm3_structure_decoder_v0.pth:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

esm3_structure_encoder_v0.pth:   0%|          | 0.00/62.3M [00:00<?, ?B/s]

  state_dict = torch.load(


In [3]:
# these are the pdb IDs for the proteins encoded by the filovirus strain Mayinga-76
filo_proteins_pdbs = ["5DVW", "4ZTA", "6VKM", "4LDB", "6EHL"]

# Now scaffolding a domain from the proteins to generate new ones


I'm just showing the first protein in detail so that it is clear what all the steps are. Then, I have a function to run this-- getting the pdb files for both the original protein and the generated one with the conserved domain.


In [4]:
pdb_id = filo_proteins_pdbs[0]
# chain_id = "A"
chain = ProteinChain.from_rcsb(pdb_id)

The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein


`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein.

The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**.


In [5]:
print("atom37_positions shape: ", chain.atom37_positions.shape)

atom37_positions shape:  (132, 37, 3)


We can visualize the protein chain using the `py3Dmol` library


In [6]:
# First we can create a `py3Dmol` view object
view = py3Dmol.view(width=500, height=500)
# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
pdb_str = chain.to_pdb_string()
# Load the PDB string into the `py3Dmol` view object
view.addModel(pdb_str, "pdb")
# Set the style of the protein chain
view.setStyle({"cartoon": {"color": "spectrum"}})
# Zoom in on the protein chain
view.zoomTo()
# Display the protein chain
view.show()

In [7]:
# Residues of PDB are 142-272, but residue indicies for the domain are 202-237
# This is the math for the relative start and end
domain_start, domain_end = 202, 237
chain_start = 147

residue_start = domain_start - chain_start
residue_end = domain_end - chain_start

In [8]:
domain_inds = np.arange(residue_start, residue_end + 1)
# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues
domain_sequence = chain[domain_inds].sequence
domain_atom37_positions = chain[domain_inds].atom37_positions
print("Motif sequence: ", domain_sequence)
print("Motif atom37_positions shape: ", domain_atom37_positions.shape)

Motif sequence:  RREGLGQDQAEPVLEVYQRLHSDKGGSFEAALWQQW
Motif atom37_positions shape:  (36, 37, 3)


We can also visualize the motif in the original chain using `py3Dmol`. The domain is in cyan blue


In [9]:
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
domain_inds_visualize = np.arange(domain_start, domain_end)
domain_res_inds = (
    domain_inds_visualize + 1
).tolist()  # residue indices are 1-indexed in PDB files, so we add 1 to the indices
view.addStyle({"resi": domain_res_inds}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()

Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the domain


In [10]:
prompt_length = 200
# First, we can construct a sequence prompt of all masks
sequence_prompt = ["_"] * prompt_length
# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)
sequence_prompt[72 : 72 + len(domain_sequence)] = list(domain_sequence)
sequence_prompt = "".join(sequence_prompt)
print("Sequence prompt: ", sequence_prompt)
print("Length of sequence prompt: ", len(sequence_prompt))

# Next, we can construct a structure prompt of all nan coordinates
structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72
structure_prompt[72 : 72 + len(domain_atom37_positions)] = torch.tensor(
    domain_atom37_positions
)
print("Structure prompt shape: ", structure_prompt.shape)
print(
    "Indices with structure conditioning: ",
    torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),
)

# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3
protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)

Sequence prompt:  ________________________________________________________________________RREGLGQDQAEPVLEVYQRLHSDKGGSFEAALWQQW____________________________________________________________________________________________
Length of sequence prompt:  200
Structure prompt shape:  torch.Size([200, 37, 3])
Indices with structure conditioning:  [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107]


Now, we can use the `generate` method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked.


In [11]:
# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use
sequence_generation_config = GenerationConfig(
    track="sequence",  # We want ESM3 to generate tokens for the sequence track
    num_steps=sequence_prompt.count("_")
    // 2,  # We'll use num(mask tokens) // 2 steps to decode the sequence
    temperature=0.5,  # We'll use a temperature of 0.5 to control the randomness of the decoding process
)

# Now, we can use the `generate` method of the model to decode the sequence
sequence_generation = model.generate(protein_prompt, sequence_generation_config)
print("Sequence Prompt:\n\t", protein_prompt.sequence)
print("Generated sequence:\n\t", sequence_generation.sequence)

  state_dict = torch.load(
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore
100%|██████████| 82/82 [00:23<00:00,  3.42it/s]
  state_dict = torch.load(
  state_dict = torch.load(


Sequence Prompt:
	 ________________________________________________________________________RREGLGQDQAEPVLEVYQRLHSDKGGSFEAALWQQW____________________________________________________________________________________________
Generated sequence:
	 MRGALEALREAVDAALSAAAPDALEQALLALELEREERAALLALARLRRGSAPAPVPALDRALAEALTKELLRREGLGQDQAEPVLEVYQRLHSDKGGSFEAALWQQWAEHEAALLRLLVRELARRAGRDPPAEWRLPLLALAAALVRPAAAAEQRRALAALRLLAARPRLRRALEVLLAAPGEAAAAAALPLALAAAAS


We can also use the `generate` method to predict the structure of the generated sequence by iteratively sampling structure tokens.


In [12]:
structure_prediction_config = GenerationConfig(
    track="structure",  # use ESM3 to generate tokens for the structure track
    num_steps=len(sequence_generation) // 8,
    temperature=0.7,
)
structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
structure_prediction = model.generate(
    structure_prediction_prompt, structure_prediction_config
)

100%|██████████| 25/25 [00:07<00:00,  3.37it/s]


Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the domain was drawn. The domain residues are colored in cyan.


In [13]:
# convert the generated structure to a back into a ProteinChain object
structure_prediction_chain = structure_prediction.to_protein_chain()
# Align the generated structure to the original structure using the motif residues
domain_inds_in_generation = np.arange(72, 72 + len(domain_sequence))
structure_prediction_chain.align(
    chain, mobile_inds=domain_inds_in_generation, target_inds=domain_inds
)
crmsd = structure_prediction_chain.rmsd(
    chain, mobile_inds=domain_inds_in_generation, target_inds=domain_inds
)
print(
    "cRMSD of the motif in the generated structure vs the original structure: ", crmsd
)

view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(pdb_str, "pdb", viewer=(0, 0))
view.addModel(structure_prediction_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgreen"}}, viewer=(0, 1))
view.addStyle({"resi": domain_res_inds}, {"cartoon": {"color": "cyan"}}, viewer=(0, 0))
view.addStyle(
    {"resi": (domain_inds_in_generation + 1).tolist()},
    {"cartoon": {"color": "cyan"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()

cRMSD of the motif in the generated structure vs the original structure:  2.7580018106490227


In [50]:
def save_highlighted(filename, residue_inds):
  highlight_file = f"{filename}.json"
  with open(highlight_file, "w") as f:
      json.dump({"highlighted_residues": residue_inds}, f)

def save_crmsd(filename, crmsd):
  with open(filename, "w") as f:
      json.dump({"crmsd": crmsd}, f)

def generate_and_save(pdb_id, domain_start_end, chain_start, prompt_length, insert_ind, view=True):
    # Create an output folder for the PDB ID
    output_dir = f"{pdb_id}_output"
    os.makedirs(output_dir, exist_ok=True)

    domain_start, domain_end = domain_start_end
    residue_start = domain_start - chain_start
    residue_end = domain_end - chain_start
    domain_inds = np.arange(residue_start, residue_end + 1)
    print(domain_inds)
    chain = ProteinChain.from_rcsb(pdb_id)
    print(len(chain.sequence))
    domain_sequence = chain[domain_inds].sequence
    domain_atom37_positions = chain[domain_inds].atom37_positions

    # generating the sequence
    sequence_prompt = ["_"] * prompt_length
    sequence_prompt[insert_ind : insert_ind + len(domain_sequence)] = list(domain_sequence)
    sequence_prompt = "".join(sequence_prompt)
    structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
    structure_prompt[insert_ind : insert_ind + len(domain_atom37_positions)] = torch.tensor(
        domain_atom37_positions
    )
    protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)
    sequence_generation_config = GenerationConfig(
        track="sequence",
        num_steps=sequence_prompt.count("_")
        // 2,
        temperature=0.5,
    )
    sequence_generation = model.generate(protein_prompt, sequence_generation_config)

    # generating the structure
    structure_prediction_config = GenerationConfig(
        track="structure",  # use ESM3 to generate tokens for the structure track
        num_steps=len(sequence_generation) // 8,
        temperature=0.7,
    )
    structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
    structure_prediction = model.generate(structure_prediction_prompt, structure_prediction_config)

    # convert the generated structure to a back into a ProteinChain object
    structure_prediction_chain = structure_prediction.to_protein_chain()
    # Align the generated structure to the original structure using the domain residues
    domain_inds_in_generation = np.arange(insert_ind, insert_ind + len(domain_sequence))
    structure_prediction_chain.align(
        chain, mobile_inds=domain_inds_in_generation, target_inds=domain_inds
    )
    crmsd = structure_prediction_chain.rmsd(
        chain, mobile_inds=domain_inds_in_generation, target_inds=domain_inds
    )

    save_crmsd(os.path.join(output_dir, "crmsd.json"), crmsd)

    if view:
        print("cRMSD of the domain in the generated structure vs the original structure: ", crmsd)

        view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
        view.addModel(chain.to_pdb_string(), "pdb", viewer=(0, 0))
        view.addModel(structure_prediction_chain.to_pdb_string(), "pdb", viewer=(0, 1))
        view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 0))
        view.setStyle({"cartoon": {"color": "lightgreen"}}, viewer=(0, 1))
        domain_inds_visualize = np.arange(domain_start, domain_end)
        domain_res_inds = (
            domain_inds_visualize + 1
        ).tolist()  # residue indices are 1-indexed in PDB files, so we add 1 to the indices
        view.addStyle({"resi": domain_res_inds}, {"cartoon": {"color": "cyan"}}, viewer=(0, 0))
        view.addStyle(
            {"resi": (domain_inds_in_generation + 1).tolist()},
            {"cartoon": {"color": "cyan"}},
            viewer=(0, 1),
        )
        view.zoomTo()
        view.show()

        # Save both structures to PDB files in the output directory
        with open(os.path.join(output_dir, "original_structure.pdb"), "w") as f:
            f.write(chain.to_pdb_string())

        with open(os.path.join(output_dir, "generated_structure.pdb"), "w") as f:
            f.write(structure_prediction_chain.to_pdb_string())

        save_highlighted(os.path.join(output_dir, f"original_domain.json"), domain_res_inds)
        save_highlighted(os.path.join(output_dir, f"generated_domain.json"), (domain_inds_in_generation + 1).tolist())

    return crmsd


In [48]:
# reference the write up or the static website for information on the following domains / motifs

# the tuple is the start and end of the domain region. the next number is the index of the start of the pdb file

pdbs_and_domain_residue_info = {"5DVW": [(202, 237), 142], "4ZTA": [(33, 48), 15], "6F6N": [(54, 201), 32], "4EJE": [(7, 10), 5], "5T3T": [(606, 611), 600]}

for pdb_id, (domain_start_end, chain_start) in pdbs_and_domain_residue_info.items():
  generate_and_save(pdb_id, domain_start_end, chain_start, 300, 72)

100%|██████████| 132/132 [00:56<00:00,  2.33it/s]
100%|██████████| 37/37 [00:15<00:00,  2.39it/s]


cRMSD of the domain in the generated structure vs the original structure:  2.82105281163127


100%|██████████| 142/142 [01:00<00:00,  2.35it/s]
100%|██████████| 37/37 [00:15<00:00,  2.37it/s]


cRMSD of the domain in the generated structure vs the original structure:  3.1634624259527224


100%|██████████| 76/76 [00:32<00:00,  2.36it/s]
100%|██████████| 37/37 [00:16<00:00,  2.28it/s]


cRMSD of the domain in the generated structure vs the original structure:  9.86611509622985


100%|██████████| 148/148 [01:02<00:00,  2.36it/s]
100%|██████████| 37/37 [00:15<00:00,  2.36it/s]


cRMSD of the domain in the generated structure vs the original structure:  0.14482065255789855


100%|██████████| 147/147 [01:02<00:00,  2.37it/s]
100%|██████████| 37/37 [00:15<00:00,  2.37it/s]


cRMSD of the domain in the generated structure vs the original structure:  1.1307616428075473
