## Getting Started ##
### Predicting the Effects of Mutations on Protein Function with ESM-2 ###

Mutations of protein sequences can be quite complex and the effects they have on proteins can range from detrimental to function, to neutral and inconsequential, to causing improvement in function. It has been shown that even single point mutations or small numbers of mutations can cause drastic conformational changes, resulting in "fold-switching" and changes in the 3D structure of the folded protein. Judging the effects of mutations is difficult, but protein language models like the ESM-2 family of models can provide a lot of information on the effects of mutations on the fold and function of proteins.

https://www.biorxiv.org/content/10.1101/2021.07.09.450648v2

https://huggingface.co/blog/AmelieSchreiber/mutation-scoring

In [1]:
import os
import numpy as np
import pandas as pd
import ipywidgets as widgets

from matplotlib import pyplot as plt

# Huggingface imports
from transformers import AutoTokenizer, EsmForMaskedLM

#PyTorch
import torch

# Appearance of the Notebook
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
np.set_printoptions(linewidth=110)
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 100)
pd.set_option('display.width', 1000)

# Import this module with autoreload
%load_ext autoreload
%autoreload 2
import esm
print(f'Project module version: {esm.__version__}')
print(f'PyTorch version:        {torch.__version__}')

Project module version: 0.0.post1.dev18+g803595c.d20240131
PyTorch version:        2.1.2+cu121


In [2]:
def interactive_heatmap(protein_sequence):
    # Define interactive widgets
    start_slider = widgets.IntSlider(value=1, min=1, max=len(protein_sequence), step=1, description='Start:')
    end_slider = widgets.IntSlider(value=len(protein_sequence), min=1, max=len(protein_sequence), step=1, description='End:')

    ui = widgets.HBox([start_slider, end_slider])

    def update_heatmap(start, end):
        if start <= end:
            generate_heatmap(protein_sequence, start, end)

    out = widgets.interactive_output(update_heatmap, {'start': start_slider, 'end': end_slider})

    # Display the interactive widgets
    display(ui, out)

def generate_heatmap(protein_sequence, start_pos=1, end_pos=None):
    # Load the model and tokenizer
    model_name = "facebook/esm2_t6_8M_UR50D"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmForMaskedLM.from_pretrained(model_name)

    # Tokenize the input sequence
    input_ids = tokenizer.encode(protein_sequence, return_tensors="pt")
    sequence_length = input_ids.shape[1] - 2  # Excluding the special tokens

    # Adjust end position if not specified
    if end_pos is None:
        end_pos = sequence_length

    # List of amino acids
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")

    # Initialize heatmap
    heatmap = np.zeros((20, end_pos - start_pos + 1))

    # Calculate LLRs for each position and amino acid
    for position in range(start_pos, end_pos + 1):
        # Mask the target position
        masked_input_ids = input_ids.clone()
        masked_input_ids[0, position] = tokenizer.mask_token_id
        
        # Get logits for the masked token
        with torch.no_grad():
            logits = model(masked_input_ids).logits
            
        # Calculate log probabilities
        probabilities = torch.nn.functional.softmax(logits[0, position], dim=0)
        log_probabilities = torch.log(probabilities)
        
        # Get the log probability of the wild-type residue
        wt_residue = input_ids[0, position].item()
        log_prob_wt = log_probabilities[wt_residue].item()
        
        # Calculate LLR for each variant
        for i, amino_acid in enumerate(amino_acids):
            log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
            heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt

    # Visualize the heatmap
    plt.figure(figsize=(15, 5))
    plt.imshow(heatmap, cmap="viridis", aspect="auto")
    plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos]))
    plt.yticks(range(20), amino_acids)
    plt.xlabel("Position in Protein Sequence")
    plt.ylabel("Amino Acid Mutations")
    plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)")
    plt.colorbar(label="Log Likelihood Ratio (LLR)")
    plt.show()

In [3]:
# Example usage:
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
interactive_heatmap(protein_sequence)

HBox(children=(IntSlider(value=1, description='Start:', max=108, min=1), IntSlider(value=108, description='End…

Output()