# ESM3 Inverse Folding Notebook

This notebook is intended to be used as a tool for inverse folding using the ESM3 model.


In [None]:
# @title Input API keys, then hit `Runtime` -> `Run all`
# @markdown Our hosted service that provides access to the full suite of ESM3 models.
# @markdown To utilize the Forge API, users must first agree to the [Terms of Service](https://forge.evolutionaryscale.ai/termsofservice) and generate an access token via the [Forge console](https://forge.evolutionaryscale.ai/console).
# @markdown The console also provides a comprehensive list of models available to each user.

import os

# @markdown ### Authentication
# @markdown Paste your token from the [Forge console](https://forge.evolutionaryscale.ai/console)
forge_token = ""  # @param {type:"string"}
os.environ["ESM_API_KEY"] = forge_token

# @markdown ### Model Selection
# @markdown Enter the model name from the [Forge console page](https://forge.evolutionaryscale.ai/console) that you would like to use:
model_name = "esm3-medium-2024-08"  # @param {type:"string"}

# @markdown ### Input Structure
pdb_code = ""  # @param {type:"string"}
chain = "detect"  # @param {type:"string"}
# @markdown Enter PDB code or leave blank to upload file
# @markdown Specify a chain if uploading a complex

# @markdown ### Design Parameters
temperature = 0.1  # @param {type:"slider", min:0.0, max:1.0, step:0.01}
num_sequences = 8  # @param {type:"integer"}

In [None]:
# @title Install dependencies
import os

os.system("pip install git+https://github.com/evolutionaryscale/esm")
os.system(
    "pip install pydssp pygtrie dna-features-viewer py3dmol nest-asyncio ipywidgets"
)

import nest_asyncio  # noqa: E402

nest_asyncio.apply()

In [None]:
# @title Run Inverse Folding
import numpy as np
from esm.sdk.api import ESMProtein, ESMProteinError, GenerationConfig
from esm.widgets.utils.clients import get_forge_client
from google.colab import files
from IPython.display import HTML


def get_pdb(pdb_code=""):
    if pdb_code is None or pdb_code == "":
        upload_dict = files.upload()
        pdb_string = upload_dict[list(upload_dict.keys())[0]]
        with open("tmp.pdb", "wb") as out:
            out.write(pdb_string)
        return "tmp.pdb"
    else:
        os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
        return f"{pdb_code}.pdb"


print("Loading structure...")
pdb_path = get_pdb(pdb_code)

# Create protein object
protein = ESMProtein.from_pdb(pdb_path, chain_id=chain)
protein.sequence = None

print("Running inverse folding...")
client = get_forge_client(model_name)
generations = client.batch_generate(
    inputs=[protein] * num_sequences,
    configs=[GenerationConfig(track="sequence", temperature=temperature)]
    * num_sequences,
)

if isinstance(protein, ESMProteinError):
    raise RuntimeError(f"Error: {str(protein)}")

errors: list[ESMProteinError] = []
sequences: list[str] = []
for i, protein in enumerate(generations):
    if isinstance(protein, ESMProteinError):
        errors.append((i, protein))
    else:
        sequences.append(protein.sequence)


def calculate_conservation_scores(sequences: list[str]) -> np.ndarray:
    array = np.array([list(seq) for seq in sequences], dtype="S1")
    array = array.view(np.uint8) - ord("A")

    # Create a 2D array of counts
    max_range = 26
    counts = np.zeros((max_range + 1, array.shape[1]), dtype=int)
    for col in range(array.shape[1]):
        count = np.bincount(array[:, col], minlength=max_range + 1)
        counts[:, col] = count
    counts = counts.T

    # Calculate entropy (-sum(p log p))
    probabilities = counts / counts.sum(axis=1, keepdims=True)
    entropy = -np.sum(probabilities * np.log(probabilities + 1e-9), axis=1)

    # Convert to conservation score (1 - normalized entropy)
    max_entropy = np.log(256)
    # Magic constant to make displaying non-conserved residues more apparent
    conservation_scores = np.maximum(0, 0.5 - (entropy / max_entropy)) / 0.5

    return conservation_scores


def display_sequences(sequences: list[str]):
    conservation_scores = calculate_conservation_scores(sequences)
    html_output = '<pre style="line-height:1.0;letter-spacing:3px;font-family:monospace;margin:0;padding:0">'
    for sequence in sequences:
        for j, residue in enumerate(sequence):
            # Add padding for alignment and color the background
            html_output += f'<span style="background-color: rgba(9, 121, 105,{conservation_scores[j]})">{residue}</span>'
        html_output += "<br>"
    html_output += "</pre>"
    display(HTML(html_output))


display_sequences(sequences)

for i, error in errors:
    print(f"Error code {error.error_code} at index {i}: {error.error_msg}")