# [Tutorial 5](https://github.com/evolutionaryscale/esm/tree/main/cookbook/tutorials): Guided Generation with ESM3

Guided generation is a powerful tool that allows you to sample outputs out of ESM3 that maximize any kind of score function.

For example, you may want to
1. Guide generations towards higher quality metrics like pTM
2. Constrain the distribution of outputs to have certain amino acid frequencies or structural attributes
3. Minimize a biophysical energy function
4. Use experimental screening data to guide designs with a regression model

As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252) and constrained optimization using the Modified Differential Method of Multipliers from [Platt & Barr 1987](https://proceedings.neurips.cc/paper_files/paper/1987/file/a1126573153ad7e9f44ba80e99316482-Paper.pdf)

In this notebook we will walk through a few examples to illustrate how to use guided generation. 

1. Guide towards high pTM for improved generation quality
2. Generate a protein with no cysteine (C) residues
3. Maximize protein globularity by minimizing the radius of gyration, while keeping pTM high



## Imports

In [None]:
!pip install git+https://github.com/evolutionaryscale/esm.git
!pip install py3dmol

In [None]:
import biotite.structure as bs
import py3Dmol

from esm.sdk.api import ESMProtein, GenerationConfig
from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction

## Creating a scoring function

To get started with the guided generation API the only thing you need is to create a callable class that inherits from `GuidedDecodingScoringFunction`. This class should receive as input an `ESMProtein` object and output a numerical score.


For example, one of the computational metrics we can use to measure the quality of a generated protein structure is the Predicted Template Modelling (pTM) score, so we'll use it to create a `PTMScoringFunction`.

Fortunately for us, every time we generate a protein using ESM3 (either locally or on Forge) we also get its pTM, so all our class needs to do when its called is to return the `ptm` attribute of its input.

In [None]:
# Create scoring function (e.g. PTM scoring function)
class PTMScoringFunction(GuidedDecodingScoringFunction):
    def __call__(self, protein: ESMProtein) -> float:
        # Minimal example of a scoring function that scores proteins based on their pTM score
        # Given that ESM3 already has a pTM prediction head, we can directly access the pTM score
        assert protein.ptm is not None, "Protein must have pTM scores to be scored"
        return float(protein.ptm)

### Initialize your client

The guided generation is compatible with both local inference using the `ESM3` class and remote inference with the Forge client

In [None]:
# To use the tokenizers and the open model you'll need to login into Hugging Face

from huggingface_hub import notebook_login

notebook_login()

In [None]:
## Locally with ESM3-open
# from esm.models.esm3 import ESM3
# model = ESM3.from_pretrained().to("cuda")

## On Forge with larger ESM3 models
from getpass import getpass

from esm.sdk import client

token = getpass("Token from Forge console: ")
model = client(
    model="esm3-medium-2024-08", url="https://forge.evolutionaryscale.ai", token=token
)

## Guide towards high pTM for improved generation quality

Once your scoring function is defined and you have initialized your model you can create an `ESM3GuidedDecoding` instance to sample from it

In [None]:
ptm_guided_decoding = ESM3GuidedDecoding(
    client=model, scoring_function=PTMScoringFunction()
)

In [None]:
# Start from a fully masked protein
PROTEIN_LENGTH = 256
starting_protein = ESMProtein(sequence="_" * PROTEIN_LENGTH)

# Call guided_generate
generated_protein = ptm_guided_decoding.guided_generate(
    protein=starting_protein,
    num_decoding_steps=len(starting_protein) // 8,
    num_samples_per_step=10,
)

### Compare against baseline with no guidance

First we are going to sample a protein generated without any guidance. This means that, when not providing pTM guidance, we could be sampling proteins that have no clear structure.

In [None]:
# Generate a protein WITHOUT guidance
generated_protein_no_guided: ESMProtein = model.generate(
    input=starting_protein,
    config=GenerationConfig(track="sequence", num_steps=len(starting_protein) // 8),
)  # type: ignore

# Fold
generated_protein_no_guided: ESMProtein = model.generate(
    input=generated_protein_no_guided,
    config=GenerationConfig(track="structure", num_steps=1),
)  # type: ignore

In [None]:
# Create a 1x2 grid of viewers (1 row, 2 columns)
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))

# Convert ESMProtein objects to ProteinChain objects
protein_chain1 = generated_protein_no_guided.to_protein_chain()
protein_chain2 = generated_protein.to_protein_chain()

# Add models to respective panels
view.addModel(protein_chain1.to_pdb_string(), "pdb", viewer=(0, 0))
view.addModel(protein_chain2.to_pdb_string(), "pdb", viewer=(0, 1))

# Set styles for each protein
view.setStyle({}, {"cartoon": {"color": "spectrum"}}, viewer=(0, 0))
view.setStyle({}, {"cartoon": {"color": "spectrum"}}, viewer=(0, 1))

# Zoom and center the view
view.zoomTo()
view.show()

In [None]:
print(f"pTM Without guidance: {generated_protein_no_guided.ptm:.3f}")
print(f"pTM With guidance: {generated_protein.ptm:.3f}")

## Generate a Protein with No Cysteines

Guided generation is not constrained to structural metrics, you can also use it to guide the sequence generation.

For example, we can create a `NoCysteineScoringFunction` that penalizes the protein if it contains Cysteine residues

In [None]:
class NoCysteineScoringFunction(GuidedDecodingScoringFunction):
    def __call__(self, protein: ESMProtein) -> float:
        # Penalize proteins that contain cysteine
        assert protein.sequence is not None, "Protein must have a sequence to be scored"
        # Note that we use a negative score here, to discourage the presence of cysteine
        return -protein.sequence.count("C")

In [None]:
no_cysteine_guided_decoding = ESM3GuidedDecoding(
    client=model, scoring_function=NoCysteineScoringFunction()
)

In [None]:
# Start from a fully masked protein
PROTEIN_LENGTH = 256
starting_protein = ESMProtein(sequence="_" * PROTEIN_LENGTH)

# Call guided_generate
no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(
    protein=starting_protein,
    num_decoding_steps=len(starting_protein) // 8,
    num_samples_per_step=10,
)

Let's check our sequence!

If guided generation converged to `score == 0.00`, the resulting protein should contain no Cysteine residues

In [None]:
assert no_cysteine_protein.sequence is not None, "Protein must have a sequence"
print(no_cysteine_protein.sequence)
print(f"Number of cysteine residues: {no_cysteine_protein.sequence.count('C')}")

## Maximize Globularity

We use the radius of gyration as a proxy to maximize globularity, and we will also encourage generations to have high pTM by using constraints

In [None]:
from esm.sdk.experimental import (
    ConstraintType,
    ESM3GuidedDecodingWithConstraints,
    GenerationConstraint,
)

In [None]:
class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):
    def __call__(self, protein: ESMProtein) -> float:
        # Use the negative radius of gyration as the score to maximize
        score = -1 * self.radius_of_gyration(protein)

        # Re-scale the score to be in a similar magnitude as pTM
        score = score / 100.0

        return score

    @staticmethod
    def radius_of_gyration(protein: ESMProtein) -> float:
        protein_chain = protein.to_protein_chain()
        arr = protein_chain.atom_array_no_insertions
        return bs.gyration_radius(arr)

In [None]:
# Constrain generation to have pTM > 0.75
ptm_constraint = GenerationConstraint(
    scoring_function=PTMScoringFunction(),
    constraint_type=ConstraintType.GREATER_EQUAL,
    value=0.75,
)

radius_guided_decoding = ESM3GuidedDecodingWithConstraints(
    client=model,
    scoring_function=RadiousOfGyrationScoringFunction(),
    constraints=[ptm_constraint],  # Add list of constraints
    damping=1.0,  # Damping factor for the MMDM algorithm
    learning_rate=10.0,  # Learning rate for the MMDM algorithm
)

In [None]:
# Start from a fully masked protein
PROTEIN_LENGTH = 256
starting_protein = ESMProtein(sequence="_" * PROTEIN_LENGTH)

# Call guided_generate
radius_guided_protein = radius_guided_decoding.guided_generate(
    protein=starting_protein,
    num_decoding_steps=len(starting_protein) // 8,
    num_samples_per_step=10,
)

In [None]:
# Visualize the trajectory of the constrained generation
radius_guided_decoding.visualize_latest_trajectory()

In [None]:
# Visualize the generated protein
view = py3Dmol.view(width=800, height=400)
view.addModel(radius_guided_protein.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
view.zoomTo()

In [None]:
# Check pTM
radius_guided_protein.ptm