In [2]:
import py3Dmol
import numpy as np
import torch
import pandas as pd
from esm.utils.structure.protein_chain import ProteinChain
from esm.models.esm3 import ESM3
from huggingface_hub import login
from esm.sdk import client
from esm.sdk.api import (
    ESM3InferenceClient,
    ESMProtein,
    GenerationConfig,
)
# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
login()
model: ESM3InferenceClient = ESM3.from_pretrained("esm3-open").to("cuda") # or "cpu"


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [18]:
pdb_id = "1AF6"
out_membraine_chain = ProteinChain.from_rcsb(pdb_id)


In [19]:
print(out_membraine_chain.sequence)

VDFHGYARSGIGWTGSGGEQQCFQTTGAQSKYRLGNECETYAELKLGQEVWKEGDKSFYFDTNVAYSVAQQNDWEATDPAFREANVQGKNLIEWLPGSTIWAGKRFYQRHDVHMIDFYYWDISGPGAGLENIDVGFGKLSLAATRSSEAGGSSSFASNNIYDYTNETANDVFDVRLAQMEINPGGTLELGVDYGRANLRDNYRLVDGASKDGWLFTAEHTQSVLKGFNKFVVQYATDSMTSQGKGLSQGSGVAFDNEKFAYNINNNGHMLRILDHGAISMGDNWDMMYVGMYQDINWDNDNGTKWWTVGIRPMYKWTPIMSTVMEIGYDNVESQRTGDKNNQYKITLAQQWQAGDSIWSRPAIRVFATYAKWDEKWGYDYTGNADNNANFGKAVPADFNGGSFGRGDSDEWTFGAQMEIWW


In [20]:
view = py3Dmol.view(width=500, height=500)
# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
pdb_str = out_membraine_chain.to_pdb_string()
# Load the PDB string into the `py3Dmol` view object
view.addModel(pdb_str, "pdb")
# Set the style of the protein chain
view.setStyle({"cartoon": {"color": "spectrum"}})
# Zoom in on the protein chain
view.zoomTo()
# Display the protein chain
view.show()

In [6]:
motif_inds = np.arange(123, 146)
# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues
motif_sequence = out_membraine_chain[motif_inds].sequence
motif_atom37_positions = out_membraine_chain[motif_inds].atom37_positions
print("Motif sequence: ", motif_sequence)
print("Motif atom37_positions shape: ", motif_atom37_positions.shape)

Motif sequence:  FAGGVEYAITPEIATRLEYQWTN
Motif atom37_positions shape:  (23, 37, 3)


In [7]:
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
motif_res_inds = (
    motif_inds + 1
).tolist()  # residue indices are 1-indexed in PDB files, so we add 1 to the indices
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()

In [8]:
prompt_length = 200
# First, we can construct a sequence prompt of all masks
sequence_prompt = ["_"] * prompt_length
# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)
sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)
sequence_prompt = "".join(sequence_prompt)
print("Sequence prompt: ", sequence_prompt)
print("Length of sequence prompt: ", len(sequence_prompt))

# Next, we can construct a structure prompt of all nan coordinates
structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72
structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(
    motif_atom37_positions
)
print("Structure prompt shape: ", structure_prompt.shape)
print(
    "Indices with structure conditioning: ",
    torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),
)

# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3
protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)

Sequence prompt:  ________________________________________________________________________FAGGVEYAITPEIATRLEYQWTN_________________________________________________________________________________________________________
Length of sequence prompt:  200
Structure prompt shape:  torch.Size([200, 37, 3])
Indices with structure conditioning:  [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94]


In [9]:
# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use
sequence_generation_config = GenerationConfig(
    track="sequence",  # We want ESM3 to generate tokens for the sequence track
    num_steps=sequence_prompt.count("_")
    // 2,  # We'll use num(mask tokens) // 2 steps to decode the sequence
    temperature=0.5,  # We'll use a temperature of 0.5 to control the randomness of the decoding process
)

# Now, we can use the `generate` method of the model to decode the sequence
sequence_generation = model.generate(protein_prompt, sequence_generation_config)
print("Sequence Prompt:\n\t", protein_prompt.sequence)
print("Generated sequence:\n\t", sequence_generation.sequence)

  context_BHLD = F.scaled_dot_product_attention(
100%|██████████| 88/88 [00:04<00:00, 19.46it/s]


Sequence Prompt:
	 ________________________________________________________________________FAGGVEYAITPEIATRLEYQWTN_________________________________________________________________________________________________________
Generated sequence:
	 MKKSLLALALAVAAAPAMAQSAAPVVEAAAPAAAPAAAASGWSVGGSVGYAFGKLDGGGLDGAKVDVKGDGTFAGGVEYAITPEIATRLEYQWTNDKKVGDSTLKTENFAVTALYLLPVAEGFDVYGLLGVANSKAEVKAAGSVSDSKWGLQAGVGAKYAVTDNVFVGAEYRYTELKYDVGGADVTANPNIISVGLGYKF


In [10]:
structure_prediction_config = GenerationConfig(
    track="structure",  # We want ESM3 to generate tokens for the structure track
    num_steps=len(sequence_generation) // 8,
    temperature=0.7,
)
structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
structure_prediction = model.generate(
    structure_prediction_prompt, structure_prediction_config
)

100%|██████████| 25/25 [00:01<00:00, 19.58it/s]


In [11]:
# Convert the generated structure to a back into a ProteinChain object
structure_prediction_chain = structure_prediction.to_protein_chain()
# Align the generated structure to the original structure using the motif residues
motif_inds_in_generation = np.arange(72, 72 + len(motif_sequence))
structure_prediction_chain.align(
    out_membraine_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)
crmsd = structure_prediction_chain.rmsd(
    out_membraine_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)
print(
    "cRMSD of the motif in the generated structure vs the original structure: ", crmsd
)

view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(pdb_str, "pdb", viewer=(0, 0))
view.addModel(structure_prediction_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgreen"}}, viewer=(0, 1))
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}}, viewer=(0, 0))
view.addStyle(
    {"resi": (motif_inds_in_generation + 1).tolist()},
    {"cartoon": {"color": "cyan"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()

cRMSD of the motif in the generated structure vs the original structure:  0.29928795446959316
