## ESM-2 ##

ESM-2 is a state-of-the-art protein model trained on a masked language modelling objective. It is suitable for fine-tuning on a wide range of tasks that take protein sequences as input. For detailed information on the model architecture and training data, please refer to the accompanying paper. You may also be interested in some demo notebooks (PyTorch, TensorFlow) which demonstrate how to fine-tune ESM-2 models on your tasks of interest.

Several ESM-2 checkpoints are available in the Hub with varying sizes. 

Larger sizes generally have somewhat better accuracy, but require much more memory and time to train.

Model weights are available here:

https://huggingface.co/facebook/esm2_t30_150M_UR50D

In [1]:
import os
import numpy as np
import pandas as pd
import ipywidgets as widgets
from pathlib import Path

from matplotlib import pyplot as plt

# Huggingface imports
import evo_prot_grad
from transformers import AutoTokenizer, EsmForMaskedLM

#PyTorch
import torch

# Appearance of the Notebook
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
np.set_printoptions(linewidth=110)
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 100)
pd.set_option('display.width', 1000)

# Import this module with autoreload
%load_ext autoreload
%autoreload 2
import esm
from esm.evoprotgrad import EvoProtGrad
from esm.evoprotgrad import torch_device
print(f'Project module version: {esm.__version__}')
print(f'PyTorch version:        {torch.__version__}')

Project module version: 0.0.post1.dev23+gc9ac203
PyTorch version:        2.1.2+cu121


### Single Protein Evolution ###

The first method focuses on evolving a single protein sequence. The protein sequence is initially converted into a FASTA format, a widely used text-based format for representing nucleotide or peptide sequences. Each sequence is prefaced with a descriptive line starting with '>', followed by the sequence itself in subsequent lines.

The ESM-2 model and its tokenizer are then loaded as the expert system for directed evolution. The model, pretrained on vast protein sequence data, understands the complex relationships between amino acids. The tokenizer converts the protein sequences into a format that the ESM-2 model can process.

Directed evolution is initiated using the EvoProtGrad's DirectedEvolution class, specifying the ESM-2 model as the expert. The process involves running several parallel chains of Markov Chain Monte Carlo (MCMC) steps. Each chain explores the sequence space, proposing mutations at each step. The EvoProtGrad framework then evaluates these mutations based on the expert model's predictions, accepting mutations that are likely to improve the desired protein characteristics.

https://huggingface.co/blog/AmelieSchreiber/directed-evolution-with-esm2

In [4]:
def run_evo_prot_grad(raw_protein_sequence):
    # Convert raw protein sequence to the format expected by EvoProtGrad
    # Usually, protein sequences are handled in FASTA format, so we create a mock FASTA string
    fasta_format_sequence = f">Input_Sequence\n{raw_protein_sequence}"

    # Save the mock FASTA string to a temporary file
    temp_fasta_path = "temp_input_sequence.fasta"
    with open(temp_fasta_path, "w") as file:
        file.write(fasta_format_sequence)

    # Load the ESM-2 model and tokenizer as the expert
    esm2_expert = evo_prot_grad.get_expert(
        'esm',
        model=EsmForMaskedLM.from_pretrained("facebook/esm2_t30_150M_UR50D"),
        tokenizer=AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D"),
        temperature=0.95,
        device='cuda'  # or 'cpu' if GPU is not available
    )

    # Initialize Directed Evolution with the ESM-2 expert
    directed_evolution = evo_prot_grad.DirectedEvolution(
        wt_fasta=temp_fasta_path,    # path to the temporary FASTA file
        output='all',               # can be 'best', 'last', or 'all' variants
        experts=[esm2_expert],       # list of experts, in this case only ESM-2
        parallel_chains=1,           # number of parallel chains to run
        n_steps=20,                  # number of MCMC steps per chain
        max_mutations=10,            # maximum number of mutations per variant
        verbose=True                # print debug info
    )

    # Run the evolution process
    variants, scores = directed_evolution()

    # Process the results
    #for variant, score in zip(variants, scores):
    #    print(f"Variant: {variant}, Score: {score}")

    return variants, scores

In [11]:
# Get the device for the model
device_dict = torch_device()
display(device_dict)
torch.set_float32_matmul_precision(precision='high')
!nvidia-smi

{'device_id': 0,
 'device': device(type='cuda', index=0),
 'device_name': 'NVIDIA A100-SXM4-80GB',
 'cudnn_version': 8906,
 'torch_version': '2.1.2+cu121'}

Sat May 18 15:23:32 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:CD:00.0 Off |                    0 |
| N/A   26C    P0              71W / 500W |   4771MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [5]:
raw_protein_sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"  # Replace with your protein sequence
variants, scores = run_evo_prot_grad(raw_protein_sequence)

>Wildtype sequence: M A L W M R L L P L L A L L A L W G P D P A A A F V N Q H L C G S H L V E A L Y L V C G E R G F F Y T P K T R R E A E D L Q V G Q V E L G G G P G A G S L Q P L A L E G S L Q K R G I V E Q C C T S I C S L Y Q L E N Y C N
step 0 acceptance rate: 10.2521
>chain 0, Product of Experts score: 0.7196
[0mM[0m [0mA[0m [0mL[0m [0mW[0m [0mM[0m [0mR[0m [0mL[0m [0mL[0m [0mP[0m [0mL[0m [0mL[0m [0mA[0m [0mL[0m [0mL[0m [0mA[0m [0mL[0m [0mW[0m [0mG[0m [0mP[0m [0mD[0m [0mP[0m [0mA[0m [91mG[0m [0mA[0m [0mF[0m [0mV[0m [0mN[0m [0mQ[0m [0mH[0m [0mL[0m [0mC[0m [0mG[0m [0mS[0m [0mH[0m [0mL[0m [0mV[0m [0mE[0m [0mA[0m [0mL[0m [0mY[0m [0mL[0m [0mV[0m [0mC[0m [0mG[0m [0mE[0m [0mR[0m [0mG[0m [0mF[0m [0mF[0m [0mY[0m [0mT[0m [0mP[0m [0mK[0m [0mT[0m [0mR[0m [0mR[0m [0mE[0m [0mA[0m [0mE[0m [0mD[0m [0mL[0m [0mQ[0m [0mV[0m [0mG[0m [0mQ[0m [0mV[0m [0mE[0m [0mL[0m [0m

In [6]:
# Run class method
epg = EvoProtGrad()
output_dir = os.path.join(os.environ['HOME'], 'data', 'protein_evolution')
Path(output_dir).mkdir(parents=True, exist_ok=True)
var_df = epg.single_evolute(raw_protein_sequence=raw_protein_sequence, output_dir=output_dir)
display(var_df)

Unnamed: 0,variant,score,pos,source,target,sequence
0,14,1.120728,[24],[F],[H],MALWMRLLPLLALLALWGPDPAAAHVNQHLCGSHLVEALYLVCGER...
1,12,0.869624,"[16, 36, 56, 60, 71, 80, 82, 98, 103, 109]","[W, E, E, L, P, A, E, I, Q, N]","[L, D, S, A, E, T, G, L, E, D]",MALWMRLLPLLALLALLGPDPAAAFVNQHLCGSHLVDALYLVCGER...
2,13,0.869624,[],[],[],MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGER...
3,11,0.64782,"[16, 36, 56, 60, 71, 80, 98, 103, 109]","[W, E, E, L, P, A, I, Q, N]","[L, D, S, A, E, T, L, E, D]",MALWMRLLPLLALLALLGPDPAAAFVNQHLCGSHLVDALYLVCGER...
4,0,0.213814,[16],[W],[L],MALWMRLLPLLALLALLGPDPAAAFVNQHLCGSHLVEALYLVCGER...
5,8,-0.158664,"[16, 56, 60, 71, 80, 103]","[W, E, L, P, A, Q]","[L, S, A, E, T, E]",MALWMRLLPLLALLALLGPDPAAAFVNQHLCGSHLVEALYLVCGER...
6,7,-0.158664,"[16, 56, 60, 71, 80, 103]","[W, E, L, P, A, Q]","[L, S, A, E, T, E]",MALWMRLLPLLALLALLGPDPAAAFVNQHLCGSHLVEALYLVCGER...
7,4,-0.158664,"[16, 56, 60, 71, 80, 103]","[W, E, L, P, A, Q]","[L, S, A, E, T, E]",MALWMRLLPLLALLALLGPDPAAAFVNQHLCGSHLVEALYLVCGER...
8,6,-0.158664,"[16, 56, 60, 71, 80, 103]","[W, E, L, P, A, Q]","[L, S, A, E, T, E]",MALWMRLLPLLALLALLGPDPAAAFVNQHLCGSHLVEALYLVCGER...
9,9,-0.333209,"[16, 36, 56, 60, 71, 80, 103]","[W, E, E, L, P, A, Q]","[L, D, S, A, E, T, E]",MALWMRLLPLLALLALLGPDPAAAFVNQHLCGSHLVDALYLVCGER...


### Paired Protein Evolution ###

The second method extends this approach to paired protein sequences, separated by a specific marker – in this case, a string of 20 'G' amino acids. This unique separator or linker allows for the simultaneous evolution of two protein sequences while preserving their individual integrity and the relational context.

Similar to the single protein evolution, the paired sequences are formatted into a FASTA-like structure, replacing the ':' separator with the 'G' amino acid string. This modified sequence is then subjected to the directed evolution process, with the 'G' string region preserved to maintain the distinction between the two protein sequences.

During the evolution process, mutations are proposed and evaluated across both protein sequences, considering their combined context. The preserved region ensures that mutations do not disrupt the separator, maintaining the integrity of the paired format.

In [8]:
def run_evo_prot_grad_on_paired_sequence(paired_protein_sequence):
    # Replace ':' with a string of 20 'G' amino acids
    separator = 'G' * 20
    sequence_with_separator = paired_protein_sequence.replace(':', separator)

    # Determine the start and end indices of the separator
    separator_start_index = sequence_with_separator.find(separator)
    separator_end_index = separator_start_index + len(separator)

    # Format the sequence into FASTA format
    fasta_format_sequence = f">Paired_Protein_Sequence\n{sequence_with_separator}"

    # Save the sequence to a temporary file
    temp_fasta_path = "temp_paired_sequence.fasta"
    with open(temp_fasta_path, "w") as file:
        file.write(fasta_format_sequence)

    # Load the ESM-2 model and tokenizer as the expert
    esm2_expert = evo_prot_grad.get_expert(
        'esm',
        model=EsmForMaskedLM.from_pretrained("facebook/esm2_t30_150M_UR50D"),
        tokenizer=AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D"),
        temperature=0.95,
        device='cuda'  # or 'cpu' if GPU is not available
    )

    # Initialize Directed Evolution with the preserved separator region
    directed_evolution = evo_prot_grad.DirectedEvolution(
        wt_fasta=temp_fasta_path,
        output='all',
        experts=[esm2_expert],
        parallel_chains=1,
        n_steps=20,
        max_mutations=10,
        verbose=True,
        preserved_regions=[(separator_start_index, separator_end_index)]  # Preserve the 'G' amino acids string
    )

    # Run the evolution process
    variants, scores = directed_evolution()

    # Process the results, replacing the 'G' amino acids string back to ':'
    #for variant, score in zip(variants, scores):
    #   evolved_sequence = variant.replace(separator, ':')
    #    print(f"Evolved Paired Sequence: {evolved_sequence}, Score: {score}")

    return variants, scores

In [9]:
paired_protein_sequence = "MLTEVMEVWHGLVIAVVSLFLQACFLTAINYLLSRHMAHKSEQILKAASLQVPRPSPGHHHPPAVKEMKETQTERDIPMSDSLYRHDSDTPSDSLDSSCSSPPACQATEDVDYTQVVFSDPGELKNDSPLDYENIKEITDYVNVNPERHKPSFWYFVNPALSEPAEYDQVAM:MASPGSGFWSFGSEDGSGDSENPGTARAWCQVAQKFTGGIGNKLCALLYGDAEKPAESGGSQPPRAAARKAACACDQKPCSCSKVDVNYAFLHATDLLPACDGERPTLAFLQDVMNILLQYVVKSFDRSTKVIDFHYPNELLQEYNWELADQPQNLEEILMHCQTTLKYAIKTGHPRYFNQLSTGLDMVGLAADWLTSTANTNMFTYEIAPVFVLLEYVTLKKMREIIGWPGGSGDGIFSPGGAISNMYAMMIARFKMFPEVKEKGMAALPRLIAFTSEHSHFSLKKGAAALGIGTDSVILIKCDERGKMIPSDLERRILEAKQKGFVPFLVSATAGTTVYGAFDPLLAVADICKKYKIWMHVDAAWGGGLLMSRKHKWKLSGVERANSVTWNPHKMMGVPLQCSALLVREEGLMQNCNQMHASYLFQQDKHYDLSYDTGDKALQCGRHVDVFKLWLMWRAKGTTGFEAHVDKCLELAEYLYNIIKNREGYEMVFDGKPQHTNVCFWYIPPSLRTLEDNEERMSRLSKVAPVIKARMMEYGTTMVSYQPLGDKVNFFRMVISNPAATHQDIDFLIEEIERLGQDL"  # Replace with your paired protein sequences
variants, scores = run_evo_prot_grad_on_paired_sequence(paired_protein_sequence)

>Wildtype sequence: M L T E V M E V W H G L V I A V V S L F L Q A C F L T A I N Y L L S R H M A H K S E Q I L K A A S L Q V P R P S P G H H H P P A V K E M K E T Q T E R D I P M S D S L Y R H D S D T P S D S L D S S C S S P P A C Q A T E D V D Y T Q V V F S D P G E L K N D S P L D Y E N I K E I T D Y V N V N P E R H K P S F W Y F V N P A L S E P A E Y D Q V A M G G G G G G G G G G G G G G G G G G G G M A S P G S G F W S F G S E D G S G D S E N P G T A R A W C Q V A Q K F T G G I G N K L C A L L Y G D A E K P A E S G G S Q P P R A A A R K A A C A C D Q K P C S C S K V D V N Y A F L H A T D L L P A C D G E R P T L A F L Q D V M N I L L Q Y V V K S F D R S T K V I D F H Y P N E L L Q E Y N W E L A D Q P Q N L E E I L M H C Q T T L K Y A I K T G H P R Y F N Q L S T G L D M V G L A A D W L T S T A N T N M F T Y E I A P V F V L L E Y V T L K K M R E I I G W P G G S G D G I F S P G G A I S N M Y A M M I A R F K M F P E V K E K G M A A L P R L I A F T S E H S H F S L K K G A A A L G I G T D S 

In [10]:
print(len(variants))
print(len(scores))
print(variants[0])

20
20
['M L T S V M E V W H G L V I A V V S L F L Q A C F L T A I N Y L L S R H M A H K S E Q I L K A A S L Q V P R P S P G H H H P P A V K E M K E T Q T E R D I P M S D S L Y R H D S D T P S D S L D S S C S S P P A C Q A T E D V D Y T Q V V F S D P G E L K N D S P L D Y E N I K E I T D Y V N V N P E R H K P S F W Y F V N P A L S E P A E Y D Q V A M G G G G G G G G G G G G G G G G G G G G M A S P G S G F W S F G S E D G S G D S E N P G T A R A W C Q V A Q K F T G G I G N K L C A L L Y G D A E K P A E S G G S Q P P R A A A R K A A C A C D Q K P C S C S K V D V N Y A F L H A T D L L P A C D G E R P T L A F L Q D V M N I L L Q Y V V K S F D R S T K V I D F H Y P N E L L Q E Y N W E L A D Q P Q N L E E I L M H C Q T T L K Y A I K T G H P R Y F N Q L S T G L D M V G L A A D W L T S T A N T N M F T Y E I A P V F V L L E Y V T L K K M R E I I G W P G G S G D G I F S P G G A I S N M Y A M M I A R F K M F P E V K E K G M A A L P R L I A F T S E H S H F S L K K G A A A L G I G T D S V I L I K C 