# ESM3 Tests

In [None]:
# from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.sdk import client

In [7]:
model = client(model="esm3-small-2024-08", url="https://forge.evolutionaryscale.ai", token="")

In [19]:
sequence = "FVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYC"

In [20]:
structure_prediction_config = GenerationConfig(
    track="structure",  # We want ESM3 to generate tokens for the structure track
    num_steps=8,
    temperature=0.7,
)

structure_prediction_prompt = ESMProtein(sequence=sequence)

structure_prediction = model.generate(
    structure_prediction_prompt, structure_prediction_config
)

In [None]:
# structure_prediction

In [22]:
structure_prediction_chain = structure_prediction.to_protein_chain()

In [None]:
structure_prediction_chain.to_pdb_string()

In [None]:
## Use py3Dmol to visualize the structure
import py3Dmol

view = py3Dmol.view(width=800, height=800)
view.addModel(structure_prediction_chain.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
view.zoomTo()
view.show()

## Multimer

In [2]:
from esm.sdk import client
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain
import string

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
model_name="esm3-medium-multimer-2024-09"
token="365gvQdkuVGi7nHejbtGPA"

In [25]:
## PDB: 5JXE
h_seq = "QVQLVQSGVEVKKPGASVKVSCKASGYTFTNYYMYWVRQAPGQGLEWMGGINPSNGGTNFNEKFKNRVTLTTDSSTTTAYMELKSLQFDDTAVYYCARRDYRFDMGFDYWGQGTTVTVSSASTKGPSVFPLAPCSRSTSESTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTKTYTCNVDHKPSNTKVDKRV"
l_seq = "EIVLTQSPATLSLSPGERATLSCRASKGVSTSGYSYLHWYQQKPGQAPRLLIYLASYLESGVPARFSGSGSGTDFTLTISSLEPEDFAVYYCQHSRDLPLTFGGGTKVEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC"
a_seq = "NPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDSRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQIKESLRAELRVTE"
sequence = "|".join([h_seq, l_seq, a_seq])

In [47]:
sequence

'QVQLVQSGVEVKKPGASVKVSCKASGYTFTNYYMYWVRQAPGQGLEWMGGINPSNGGTNFNEKFKNRVTLTTDSSTTTAYMELKSLQFDDTAVYYCARRDYRFDMGFDYWGQGTTVTVSSASTKGPSVFPLAPCSRSTSESTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTKTYTCNVDHKPSNTKVDKRV|EIVLTQSPATLSLSPGERATLSCRASKGVSTSGYSYLHWYQQKPGQAPRLLIYLASYLESGVPARFSGSGSGTDFTLTISSLEPEDFAVYYCQHSRDLPLTFGGGTKVEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC|NPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDSRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQIKESLRAELRVTE'

In [54]:
def fix_multichain_pdb_str(pdb_string:str="", split_res:str = "UNK") -> str:
    """
    Fixes chain identifiers in a PDB string for multiple chains.
    Uses the <UNK> residue identifier to split chains.

    Args:
        pdb_str (str): Raw PDB output from a ESM3 multimer model
        split_res (str): Residue identifier to split chains. Default is "UNK".

    Returns:
        str: PDB string with proper chain identifiers.
    """
    chain_ids = list(string.ascii_uppercase + string.ascii_lowercase)  # A-Z + a-z = 52 chains max

    fixed_lines = []
    chain_idx = 0
    for line in pdb_string.splitlines():
        if line.startswith("ATOM") or line.startswith("HETATM"):
            residue = line[17:20]
            ## Place the ATOM in a chain
            line = line[:21] + chain_ids[chain_idx] + line[22:]
            if residue == split_res:
                chain_idx += 1
                continue  ## Skip unknown residues
        fixed_lines.append(line)
    return "\n".join(fixed_lines)


In [43]:
def fold_sequence(sequence: str, model_name: str, token: str) -> str:

    model = client(model=model_name, url="https://forge.evolutionaryscale.ai", token=token)

    sequence = sequence.replace(" ", "").replace("\n", "")
    chains = sequence.split("|")

    ## Generate the structure
    structure_prediction = model.generate(
        ESMProtein(sequence=sequence),
        GenerationConfig(
            track="structure", num_steps=len(sequence) // 4, temperature=0
        ),
    ) 
    
    ## Fix the PDB string to have separate chain identifiers and remove "UNK" residues
    pdb_str = fix_multichain_pdb_str(structure_prediction.to_pdb_string())
    return pdb_str


In [44]:
structure_prediction

ESMProteinError(error_code=401, error_msg='Failure in generate: {"status":"error","message":"Unauthorized"}')

In [45]:
complex_prediction = fold_sequence(sequence=sequence, model_name=model_name, token=token)

In [46]:
## write to file
with open("5JXE_mm_pred_split.pdb", "w") as f:
    f.write(complex_prediction)

In [40]:
complex_prediction.split('\n')

['ATOM      1  N   GLN A   1     -15.794 -13.325   5.257  1.00  0.93           N  ',
 'ATOM      2  CA  GLN A   1     -14.375 -13.438   4.938  1.00  0.93           C  ',
 'ATOM      3  C   GLN A   1     -14.041 -12.693   3.648  1.00  0.93           C  ',
 'ATOM      4  CB  GLN A   1     -13.966 -14.906   4.816  1.00  0.93           C  ',
 'ATOM      5  O   GLN A   1     -14.340 -13.172   2.553  1.00  0.93           O  ',
 'ATOM      6  CG  GLN A   1     -14.314 -15.744   6.039  1.00  0.93           C  ',
 'ATOM      7  CD  GLN A   1     -15.805 -15.986   6.181  1.00  0.93           C  ',
 'ATOM      8  NE2 GLN A   1     -16.213 -16.513   7.330  1.00  0.93           N  ',
 'ATOM      9  OE1 GLN A   1     -16.583 -15.701   5.265  1.00  0.93           O  ',
 'ATOM     10  N   VAL A   2     -13.707 -11.569   4.068  1.00  0.96           N  ',
 'ATOM     11  CA  VAL A   2     -13.062 -10.750   3.047  1.00  0.96           C  ',
 'ATOM     12  C   VAL A   2     -11.796 -11.444   2.550  1.00  0

In [55]:
try:
    from esm.sdk import client
    from esm.sdk.api import ESMProtein, GenerationConfig
except ModuleNotFoundError as e:
    raise Exception(f"esm module not found: {str(e)}")

temperature=0.7
num_steps=8

## PDB: 5JXE
h_seq = "QVQLVQSGVEVKKPGASVKVSCKASGYTFTNYYMYWVRQAPGQGLEWMGGINPSNGGTNFNEKFKNRVTLTTDSSTTTAYMELKSLQFDDTAVYYCARRDYRFDMGFDYWGQGTTVTVSSASTKGPSVFPLAPCSRSTSESTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTKTYTCNVDHKPSNTKVDKRV"
l_seq = "EIVLTQSPATLSLSPGERATLSCRASKGVSTSGYSYLHWYQQKPGQAPRLLIYLASYLESGVPARFSGSGSGTDFTLTISSLEPEDFAVYYCQHSRDLPLTFGGGTKVEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC"
a_seq = "NPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDSRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQIKESLRAELRVTE"
aa_sequence = "|".join([h_seq, l_seq, a_seq])

model_name="esm3-medium-multimer-2024-09"
token="365gvQdkuVGi7nHejbtGPA"

try:
    model = client(model=model_name, url="https://forge.evolutionaryscale.ai", token=token)
except Exception as e:
    raise Exception(f"Error getting ESM model with token: {str(e)}")

## Generate the protein structure
structure_prediction_config = GenerationConfig(
    track="structure",
    num_steps=num_steps,
    temperature=temperature,
)

structure_prediction_prompt = ESMProtein(sequence=aa_sequence)

structure_prediction = model.generate(
    structure_prediction_prompt,
    structure_prediction_config
)

structure_prediction_chain = structure_prediction.to_protein_chain()

pdb_string = structure_prediction_chain.to_pdb_string()

## If multimer model was used, fix the chain identifiers
if "multimer" in model_name:
    pdb_string = fix_multichain_pdb_str(pdb_string=pdb_string)

pdb_string

'ATOM      1  N   GLN A   1     -16.382 -13.047   4.084  1.00  0.94           N  \nATOM      2  CA  GLN A   1     -14.938 -13.250   4.031  1.00  0.94           C  \nATOM      3  C   GLN A   1     -14.325 -12.535   2.830  1.00  0.94           C  \nATOM      4  CB  GLN A   1     -14.606 -14.742   3.980  1.00  0.94           C  \nATOM      5  O   GLN A   1     -14.174 -13.128   1.759  1.00  0.94           O  \nATOM      6  CG  GLN A   1     -15.355 -15.575   5.011  1.00  0.94           C  \nATOM      7  CD  GLN A   1     -16.638 -16.173   4.463  1.00  0.94           C  \nATOM      8  NE2 GLN A   1     -17.370 -16.885   5.312  1.00  0.94           N  \nATOM      9  OE1 GLN A   1     -16.968 -15.994   3.286  1.00  0.94           O  \nATOM     10  N   VAL A   2     -13.943 -11.418   3.316  1.00  0.96           N  \nATOM     11  CA  VAL A   2     -13.188 -10.562   2.406  1.00  0.96           C  \nATOM     12  C   VAL A   2     -11.925 -11.287   1.944  1.00  0.96           C  \nATOM     13  CB

## Visualize

In [31]:
## Use py3Dmol to visualize the structure
import py3Dmol

view = py3Dmol.view(width=800, height=800)
view.addModel(complex_prediction, "pdb")
# view.addModel(folded_protein.to_pdb_string(), "pdb")
# view.setStyle({"cartoon": {"color": "spectrum"}})
## color by chain
view.setStyle({"cartoon": {"color": "blue"}, "chain": "A"})
view.setStyle({"cartoon": {"color": "green"}, "chain": "B"})
view.setStyle({"cartoon": {"color": "red"}, "chain": "C"})
view.zoomTo()
view.show()