## ESM-2: Single Protein Evolution ##

ESM-2 is a state-of-the-art protein model trained on a masked language modelling objective. It is suitable for fine-tuning on a wide range of tasks that take protein sequences as input. For detailed information on the model architecture and training data, please refer to the accompanying paper. You may also be interested in some demo notebooks (PyTorch, TensorFlow) which demonstrate how to fine-tune ESM-2 models on your tasks of interest.

Several ESM-2 checkpoints are available in the Hub with varying sizes. 

Larger sizes generally have somewhat better accuracy, but require much more memory and time to train.

The first method focuses on evolving a single protein sequence. The protein sequence is initially converted into a FASTA format, a widely used text-based format for representing nucleotide or peptide sequences. Each sequence is prefaced with a descriptive line starting with '>', followed by the sequence itself in subsequent lines.

The ESM-2 model and its tokenizer are then loaded as the expert system for directed evolution. The model, pretrained on vast protein sequence data, understands the complex relationships between amino acids. The tokenizer converts the protein sequences into a format that the ESM-2 model can process.

Directed evolution is initiated using the EvoProtGrad's DirectedEvolution class, specifying the ESM-2 model as the expert. The process involves running several parallel chains of Markov Chain Monte Carlo (MCMC) steps. Each chain explores the sequence space, proposing mutations at each step. The EvoProtGrad framework then evaluates these mutations based on the expert model's predictions, accepting mutations that are likely to improve the desired protein characteristics.

Resources

Blog post: https://huggingface.co/blog/AmelieSchreiber/directed-evolution-with-esm2

Model weights: https://huggingface.co/facebook/esm2_t30_150M_UR50D

ESM GitHub repository: https://github.com/facebookresearch/esm

In [1]:
import os
import gc
import numpy as np
import pandas as pd
import ipywidgets as widgets
from pathlib import Path

from matplotlib import pyplot as plt
from torch import cuda

# Huggingface imports
import evo_prot_grad
from transformers import AutoTokenizer, EsmForMaskedLM

#PyTorch
import torch

# Appearance of the Notebook
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Import this module with autoreload
%load_ext autoreload
%autoreload 2
import esm
from esm.evoprotgrad import EvoProtGrad
from esm.evoprotgrad import torch_device
print(f'Project module version: {esm.__version__}')
print(f'PyTorch version:        {torch.__version__}')

Project module version: 0.0.post1.dev23+gc9ac203
PyTorch version:        2.1.2+cu121


### Set the GPU device ###

In [2]:
# Where do we want to put the model weights.
project_dir = os.path.normpath('/n/data1/hms/ccb/projects/esm')
cache_dir = os.path.join(project_dir, 'model_weights')
Path(cache_dir).mkdir(exist_ok=True, parents=True)

# Free up GPU memory
gc.collect()
torch.cuda.empty_cache()

# Get the device for the model
device_dict = torch_device()
display(device_dict)
torch.set_float32_matmul_precision(precision='high')
!nvidia-smi

# Now, get the device name
device = device_dict.get('device')
print(device)

{'device_id': 0,
 'device': device(type='cuda', index=0),
 'device_name': 'NVIDIA A100-SXM4-80GB',
 'cudnn_version': 8906,
 'torch_version': '2.1.2+cu121'}

Mon May 20 09:29:00 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:DD:00.0 Off |                    0 |
| N/A   26C    P0              63W / 500W |      7MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

### Create the expert model ###

In [3]:
# https://huggingface.co/facebook/esm2_t33_650M_UR50D
esm_checkpoints = {
    't48_15B': 'facebook/esm2_t48_15B_UR50D',
    't36_3B': 'facebook/esm2_t36_3B_UR50D',
    't33_650M': 'facebook/esm2_t33_650M_UR50D',
    't30_150M': 'facebook/esm2_t30_150M_UR50D',
    't12_35M': 'facebook/esm2_t12_35M_UR50D',
    't6/8M': 'facebook/esm2_t6_8M_UR50D',
    'default': 'facebook/esm2_t30_150M_UR50D'
}

def set_expert(name='esm', checkpoint='default', device=None, cache_dir=None):
    """
    Args:
        name: (str) The name of the expert. Default is 'esm'.
        checkpoint: (str) The name of the checkpoint. Default is 'default'.
        device: (str) The device to run the expert on. Default is None.
        **kwargs: Additional keyword arguments for the method.
    Returns:
        expert: The expert object that has been set.
    """
    checkpoint = esm_checkpoints.get(checkpoint)
    expert = evo_prot_grad.get_expert(
        expert_name=name,
        model=EsmForMaskedLM.from_pretrained(checkpoint, cache_dir=cache_dir),
        tokenizer=AutoTokenizer.from_pretrained(checkpoint, cache_dir=cache_dir),
        temperature=0.95,
        device=device)
    return expert

# On the A100 GPU, we can load the big model. The weights may need to be downloaded which can take a while.
checkpoint = 't30_150M'
print(f'Loading model and weights for {checkpoint} model. This can take a while.')
expert = set_expert(checkpoint=checkpoint, device=device, cache_dir=cache_dir)

# Save some expert parameters
model_dict = {'model': expert.model,
              'device': expert.device,
              'temperature': expert.temperature,
              'vocabulary': expert.alphabet,
              'tokenizer': expert.tokenizer}

# Let's see how mubh GPU memory we are using
!nvidia-smi

Loading model and weights for t30_150M model. This can take a while.
Mon May 20 09:29:02 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:DD:00.0 Off |                    0 |
| N/A   26C    P0              71W / 500W |   1125MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

In [5]:
def run_evo_prot_grad(raw_protein_sequence, expert):
    # Convert raw protein sequence to the format expected by EvoProtGrad
    # Usually, protein sequences are handled in FASTA format, so we create a mock FASTA string
    fasta_format_sequence = f">Input_Sequence\n{raw_protein_sequence}"

    # Save the mock FASTA string to a temporary file
    temp_fasta_path = "temp_input_sequence.fasta"
    with open(temp_fasta_path, "w") as file:
        file.write(fasta_format_sequence)

    # Initialize Directed Evolution with the ESM-2 expert
    directed_evolution = evo_prot_grad.DirectedEvolution(
        wt_fasta=temp_fasta_path,    # path to the temporary FASTA file
        output='all',                # can be 'best', 'last', or 'all' variants
        experts=[expert],            # list of experts, in this case only ESM-2
        parallel_chains=1,           # number of parallel chains to run
        n_steps=20,                  # number of MCMC steps per chain
        max_mutations=10,            # maximum number of mutations per variant
        verbose=True                 # print debug info
    )

    # Run the evolution process
    variants, scores = directed_evolution()

    return variants, scores

In [6]:
raw_protein_sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"  # Replace with your protein sequence
variants, scores = run_evo_prot_grad(raw_protein_sequence, expert=expert)

>Wildtype sequence: M A L W M R L L P L L A L L A L W G P D P A A A F V N Q H L C G S H L V E A L Y L V C G E R G F F Y T P K T R R E A E D L Q V G Q V E L G G G P G A G S L Q P L A L E G S L Q K R G I V E Q C C T S I C S L Y Q L E N Y C N
step 0 acceptance rate: 6.6221
>chain 0, Product of Experts score: -0.6628
[0mM[0m [0mA[0m [0mL[0m [0mW[0m [0mM[0m [0mR[0m [0mL[0m [0mL[0m [0mP[0m [0mL[0m [0mL[0m [0mA[0m [0mL[0m [0mL[0m [0mA[0m [0mL[0m [0mW[0m [0mG[0m [0mP[0m [0mD[0m [0mP[0m [0mA[0m [0mA[0m [0mA[0m [0mF[0m [0mV[0m [0mN[0m [0mQ[0m [0mH[0m [0mL[0m [0mC[0m [0mG[0m [0mS[0m [0mH[0m [0mL[0m [0mV[0m [0mE[0m [0mA[0m [0mL[0m [0mY[0m [0mL[0m [0mV[0m [0mC[0m [0mG[0m [0mE[0m [0mR[0m [0mG[0m [0mF[0m [0mF[0m [0mY[0m [0mT[0m [0mP[0m [0mK[0m [0mT[0m [0mR[0m [0mR[0m [0mE[0m [0mA[0m [0mE[0m [0mD[0m [0mL[0m [0mQ[0m [0mV[0m [0mG[0m [0mQ[0m [0mV[0m [91mN[0m [0mL[0m [0m

In [6]:
# Process the results
#for variant, score in zip(variants, scores):
#    print(f"Variant: {variant}, Score: {score}")

In [7]:
# Run class method. This puts the results into a nice table, sorte by score
epg = EvoProtGrad(expert=expert, cache_dir=cache_dir)
output_dir = os.path.join(os.environ['HOME'], 'data', 'protein_evolution')
Path(output_dir).mkdir(parents=True, exist_ok=True)
var_df = epg.single_evolute(raw_protein_sequence=raw_protein_sequence, output_dir=output_dir)
display(var_df)

Unnamed: 0,variant,score,pos,source,target,sequence
0,12,3.38427,"[4, 5, 16, 25, 26, 68, 73, 75, 93, 102]","[M, R, W, V, N, G, A, S, Q, Y]","[A, A, T, T, V, A, H, W, D, A]",MALWAALLPLLALLALTGPDPAAAFTVQHLCGSHLVEALYLVCGER...
1,13,3.38427,[102],[Y],[P],MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGER...
2,6,2.354532,"[4, 5, 16, 25, 26, 68, 75, 93, 102]","[M, R, W, V, N, G, S, Q, Y]","[P, A, T, T, V, A, W, D, A]",MALWPALLPLLALLALTGPDPAAAFTVQHLCGSHLVEALYLVCGER...
3,7,2.354532,"[4, 5, 16, 25, 26, 68, 75, 93, 102]","[M, R, W, V, N, G, S, Q, Y]","[P, A, T, T, V, A, W, D, A]",MALWPALLPLLALLALTGPDPAAAFTVQHLCGSHLVEALYLVCGER...
4,10,2.354532,"[4, 5, 16, 25, 26, 68, 75, 93, 102]","[M, R, W, V, N, G, S, Q, Y]","[P, A, T, T, V, A, W, D, A]",MALWPALLPLLALLALTGPDPAAAFTVQHLCGSHLVEALYLVCGER...
5,11,2.354532,"[4, 5, 16, 25, 26, 68, 75, 93, 102]","[M, R, W, V, N, G, S, Q, Y]","[P, A, T, T, V, A, W, D, A]",MALWPALLPLLALLALTGPDPAAAFTVQHLCGSHLVEALYLVCGER...
6,9,2.354532,"[4, 5, 16, 25, 26, 68, 75, 93, 102]","[M, R, W, V, N, G, S, Q, Y]","[P, A, T, T, V, A, W, D, A]",MALWPALLPLLALLALTGPDPAAAFTVQHLCGSHLVEALYLVCGER...
7,8,2.354532,"[4, 5, 16, 25, 26, 68, 75, 93, 102]","[M, R, W, V, N, G, S, Q, Y]","[P, A, T, T, V, A, W, D, A]",MALWPALLPLLALLALTGPDPAAAFTVQHLCGSHLVEALYLVCGER...
8,18,1.01667,"[25, 26, 56, 63, 72, 83, 98, 102]","[V, N, E, G, G, G, I, Y]","[A, E, L, H, P, P, A, P]",MALWMRLLPLLALLALWGPDPAAAFAEQHLCGSHLVEALYLVCGER...
9,19,1.01667,"[25, 26, 56, 63, 72, 83, 98, 102]","[V, N, E, G, G, G, I, Y]","[A, E, L, H, P, P, A, P]",MALWMRLLPLLALLALWGPDPAAAFAEQHLCGSHLVEALYLVCGER...
