In [12]:
import py3Dmol
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
from esm.utils.structure.protein_chain import ProteinChain
from esm.models.esm3 import ESM3
from huggingface_hub import login
from esm.sdk import client
import requests
from bs4 import BeautifulSoup
from esm.sdk.api import (
    ESM3InferenceClient,
    ESMProtein,
    GenerationConfig,
)
# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
login()
model: ESM3InferenceClient = ESM3.from_pretrained("esm3-open").to("cuda") # or "cpu"
url = "http://prodata.swmed.edu/ecod/af2_pdb/domain/"


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [13]:
df = pd.read_csv('OMBB_data.csv')
df.head()

Unnamed: 0,id,strands,seq,seq_len
0,e1af6A1,18,VDFHGYARSGIGWTGSGGEQQCFQTTGAQSKYRLGNECETYAELKL...,421
1,e1kmoA2,22,IPQDFGIEAGVEGQLSPTSSQNNPKETHNLMVGGTADNGFGTALLY...,523
2,e1p4tA1,8,EGASGFYVQADAAHAKASSSLGSAKGFSPRISAGYRINDLRFAVDY...,155
3,e1prnA1,16,EISLNGYGRFGLQYVEDRGVGLEDTIISSRLRINIVGTTETDQGVT...,289
4,e1qd5A1,12,AVRGSIIANMLQEHDNPFTLYPYDTNYLIYTQTSDLNKEAIASYDW...,257


In [14]:
def getPdbId(id, url):
    try:
        response = requests.get(url + id)
        soup = BeautifulSoup(response.text, 'html.parser')
        pdb_id = None
        link = soup.find('a', title="Link to PDB")
        if link:
            href = link['href']
            pdb_id = href.split("structureId=")[-1]
        if pdb_id is None:
            print(f'No PDB ID found for {id}')
        return pdb_id
    except Exception as e:
        print(f'Error: {e}')
        return None


In [15]:
import pickle
import os
cache_path = 'protein_chains.pkl'
if os.path.exists(cache_path):
    with open(cache_path, 'rb') as file:
        protein_chains = pickle.load(file)
    print("Loaded list")
else:
    protein_chains = []
    for idx,row in tqdm(df.iterrows(), total=len(df), desc='Fetching ProteinChains'):
        id = row['id']
        pdb_id = getPdbId(id, url)
        out_membraine_chain = ProteinChain.from_rcsb(pdb_id)
        protein_chains.append(out_membraine_chain)
    with open(cache_path, 'wb') as file:
        pickle.dump(protein_chains, file)
    print("List saved successfully!")

Loaded list


In [16]:
print(len(protein_chains))
print(protein_chains[0].sequence)

example_chain = protein_chains[1]

102
VDFHGYARSGIGWTGSGGEQQCFQTTGAQSKYRLGNECETYAELKLGQEVWKEGDKSFYFDTNVAYSVAQQNDWEATDPAFREANVQGKNLIEWLPGSTIWAGKRFYQRHDVHMIDFYYWDISGPGAGLENIDVGFGKLSLAATRSSEAGGSSSFASNNIYDYTNETANDVFDVRLAQMEINPGGTLELGVDYGRANLRDNYRLVDGASKDGWLFTAEHTQSVLKGFNKFVVQYATDSMTSQGKGLSQGSGVAFDNEKFAYNINNNGHMLRILDHGAISMGDNWDMMYVGMYQDINWDNDNGTKWWTVGIRPMYKWTPIMSTVMEIGYDNVESQRTGDKNNQYKITLAQQWQAGDSIWSRPAIRVFATYAKWDEKWGYDYTGNADNNANFGKAVPADFNGGSFGRGDSDEWTFGAQMEIWW


In [17]:
view = py3Dmol.view(width=500, height=500)
# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
pdb_str = example_chain.to_pdb_string()
# Load the PDB string into the `py3Dmol` view object
view.addModel(pdb_str, "pdb")
# Set the style of the protein chain
view.setStyle({"cartoon": {"color": "spectrum"}})
# Zoom in on the protein chain
view.zoomTo()
# Display the protein chain
view.show()

In [18]:
maskPercent = 0.1
maskPos = int(len(example_chain.sequence) * (1-maskPercent))
maskAmount = len(example_chain.sequence) - maskPos
# Create a mask for the sequence
print(f"Masking {maskPercent * 100}% ({maskAmount}) chars at the end of the sequence.")
sequence_prompt = example_chain.sequence[:maskPos] + ''.join(['_'] * maskAmount)
print("Sequence prompt:", sequence_prompt)

Masking 10.0% (67) chars at the end of the sequence.
Sequence prompt: ALTVVGDWLGDARENDVFEHAGARDVIRREDFAKTGATTMREVLNRIPGVSAPENNGTGSHDLAMNFGIRGLNPRLASRSTVLMDGIPVPFAPYGQPQLSLAPVSLGNMDAIDVVRGGGAVRYGPQSVGGVVNFVTRAIPQDFGIEAGVEGQLSPTSSQNNPKETHNLMVGGTADNGFGTALLYSGTRGSDWREHSATRIDDLMLKSKYAPDEVHTFNSLLQYYDGEADMPGGLSRADYDADRWQSTRPYDRFWGRRKLASLGYQFQPDSQHKFNIQGFYTQTLRSGYLEQGKRITLSPRNYWVRGIEPRYSQIFMIGPSAHEVGVGYRYLNESTHEMRYYTATSSGQLPSGSSPYDRDTRSGTEAHAWYLDDKIDIGNWTITPGMRFEHIESYQNNAITGTHEEVSYNAPLPALNVLYHLTDSWNLYANTEGSFGTVQYSQIGKAVQSGNVEPEKARTWELGTRYDDGALTAEMGLFLINFNNQYDSNQTNDTVTARGKTRHTGLETQARYDLGTLTPTLDNVSIYASYAYVNAEIREKGDTYGNLVPFSPKHKGTLGVDYKPGNWTFNLNSDFQSSQFADNANTVKESADGS___________________________________________________________________


In [19]:
view = py3Dmol.view(width=500, height=500)
inds = np.arange(0,maskPos)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
motif_res_inds = (
    inds + 1
).tolist()  # residue indices are 1-indexed in PDB files, so we add 1 to the indices
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()

In [20]:
sequence_generation_config = GenerationConfig(
    track="sequence",  # We want ESM3 to generate tokens for the sequence track
    num_steps=sequence_prompt.count("_")
    // 2,  # We'll use num(mask tokens) // 2 steps to decode the sequence
    temperature=0.5,  # We'll use a temperature of 0.5 to control the randomness of the decoding process
)
structure_prediction_config = GenerationConfig(
    track="structure",  # We want ESM3 to generate tokens for the structure track
    num_steps=len(sequence_prompt) // 8,
    temperature=0.7,
)
protein = ESMProtein(sequence=sequence_prompt)
# Now, we can use the `generate` method of the model to decode the sequence
sequence_generation = model.generate(protein, sequence_generation_config)
print("Sequence Prompt:\n\t", protein.sequence)
print("Generated sequence:\n\t", sequence_generation.sequence)
structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
structure_prediction = model.generate(
    structure_prediction_prompt, structure_prediction_config
)

100%|██████████| 33/33 [00:28<00:00,  1.15it/s]


Sequence Prompt:
	 ALTVVGDWLGDARENDVFEHAGARDVIRREDFAKTGATTMREVLNRIPGVSAPENNGTGSHDLAMNFGIRGLNPRLASRSTVLMDGIPVPFAPYGQPQLSLAPVSLGNMDAIDVVRGGGAVRYGPQSVGGVVNFVTRAIPQDFGIEAGVEGQLSPTSSQNNPKETHNLMVGGTADNGFGTALLYSGTRGSDWREHSATRIDDLMLKSKYAPDEVHTFNSLLQYYDGEADMPGGLSRADYDADRWQSTRPYDRFWGRRKLASLGYQFQPDSQHKFNIQGFYTQTLRSGYLEQGKRITLSPRNYWVRGIEPRYSQIFMIGPSAHEVGVGYRYLNESTHEMRYYTATSSGQLPSGSSPYDRDTRSGTEAHAWYLDDKIDIGNWTITPGMRFEHIESYQNNAITGTHEEVSYNAPLPALNVLYHLTDSWNLYANTEGSFGTVQYSQIGKAVQSGNVEPEKARTWELGTRYDDGALTAEMGLFLINFNNQYDSNQTNDTVTARGKTRHTGLETQARYDLGTLTPTLDNVSIYASYAYVNAEIREKGDTYGNLVPFSPKHKGTLGVDYKPGNWTFNLNSDFQSSQFADNANTVKESADGS___________________________________________________________________
Generated sequence:
	 ALTVVGDWLGDARENDVFEHAGARDVIRREDFAKTGATTMREVLNRIPGVSAPENNGTGSHDLAMNFGIRGLNPRLASRSTVLMDGIPVPFAPYGQPQLSLAPVSLGNMDAIDVVRGGGAVRYGPQSVGGVVNFVTRAIPQDFGIEAGVEGQLSPTSSQNNPKETHNLMVGGTADNGFGTALLYSGTRGSDWREHSATRIDDLMLKSKYAPDEVHTFNSLLQYYDGEADMPGGLSRADYDADRWQSTRPYDRFWGRRKLASLGYQFQPDSQHKFNIQGFYTQTLRSGYLEQGKRITL

100%|██████████| 82/82 [01:12<00:00,  1.13it/s]


In [21]:
# Convert the structure prediction to a ProteinChain object
structure_prediction_chain = structure_prediction.to_protein_chain()

# None masked positions
inds = np.arange(0,maskPos)

# Align the generated structure with the original structure using the non-masked sequence
structure_prediction_chain.align(
    example_chain, mobile_inds=inds, target_inds=inds
)

# Calculate RMSD
crmsd = structure_prediction_chain.rmsd(
    example_chain, mobile_inds=inds, target_inds=inds
)
print("cRMSD of the motif in the generated structure vs the original structure: ", crmsd)

# Visualize with py3Dmol
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(pdb_str, "pdb", viewer=(0, 0))
view.addModel(structure_prediction_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgreen"}}, viewer=(0, 1))
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}}, viewer=(0, 0))
view.addStyle(
    {"resi": (inds + 1).tolist()},
    {"cartoon": {"color": "cyan"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()


cRMSD of the motif in the generated structure vs the original structure:  1.5776776601762212
