# Protein Language Models (pLMs)

We have two main exercises in this module.

First, we'll explore protein language models and the format of the predictions they make. In general, this is background information for an **expert applications specialist**; a method developer will spend more time with the inner workings of a pLM. This knowledge will be conceptually useful and beneficial to understanding the literature.

Then, we'll do some benchmarking of different pLMs to gauge which models perform best for various classes of proteins. It's a good idea to understand the quality of prediction you get before applying any model. You'll be surprised at the performances differences for each model.

## Setup and dependencies

In [1]:
if 'google.colab' in str(get_ipython()):
    print("Running on Google Colab. Executing Colab-specific commands...")
    # Mount Google Drive to access files
    from google.colab import drive
    drive.mount('/content/drive')

    # Drive location for the fasta files
    data_loc = '/content/drive/My Drive/AIDrivenDesignOfBiologics/AIDrivenDesignOfBiologics-PEGSEurope-2025/ProteinLanguageModels/pLM_basics_and_benchmarking'

    print("Installing dependencies...")
    !pip install biopython # pandas matplotlib seaborn tqdm transformers scipy torch torchvision torchaudio

else:
    print("Not running on Google Colab. Skipping Colab-specific commands.")
    print("Running in a local environment or Jupyter Notebook.")
    data_loc = '/home/davidnannemann/AIDD4B/ProteinLMs/'

Running on Google Colab. Executing Colab-specific commands...
Mounted at /content/drive
Installing dependencies...
Collecting biopython
  Downloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Downloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m70.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.86


In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import random
from tqdm import tqdm
from Bio import SeqIO
from scipy import stats

import warnings
warnings.filterwarnings('ignore')

# Model-specific imports
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
import torch
import torch.nn.functional as F

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

print("Dependencies loaded successfully!")

Dependencies loaded successfully!


## Protein Language Models Overview

### Model Loading Functions

There are many ways to load a pLM. Here, we're using functions from the `transformers` module. You'll see other methods utilized and have to adapt to these if the model of interest is only available with one set of modules.

As an **expert applications scientist**, but maybe not an ML engineer, when you see a warning you don't understand, it's a good idea to check that your model loading process is kosher. Look closely at the log as ProtBERT loads. Use your favorite chatbot to check that the model is fine. FYI, protBERT loading is fine.

In [3]:
class ProteinLanguageModel:
    def __init__(self, model_name, model_type):
        self.model_name = model_name
        self.model_type = model_type
        self.tokenizer = None
        self.model = None
        self.pipeline = None

    def load_model(self):
        """Load the specified protein language model"""
        try:
            if self.model_type == "esm":
                # ESM models have specific loading requirements
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                self.model = AutoModelForMaskedLM.from_pretrained(self.model_name)

            elif self.model_type == "protbert":
                # ProtBERT uses spaces between amino acids
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                self.model = AutoModelForMaskedLM.from_pretrained(self.model_name)

            elif self.model_type == "igbert":
                # IgBERT for antibody sequences (similar to BERT but antibody-specific)
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                self.model = AutoModelForMaskedLM.from_pretrained(self.model_name)

            # elif self.model_type == "amplify":
            #     # AMPLIFY uses ESM-like architecture but with optimized training
            #     self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            #     self.model = AutoModelForMaskedLM.from_pretrained(self.model_name)

            # Create pipeline for easier prediction
            self.pipeline = pipeline("fill-mask",
                                    model=self.model,
                                    tokenizer=self.tokenizer,
                                    device=0 if torch.cuda.is_available() else -1)

            print(f"✓ Loaded {self.model_name}")
            return True

        except Exception as e:
            print(f"✗ Failed to load {self.model_name}: {str(e)}")
            print(f"  Error details: {type(e).__name__}")
            return False

# Initialize models with correct model paths
models = {
    "ESM": ProteinLanguageModel("facebook/esm2_t33_650M_UR50D", "esm"), #
    "ProtBERT": ProteinLanguageModel("Rostlab/prot_bert", "protbert"),
    "IgBERT": ProteinLanguageModel("Exscientia/IgBert", "igbert"),
    #"AMPLIFY": ProteinLanguageModel("chandar-lab/AMPLIFY_350M", "amplify"),
}

print("Attempting to load protein language models...")
print("Note: This may take several minutes for first-time downloads.")
print("=" * 60)

# Load available models
available_models = {}
for name, model in models.items():
    print(f"\nLoading {name}...")
    if model.load_model():
        available_models[name] = model
    else:
        print(f"  Skipping {name} - will continue with other models")

print(f"\n{'='*60}")
print(f"MODEL LOADING SUMMARY:")
print(f"Successfully loaded {len(available_models)} out of {len(models)} models")
print(f"Available models: {list(available_models.keys())}")

if len(available_models) == 0:
    print("\n⚠️  WARNING: No models loaded successfully!")
    print("This might be due to:")
    print("- Network connectivity issues")
    print("- Insufficient memory/GPU memory")
    print("- Model repositories being temporarily unavailable")
    print("- Missing dependencies")
    print("\nTrying fallback models...")

    # Fallback to smaller/more reliable models
    fallback_models = {
        "ESM-2-35M": ProteinLanguageModel("facebook/esm2_t12_35M_UR50D", "esm"),
        "ProtBERT-BFD": ProteinLanguageModel("Rostlab/prot_bert_bfd", "protbert"),
    }

    for name, model in fallback_models.items():
        print(f"\nTrying fallback: {name}...")
        if model.load_model():
            available_models[name] = model
            break

if len(available_models) > 0:
    print(f"\n✅ Ready to proceed with {len(available_models)} model(s)")

    # Show model details
    for name, model in available_models.items():
        print(f"\n{name}:")
        print(f"  Model path: {model.model_name}")
        print(f"  Model type: {model.model_type}")
        print(f"  Tokenizer vocab size: {len(model.tokenizer) if model.tokenizer else 'Unknown'}")
else:
    print("\n❌ No models could be loaded. Please check your internet connection")
    print("and try running the notebook again.")

Attempting to load protein language models...
Note: This may take several minutes for first-time downloads.

Loading ESM...


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/724 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.61G [00:00<?, ?B/s]

Device set to use cuda:0


✓ Loaded facebook/esm2_t33_650M_UR50D

Loading ProtBERT...


tokenizer_config.json:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/361 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cuda:0


✓ Loaded Rostlab/prot_bert

Loading IgBERT...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/686 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

Device set to use cuda:0


✓ Loaded Exscientia/IgBert

MODEL LOADING SUMMARY:
Successfully loaded 3 out of 3 models
Available models: ['ESM', 'ProtBERT', 'IgBERT']

✅ Ready to proceed with 3 model(s)

ESM:
  Model path: facebook/esm2_t33_650M_UR50D
  Model type: esm
  Tokenizer vocab size: 33

ProtBERT:
  Model path: Rostlab/prot_bert
  Model type: protbert
  Tokenizer vocab size: 30

IgBERT:
  Model path: Exscientia/IgBert
  Model type: igbert
  Tokenizer vocab size: 30


## Dataset Preparation

We're going to look at performance across three protein families: antibodies, enzymes, and a viral protein. These families are exemplary.

The three protein families:

1. `antibodies`: these are a classic target for biologics drug design. The dataset consists of the variable region of antibodies (HC-LC) in TheraSabDab (Oct. 2024) and deduplicated at 70% identity to reduce the number of sequences to <100.
2. `tev_protease`: a classic enzyme that we'll revisit in the next module. The sequences derive from a Uniprot grouping of sequences with at least 50% identity, resulting in ~50 sequences.
3. `neuraminidase`: Who doesn't want to build a better flu vaccine? This dataset starts with the Uniprot entry for A/California/04/2009 H1N1 neuraminidase, whereby sequences with up to 50% identity were extracted. These were then grouped at 90% identity to reduce the number of sequences from over 40k to just 75.

FYI, clustering was performed with cd-hit and the cluster representatives carried into this analysis. More thought should be placed into the make-up of a true scientific benchmark. The goal here is to show you _how_ to organize a benchmark, not to actually do one.

In [4]:
def load_sequences_from_fasta(file_paths, num_records=50):
    """
    Load sequences from multiple FASTA files into a dictionary.

    Args:
        file_paths: dict mapping class name to fasta file path

    Returns:
        dict: {class_name: [sequence_str, ...]}
    """
    sequences_dict = {}
    for class_name, fasta_path in file_paths.items():
        sequences = []
        record_count = num_records
        for record in SeqIO.parse(fasta_path, "fasta"):
            if record_count == 0:
                break
            record_count -= 1
            sequences.append(str(record.seq))
        sequences_dict[class_name] = sequences
    return sequences_dict

# Sample protein sequences for demonstration
fasta_files = {
    "antibody":      f"{data_loc}/TheraSabDab_sequences.fasta",
    "tev_protease":  f"{data_loc}/tev_swissprot_blast.fasta", # Use TEV_protease_sequences.fasta if you have a high-end GPU with lots of memory
    "neuraminidase": f"{data_loc}/neuraminidase_clusters.out.fasta"
}
sample_sequences = load_sequences_from_fasta(fasta_files)

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/My Drive/AIDrivenDesignOfBiologics/AIDrivenDesignOfBiologics-PEGSEurope-2025/ProteinLanguageModels/pLM_basics_and_benchmarking/TheraSabDab_sequences.fasta'

In [None]:
# Create DataFrame
data_rows = []
for protein_class, sequences in sample_sequences.items():
    for i, sequence in enumerate(sequences):
        data_rows.append({
            'sequence_id': f"{protein_class}_{i+1}",
            'protein_class': protein_class,
            'sequence': sequence,
            'sequence_length': len(sequence)
        })

df = pd.DataFrame(data_rows)
print(f"Dataset shape: {df.shape}")
print(f"Protein classes: {df['protein_class'].value_counts().tolist()}")
df.head()

# How do protein language models work?

Remember that pLMs are trained by masking positions in the sequence and optimizing weights such that the masked position is best predicted with the native amino acid.

We do the same at inference - mask a position or positions and predict the probability of amino acids at those positions. Some protocols select the amino acid with the best probability (k=1), while others consider values of _k_ greater than 1.

## Predicting amino acid probabilites at each position

The below cell contains functions for formatting the sequence with masked residues. Some pLMs use different keyword tokens. In addition, BERT-based pLMs separate each token with a space. So, we have a function that does this work for us.

The second function in the below cell is iterates over each position in a protein sequence and gathers the predicted probabilities for that position. We're going to explore this data a bit below.

In [None]:
def format_sequence_for_model(sequence, model_type):
    """Format sequence according to model requirements. For these models, tokens at the ends are added."""
    if model_type == "protbert":
        # ProtBERT requires spaces between amino acids but NOT within [MASK] tokens
        # First convert <mask> to [MASK], then handle spacing carefully
        sequence = sequence.replace("<mask>", "[MASK]")
        sequence = sequence.replace("---", "")  # Paired usage of antibody sequences

        # Split sequence into tokens, preserving [MASK] as single units
        tokens = []
        i = 0
        while i < len(sequence):
            if sequence[i:i+6] == "[MASK]":
                # Add [MASK] as a single token
                tokens.append("[MASK]")
                i += 6
            else:
                # Add individual amino acid
                tokens.append(sequence[i])
                i += 1

        # Join with spaces
        return ' '.join(tokens)
    elif model_type == "igbert":
        # ProtBERT requires spaces between amino acids but NOT within [MASK] tokens
        # First convert <mask> to [MASK], then handle spacing carefully
        sequence = sequence.replace("<mask>", "[MASK]")
        sequence = sequence.replace("---", "[SEP]")  # Paired usage of antibody sequences

        # Split sequence into tokens, preserving [MASK] as single units
        tokens = []
        i = 0
        while i < len(sequence):
            if sequence[i:i+6] == "[MASK]":
                # Add [MASK] as a single token
                tokens.append("[MASK]")
                i += 6
            elif sequence[i:i+5] == "[SEP]":
                # Add [MASK] as a single token
                tokens.append("[SEP]")
                i += 5
            else:
                # Add individual amino acid
                tokens.append(sequence[i])
                i += 1
        # Join with spaces
        return ' '.join(tokens)
    else:
        # ESM and others use standard format with <mask> tokens
        sequence = sequence.replace("-", "<pad>")  # Paired usage of antibody sequences
        return sequence

def get_positionwise_probs(model, sequence, model_type="esm"):
    """
    For each position, mask it and get the probability for all 20 amino acids.
    Returns: probs_matrix (L x 20), aa_order (list), native_indices (list)
    """
    aa_order = list("ACDEFGHIKLMNPQRSTVWY")
    probs_matrix = []
    native_indices = []
    seq_length = len(sequence)

    for pos in range(seq_length):
        # Mask one position
        seq_list = list(sequence)
        native_aa = seq_list[pos]
        native_indices.append(aa_order.index(native_aa) if native_aa in aa_order else -1)

        # Insert mask token
        if model_type == "protbert" or model_type == "igbert":
            mask_token = "[MASK]"
        else:
            mask_token = "<mask>"
        seq_list[pos] = mask_token
        masked_seq = ''.join(seq_list)
        formatted_seq = format_sequence_for_model(masked_seq, model_type)

        # Get predictions for this mask
        results = model.pipeline(formatted_seq, top_k=20)
        # results: list of dicts, each with 'token_str' and 'score'
        aa_probs = np.zeros(20)
        for res in results:
            aa = res['token_str'].strip().upper()
            if aa in aa_order:
                aa_probs[aa_order.index(aa)] = res['score']
        probs_matrix.append(aa_probs)

    probs_matrix = np.array(probs_matrix)  # shape: (L, 20)
    return probs_matrix, aa_order, native_indices

# Example usage:
#torch.cuda.empty_cache()
sequence = sample_sequences["antibody"][0]
model = available_models["IgBERT"]  # or your preferred model
probs_matrix, aa_order, native_indices = get_positionwise_probs(model, sequence, model.model_type)


In [None]:
# Plotting the heatmap
plt.figure(figsize=(9, 4))
palette = sns.color_palette("light:b", as_cmap=True)

ax = sns.heatmap(probs_matrix.T, cmap=palette, xticklabels=list(sequence), yticklabels=aa_order, cbar_kws={'label': 'Probability'})

# Highlight native residue at each position
for i, aa_idx in enumerate(native_indices):
    if aa_idx >= 0:
        ax.add_patch(plt.Rectangle((i, aa_idx), 1, 1, fill=False, edgecolor='black', lw=2))

#plt.xlim((0,55))
#plt.xlim((68,118))
plt.xticks(rotation=0)
plt.xlabel("Sequence Position")
plt.ylabel("Amino Acid")
plt.title("Amino Acid Probabilities at Each Position (Native Boxed)")
plt.tight_layout()
plt.show()

This graph is a bit smushed, but layouts of this kind are always challenging. Uncomment `plt.xlim((0,55))` and explore different portions of the graph.

Discuss with an antibody expert in your vicinity to look at conservation in different regions, e.g. conserved Cys/Trp/Phe and various CDR/FWRs. Are there any trends?

## The inner workings of a pLM

Previously, we allowed the `pipeline` function from the `tokenizer` module to do the work of reporting the best amino acid. Let's look in detail at the underlying workflow for prediction of a single position.

### Tokenization

Tokens are how the sequence is represented to pLM.

What are the tokens available for each model: `ESM`, `ProtBERT`, or `IgBERT`? Explore the token vocabulary for our loaded molecules. Note the different vocabularies in each model and how to explore the vocabulary of a new-to-you model.

In [None]:
# Show amino acid tokens specifically
model = available_models["IgBERT"]
standard_aa = "ACDEFGHIKLMNPQRSTVWY"
extended_aa = standard_aa + "BXZUO-."  # B, X, Z are sometimes used
extended_dict = {
    "B": "Asp or Asn",
    "X": "Any AA",
    "Z": "Glu or Gln",
    "U": "Selenocysteine",
    "O": "Pyrrolysine",
    "-": "Gap/Padding",
    ".": "Gap/Padding"
}

token_dict = model.tokenizer.get_vocab()
for token in token_dict.items():
    if token[0] in standard_aa:
        print(f"Token: {token[0]:>6}, Index: {token[1]:>3}")
    elif token[0] in extended_aa:
        print(f"Token: {token[0]:>6}, Index: {token[1]:>3} (non-standard: {extended_dict[token[0]]})")
    else:
        print(f"Token: {token[0]:>6}, Index: {token[1]:>3} (special)")


### Prediction at a single masked position


Pick an interesting position and follow the code. Find where in the code each step below is happening.

1. the position is masked and the sequence is formatted.
2. the tokens are then embedded into a high-dimensional embedding vector.
3. the output of the model is a vector of logits, one for each token. A logit is an unnormalized score. (Each pLM has a different set of tokens, sometimes including nucleotides and gaps, too.)
4. Probabilities are calculated from the logits using the softmax function.
5. The index for the top probability is found and a mapping to the order of the tokens is used to find the top amino acid. (Notice we extract the order of the tokens from the model, too)


In [None]:
# Choose a position to analyze (e.g., position 5 or 30 or 104)
position = 1  # 0-based index, e.g. sequence position minus 1

# Select model and sequence
model = available_models["IgBERT"]
sequence = sample_sequences["antibody"][0]
model_type = model.model_type

# Prepare masked sequence for the chosen position
seq_list = list(sequence)
native_aa = seq_list[position]

if model_type == "protbert" or model_type == "igbert":
    mask_token = "[MASK]"
else:
    mask_token = "<mask>"
seq_list[position] = mask_token
masked_seq = ''.join(seq_list)
formatted_seq = format_sequence_for_model(masked_seq, model_type)

# Tokenize and get model outputs
inputs = model.tokenizer(formatted_seq, return_tensors="pt")
device = next(model.model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
    outputs = model.model(**inputs, output_hidden_states=True)

# Find the mask token index in the input
mask_token_id = model.tokenizer.mask_token_id if model_type != "esm" else model.tokenizer.convert_tokens_to_ids("<mask>")
mask_idx = (inputs["input_ids"][0] == mask_token_id).nonzero(as_tuple=True)[0].item()

The `input` is a tensor representation of the masked sequence with the index of each logit at each position. The index corresponds to the tokens available for the model, such as amino acids and the <mask>/[MASK] tokens. Find the logit index up to 20 and see that it corresponds
to your `position` variable.
The second tensor in the input is a mask token representation of the sequence, where the position of the mask token corresponds to the position variable.

In [None]:
print("Input:", inputs)

In [None]:
# The embedding is the underlying representation of the masked sequence position, capturing contextual information.
# Embedding dimension is a property of the model, associated with the parameters of its architecture.
# An embedding can be extracted from each layer of the model; here we use the final layer, but other
# layers are known to have meaningful information.

# Get embedding for the masked site
embedding = outputs.hidden_states[-1][0, mask_idx, :].cpu().numpy()

print("Embedding shape:", embedding.shape)

In [None]:
# The output logits are a tensor of shape (1, sequence_length, vocab_size), where vocab_size is the number of tokens in
# the model's vocabulary. The logits represent the model's raw predictions for each token in the vocabulary at each
# position in the sequence.

print("Full output logits shape:", outputs.logits.shape)

In [None]:
# Get logits for the masked site
logits = outputs.logits[0, mask_idx, :]

# AA logits for the masked position.
print("Shape:", logits.shape,". Twenty logits for the 20 standard amino acids.")

# Get token indices for the 20 standard amino acids.
# Find the logits for those tokens at the respective index.
# There are other code mechanisms to do this, but this is straightforward.
aa_order = list("ACDEFGHIKLMNPQRSTVWY")
aa_token_ids = [model.tokenizer.convert_tokens_to_ids(aa) for aa in aa_order] # logit indices for standard AAs in aa_order
aa_logits = logits[aa_token_ids] # logits ordered as in aa_order

print("Logits values:", aa_logits.cpu().numpy())


In [None]:
# Probabilities are calculated using softmax on the logits and should sum to 1.
aa_probs = F.softmax(aa_logits, dim=0).cpu().numpy()

print("Sum of probabilities (should be 1):", aa_probs.sum())
print("\n")
for aa, prob, logits in zip(aa_order, aa_probs, aa_logits):
    print(f"Amino Acid: {aa}, Probability: {prob:.3f}, Logit: {logits.item():.4f}")
print("\n")

# Show probabilities and native/best amino acids
best_aa = aa_order[aa_probs.argmax()]
print(f"At position {position}: Native AA = {native_aa}; Probability = {round(float(aa_probs[aa_order.index(native_aa)]),3)}")
print(f"At position {position}:   Best AA = {best_aa}; Probability = {round(float(aa_probs[aa_order.index(best_aa)]),3)}")

Let's understand the relationship between logits and probabilities with a plot, and visualize the best amino acid at the position of interest.

In [None]:
x,y = aa_logits.cpu().numpy(), aa_probs
plt.scatter(x, y)

# Annotate each point
for i, txt in enumerate(aa_order):
    plt.annotate(txt, (x[i], y[i]), # xy is the point to annotate
                 xytext=(x[i] + 0.01, y[i] + 0.01), # xytext is the text position
                 textcoords="data", # coordinate system for xytext
                 ha='left', va='bottom') # horizontal and vertical alignment

plt.xlabel("Logits")
plt.ylabel("Probability")
plt.title("Logits vs Probabilities for each Amino Acid")
plt.grid(True)
plt.show()

Cycle back and look at other positions. I recommend looking at residues 5 or 30 or 104, as these are interesting (hint, hint). Unless changed, the example sequence is an antibody, so look in the range 24-34 and 98-108, as these are CDR1 and CDR3.

For each position, consider if the native amino acid is recommended at that site of if the optimal identity is another amino acid.

## Masking Strategy and Functions

Now that we know a bit about the prediction outputs of pLMs, let's take a look at how these models perform in a benchmark.

Below are functions for our benchmark.

First is a function to generate inputs with a percentage of positions masked. A pLM outputs logits for every position simultaneously. Unmasked positions provide context for the masked positions undergoing design. Because we're using multiple language models, we need special methods to handle masking based on the tokens used during training. Let's look at the masked sequences.

Then, we have a function that runs the masked sequence through the pLM and makes a prediction of the most probable residue at each masked position.

In these first two cells, example output is provided. Take a look at the code, too. The actual benchmark is run after we put together the functions.

In [None]:
def create_masked_sequence(sequence, mask_percentage=0.15, model_type="esm"):
    """
    Create a masked version of a protein sequence with appropriate mask tokens

    Args:
        sequence: Original protein sequence
        mask_percentage: Percentage of positions to mask
        model_type: Type of model to determine correct mask token

    Returns:
        masked_sequence: Sequence with masked positions
        masked_positions: List of positions that were masked
        original_residues: List of original residues at masked positions
    """
    seq_list = list(sequence)
    seq_length = len(sequence)

    # Determine correct mask token for model type
    if model_type == "protbert":
        mask_token = "[MASK]"
    elif model_type == "igbert":
        mask_token = "[MASK]"
    else:  # ESM, AMPLIFY, IgBERT use <mask>
        mask_token = "<mask>"

    # Calculate number of positions to mask
    num_to_mask = max(1, int(seq_length * mask_percentage))

    # Randomly select positions to mask (avoid start/end positions)
    maskable_positions = [i for i,x in enumerate(seq_list[1:-1]) if not x == "-"]
    masked_positions = sorted(random.sample(maskable_positions,
                                   min(num_to_mask, len(maskable_positions))))

    # Store original residues
    original_residues = [sequence[pos] for pos in sorted(masked_positions)]

    # Apply masking
    for pos in masked_positions:
        seq_list[pos] = mask_token

    masked_sequence = ''.join(seq_list)

    return masked_sequence, masked_positions, original_residues


# Example masking
for model_name in ["esm","protbert","igbert"]:
    example_seq = sample_sequences["antibody"][0]
    masked_seq, positions, residues = create_masked_sequence(example_seq, model_type=model_name)

    print(f"Masking example sequence for {model_name}")
    print("Original sequence:")
    print(example_seq)
    print(f"\nMasked sequence (positions {sorted(positions)}):")
    print(format_sequence_for_model(masked_seq,model_type=model_name))
    print(f"\nOriginal residues at masked positions: {residues}")
    print("\n\n")

In [None]:
def predict_masked_positions(model, masked_sequence, masked_positions, model_type):
    """
    Predict amino acids at masked positions using a protein language model

    Args:
        model: Loaded protein language model
        masked_sequence: Sequence with masked positions
        masked_positions: List of masked position indices
        model_type: Type of model for formatting

    Returns:
        predictions: List of predicted amino acids
        probabilities: List of prediction probabilities
        results: Raw prediction results from the model
    """

    # Format sequence for the specific model
    formatted_sequence = format_sequence_for_model(masked_sequence, model_type)

    predictions = []
    probabilities = []

    try:
        # Verify mask tokens are present
        if model_type == "protbert":
            if "[MASK]" not in formatted_sequence:
                print(f"Warning: No [MASK] tokens found in sequence for ProtBERT")
                print(f"Formatted sequence: {formatted_sequence[:100]}...")
                return ['X'] * len(masked_positions), [0.0] * len(masked_positions)
        elif model_type == "igbert":
            if "[MASK]" not in formatted_sequence:
                print(f"Warning: No [MASK] tokens found in sequence for ProtBERT")
                print(f"Formatted sequence: {formatted_sequence[:100]}...")
                return ['X'] * len(masked_positions), [0.0] * len(masked_positions)
        else:
            if "<mask>" not in formatted_sequence:
                print(f"Warning: No <mask> tokens found in sequence")
                print(f"Formatted sequence: {formatted_sequence[:100]}...")
                return ['X'] * len(masked_positions), [0.0] * len(masked_positions)

        # Get predictions for all masked positions
        results = model.pipeline(formatted_sequence, top_k=1)

        #print(f"Raw results for {model_type}: {results}")

        # Handle different result formats
        if not isinstance(results, list):
            results = [results]

        # If we have nested lists (multiple masks), flatten appropriately
        if len(results) > 0 and isinstance(results[0], list):
            # Multiple masks case - results is a list of lists
            flattened_results = []
            for mask_results in results:
                if isinstance(mask_results, list) and len(mask_results) > 0:
                    flattened_results.append(mask_results[0])  # Take top prediction
            results = flattened_results

        # Process predictions
        for result in results:
            if isinstance(result, dict):
                pred_token = result['token_str'].strip()
                pred_prob = result['score']

                # Clean up prediction token based on model type
                if model_type == "protbert":
                    # ProtBERT may return tokens with spaces or special characters
                    pred_token = pred_token.replace(' ', '').replace('▁', '')
                    # Get just the amino acid character
                    if len(pred_token) == 1 and pred_token.isalpha():
                        predictions.append(pred_token.upper())
                    else:
                        # Handle unexpected token format
                        predictions.append('X')
                else:
                    # ESM and other models
                    if len(pred_token) == 1 and pred_token.isalpha():
                        predictions.append(pred_token.upper())
                    else:
                        predictions.append('X')

                probabilities.append(pred_prob)

        # Ensure we have predictions for all masked positions
        while len(predictions) < len(masked_positions):
            predictions.append('X')
            probabilities.append(0.0)

        # Truncate if we have too many predictions
        predictions = predictions[:len(masked_positions)]
        probabilities = probabilities[:len(masked_positions)]

    except Exception as e:
        print(f"Prediction error for {model_type}: {str(e)}")
        print(f"Formatted sequence: {formatted_sequence[:100]}...")
        # Return placeholder values if prediction fails
        predictions = ['X'] * len(masked_positions)
        probabilities = [0.0] * len(masked_positions)
        results = []

    return predictions, probabilities, results

def calculate_recovery_rate(original_residues, predicted_residues):
    """Calculate the percentage of correctly predicted residues"""
    if len(original_residues) != len(predicted_residues):
        return 0.0

    correct_predictions = sum(1 for orig, pred in zip(original_residues, predicted_residues)
                            if orig == pred)

    return (correct_predictions / len(original_residues)) * 100

# Test prediction function
for test_model in list(available_models.values()):
    #test_model = list(available_models.values())[0]
    print(f"Testing {test_model.model_name}")

    test_model_type = test_model.model_type

    example_seq = sample_sequences["antibody"][0]
    masked_seq, positions, residues = create_masked_sequence(example_seq, model_type=test_model.model_type)
    print("Original sequence:")
    print(example_seq)
    print(f"\nMasked sequence (positions {positions}):")
    print(format_sequence_for_model(masked_seq,model_type=test_model.model_type))

    predictions, probs, results = predict_masked_positions(test_model, masked_seq, positions, test_model_type)
    recovery_rate = calculate_recovery_rate(residues, predictions)

    print(f"Test prediction results:")
    print(f"Original: {residues}")
    print(f"Predicted: {predictions}")
    print(f"Recovery rate: {recovery_rate:.1f}%")
    print("\n\n")

    # for x in results:
    #     print(x)

## Running Benchmark Across All Models

Protein Language Models are run efficiently when batched. But, we want to understand how the benchmark runs, so there's also code to run this one sequence at a time.

In [None]:
def run_comprehensive_evaluation(df, available_models, num_replicates=10, mask_percentage=0.15,
                               use_batching=True, batch_size=8, debug=False):
    """
    Run masked residue prediction evaluation across all models and sequences

    Args:
        df: DataFrame with protein sequences
        available_models: Dictionary of loaded models
        num_replicates: Number of times to repeat each prediction
        mask_percentage: Percentage of sequence to mask
        use_batching: Whether to use batched predictions (more efficient)
        batch_size: Number of sequences to process in each batch

    Returns:
        benchmark_df: DataFrame with detailed results
    """
    benchmark_data= []
    if use_batching:
        print("Using batched predictions for efficiency...")
        benchmark_data = _run_batched_evaluation(df, available_models, num_replicates,
                                        mask_percentage, batch_size, debug=debug)
    else:
        print("Using sequential predictions (slower but more transparent)...")
        benchmark_data = _run_sequential_evaluation(df, available_models, num_replicates,
                                           mask_percentage, debug=debug)

    #print(benchmark_data)
    return pd.DataFrame(benchmark_data)

def _run_sequential_evaluation(df, available_models, num_replicates, mask_percentage, debug=False):
    """Sequential approach - clearer for educational purposes"""
    benchmark_data = []
    total_iterations = len(df) * len(available_models) * num_replicates

    with tqdm(total=total_iterations, desc="Running sequential predictions") as pbar:

        for idx, row in df.iterrows():
            sequence = row['sequence']
            protein_class = row['protein_class']
            sequence_id = row['sequence_id']

            for model_name, model in available_models.items():

                for replicate in range(num_replicates):

                    # Create masked sequence with correct token for model type
                    masked_seq, masked_pos, original_res = create_masked_sequence(
                        sequence, mask_percentage, model_type=model.model_type
                    )
                    if debug:
                        print("Original sequence:")
                        print(sequence)
                        print(f"\nMasked sequence (positions {masked_pos}):")
                        print(format_sequence_for_model(masked_seq,model_type=model.model_type))

                    # Make predictions
                    predictions, probabilities, results = predict_masked_positions(
                        model, masked_seq, masked_pos, model.model_type
                    )

                    # Calculate recovery rate
                    recovery_rate = calculate_recovery_rate(original_res, predictions)

                    if debug:
                        print(f"Prediction results:")
                        print(f"Original: {original_res}")
                        print(f"Predicted: {predictions}")
                        print(f"Recovery rate: {recovery_rate:.1f}%")
                        print("\n\n")

                    # Store results
                    #print(predictions)
                    data = {
                        'sequence_id': sequence_id,
                        'sequence': sequence,
                        'protein_class': protein_class,
                        'model_name': model_name,
                        'replicate': replicate + 1,
                        'num_masked': len(masked_pos),
                        'recovery_rate': recovery_rate,
                        'avg_probability': np.mean(probabilities) if probabilities else 0.0,
                        'masked_positions': masked_pos,
                        'original_residues': original_res,
                        'predicted_residues': predictions
                    }
                    if debug:
                        print(data)
                    benchmark_data.append(data)

                    pbar.update(1)
    return benchmark_data

def _run_batched_evaluation(df, available_models, num_replicates, mask_percentage, batch_size, debug=False):
    """Efficient batched approach for faster processing"""
    benchmark_data = []

    for model_name, model in available_models.items():
        print(f"\nProcessing model: {model_name}")

        # Prepare all masked sequences for this model
        batch_data = []

        for idx, row in df.iterrows():
            sequence = row['sequence']
            protein_class = row['protein_class']
            sequence_id = row['sequence_id']

            for replicate in range(num_replicates):
                # Create masked sequence with correct token for model type
                masked_seq, masked_pos, original_res = create_masked_sequence(
                    sequence, mask_percentage, model.model_type
                )

                batch_data.append({
                    'sequence_id': sequence_id,
                    'sequence': sequence,
                    'protein_class': protein_class,
                    'replicate': replicate + 1,
                    'masked_sequence': masked_seq,
                    'masked_positions': masked_pos,
                    'original_residues': original_res,
                    'formatted_sequence': format_sequence_for_model(masked_seq, model.model_type)
                })

        # Process in batches - but handle each sequence individually within the batch
        for i in tqdm(range(0, len(batch_data), batch_size),
                     desc=f"Batched predictions for {model_name}"):

            batch = batch_data[i:i + batch_size]

            # Process each sequence in the batch individually
            for batch_item in batch:
                try:
                    if debug:
                        print(f"Original sequence: {batch_item['sequence']}")
                        print(f"\nMasked sequence (positions {batch_item['masked_positions']}):")
                        print(batch_item['formatted_sequence'])

                    # Use the same prediction logic as the sequential version
                    predictions, probabilities, results = predict_masked_positions(
                        model,
                        batch_item['masked_sequence'],
                        batch_item['masked_positions'],
                        model.model_type
                    )

                    if debug:
                        print(f"Prediction results:")
                        print(f"Original: {batch_item['original_residues']}")
                        print(f"Predicted: {predictions}")
                        print()

                    # Calculate recovery rate
                    recovery_rate = calculate_recovery_rate(
                        batch_item['original_residues'], predictions
                    )

                    # Store results
                    benchmark_data.append({
                        'sequence_id': batch_item['sequence_id'],
                        'sequence': batch_item['sequence'],
                        'protein_class': batch_item['protein_class'],
                        'model_name': model_name,
                        'replicate': batch_item['replicate'],
                        'num_masked': len(batch_item['masked_positions']),
                        'recovery_rate': recovery_rate,
                        'probabilities': probabilities,
                        'avg_probability': np.mean(probabilities) if probabilities else 0.0,
                        'masked_positions': batch_item['masked_positions'],
                        'original_residues': batch_item['original_residues'],
                        'predicted_residues': predictions
                    })

                except Exception as e:
                    print(f"Prediction error for {model_name} on sequence {batch_item['sequence_id']}: {str(e)}")

                    # Store failed prediction with default values
                    benchmark_data.append({
                        'sequence_id': batch_item['sequence_id'],
                        'sequence': batch_item['sequence'],
                        'protein_class': batch_item['protein_class'],
                        'model_name': model_name,
                        'replicate': batch_item['replicate'],
                        'num_masked': len(batch_item['masked_positions']),
                        'recovery_rate': 0.0,
                        'probabilities': 'Failed',
                        'avg_probability': 0.0,
                        'masked_positions': batch_item['masked_positions'],
                        'original_residues': batch_item['original_residues'],
                        'predicted_residues': ['X'] * len(batch_item['masked_positions'])
                    })

    return benchmark_data

# Run the evaluation
print("Starting comprehensive evaluation...")

# Choose evaluation method:
use_batching=True # Faster, more efficient (recommended for production)
#use_batching=False #: Slower but more transparent for learning purposes

# For demonstration purposes, we'll show both approaches
if use_batching:
    print("\n=== EFFICIENT APPROACH (Batched) ===")
    print("This approach is faster and eliminates the GPU warning")
    results_df = run_comprehensive_evaluation(
        df, available_models, num_replicates=3, use_batching=True, batch_size=8, debug=False
    )
else:
    # For demonstration purposes, we'll show both approaches
    print("\n=== EDUCATIONAL APPROACH (Sequential) ===")
    print("This approach is slower but clearer to understand step-by-step")
    results_df = run_comprehensive_evaluation(
        df, available_models, num_replicates=3, use_batching=False, debug=False
    )

print(f"\nEvaluation complete! Results shape: {results_df.shape}")
print(f"Average recovery rates by model:")
print(results_df.groupby(['model_name','protein_class'])['recovery_rate'].mean().round(2))

# Optional: Compare timing between approaches
print(f"Results: {results_df.shape[0]} predictions")

## 7. Statistical Analysis

In [None]:
# Calculate summary statistics
summary_stats = results_df.groupby(['protein_class', 'model_name']).agg({
    'recovery_rate': ['mean', 'std', 'count'],
    'avg_probability': ['mean', 'std']
}).round(3)

summary_stats.columns = ['_'.join(col).strip() for col in summary_stats.columns]
summary_stats = summary_stats.reset_index()

print("Summary Statistics:")
print(summary_stats)

# Statistical significance testing
from scipy.stats import kruskal, mannwhitneyu

def perform_statistical_tests(results_df):
    """Perform statistical tests to compare model performance"""

    test_results = []

    print("Performing statistical analysis...")
    print(f"Available models: {results_df['model_name'].unique()}")
    print(f"Available protein classes: {results_df['protein_class'].unique()}")
    print(f"Total data points: {len(results_df)}")

    # Check if we have enough models to compare
    unique_models = results_df['model_name'].unique()
    if len(unique_models) < 2:
        print(f"Warning: Only {len(unique_models)} model(s) available. Need at least 2 for comparison.")
        return pd.DataFrame(test_results)

    # Test for each protein class
    for protein_class in results_df['protein_class'].unique():
        print(f"\nAnalyzing protein class: {protein_class}")
        class_data = results_df[results_df['protein_class'] == protein_class]

        # Get recovery rates for each model
        model_groups = []
        model_names = []

        for model in class_data['model_name'].unique():
            model_data = class_data[class_data['model_name'] == model]['recovery_rate'].values
            print(f"  {model}: {len(model_data)} data points, mean={np.mean(model_data):.2f}")

            if len(model_data) > 0:  # Only include models with data
                model_groups.append(model_data)
                model_names.append(model)

        print(f"  Found {len(model_groups)} model groups for comparison")

        # Perform tests based on number of groups
        if len(model_groups) >= 3:
            # Kruskal-Wallis test for multiple groups (non-parametric ANOVA)
            try:
                h_stat, p_value = kruskal(*model_groups)
                test_results.append({
                    'protein_class': protein_class,
                    'test': 'Kruskal-Wallis',
                    'comparison': f"All {len(model_groups)} models",
                    'statistic': round(h_stat, 4),
                    'p_value': round(p_value, 6),
                    'significant': p_value < 0.05,
                    'interpretation': 'Significant differences between models' if p_value < 0.05 else 'No significant differences'
                })
                print(f"  Kruskal-Wallis: H={h_stat:.4f}, p={p_value:.6f}")
            except Exception as e:
                print(f"  Error in Kruskal-Wallis test: {e}")

        elif len(model_groups) == 2:
            # Mann-Whitney U test for two groups
            try:
                u_stat, p_value = mannwhitneyu(model_groups[0], model_groups[1], alternative='two-sided')
                test_results.append({
                    'protein_class': protein_class,
                    'test': 'Mann-Whitney U',
                    'comparison': f"{model_names[0]} vs {model_names[1]}",
                    'statistic': round(u_stat, 4),
                    'p_value': round(p_value, 6),
                    'significant': p_value < 0.05,
                    'interpretation': f'{model_names[0]} significantly different from {model_names[1]}' if p_value < 0.05 else 'No significant difference'
                })
                print(f"  Mann-Whitney U: U={u_stat:.4f}, p={p_value:.6f}")
            except Exception as e:
                print(f"  Error in Mann-Whitney U test: {e}")

        else:
            print(f"  Insufficient groups ({len(model_groups)}) for statistical testing")

    # Overall comparison across all protein classes
    print(f"\nOverall comparison across all protein classes:")
    overall_groups = []
    overall_names = []

    for model in results_df['model_name'].unique():
        model_data = results_df[results_df['model_name'] == model]['recovery_rate'].values
        print(f"  {model}: {len(model_data)} total data points, mean={np.mean(model_data):.2f}")

        if len(model_data) > 0:
            overall_groups.append(model_data)
            overall_names.append(model)

    if len(overall_groups) >= 3:
        try:
            h_stat, p_value = kruskal(*overall_groups)
            test_results.append({
                'protein_class': 'All Classes',
                'test': 'Kruskal-Wallis',
                'comparison': f"All {len(overall_groups)} models",
                'statistic': round(h_stat, 4),
                'p_value': round(p_value, 6),
                'significant': p_value < 0.05,
                'interpretation': 'Significant overall differences between models' if p_value < 0.05 else 'No significant overall differences'
            })
            print(f"  Overall Kruskal-Wallis: H={h_stat:.4f}, p={p_value:.6f}")
        except Exception as e:
            print(f"  Error in overall Kruskal-Wallis test: {e}")

    elif len(overall_groups) == 2:
        try:
            u_stat, p_value = mannwhitneyu(overall_groups[0], overall_groups[1], alternative='two-sided')
            test_results.append({
                'protein_class': 'All Classes',
                'test': 'Mann-Whitney U',
                'comparison': f"{overall_names[0]} vs {overall_names[1]}",
                'statistic': round(u_stat, 4),
                'p_value': round(p_value, 6),
                'significant': p_value < 0.05,
                'interpretation': f'{overall_names[0]} significantly different from {overall_names[1]} overall' if p_value < 0.05 else 'No significant overall difference'
            })
            print(f"  Overall Mann-Whitney U: U={u_stat:.4f}, p={p_value:.6f}")
        except Exception as e:
            print(f"  Error in overall Mann-Whitney U test: {e}")

    return pd.DataFrame(test_results)

# Perform statistical analysis
statistical_results = perform_statistical_tests(results_df)

print(f"\n{'='*80}")
print("Statistical Test Results:")
print(f"{'='*80}")

if len(statistical_results) > 0:
    # Display results in a formatted way
    for idx, row in statistical_results.iterrows():
        print(f"\n{row['protein_class']} - {row['test']}:")
        print(f"  Comparison: {row['comparison']}")
        print(f"  Statistic: {row['statistic']}")
        print(f"  P-value: {row['p_value']}")
        print(f"  Significant: {'Yes' if row['significant'] else 'No'}")
        print(f"  Interpretation: {row['interpretation']}")

    # Also display as DataFrame
    print(f"\n{'-'*80}")
    print("Summary Table:")
    print(statistical_results.to_string(index=False))
else:
    print("No statistical tests could be performed.")
    print("This may be due to:")
    print("- Only one model available")
    print("- Insufficient data points")
    print("- All models having identical performance")

# Additional descriptive statistics
print(f"\n{'='*80}")
print("Additional Descriptive Statistics:")
print(f"{'='*80}")

# Model performance comparison
model_performance = results_df.groupby('model_name')['recovery_rate'].agg([
    'count', 'mean', 'std', 'min', 'max'
]).round(3)
model_performance.columns = ['N', 'Mean', 'Std', 'Min', 'Max']
print("\nOverall Model Performance:")
print(model_performance)

# Class-specific performance
class_performance = results_df.groupby(['protein_class', 'model_name'])['recovery_rate'].agg([
    'count', 'mean', 'std'
]).round(3)
class_performance.columns = ['N', 'Mean', 'Std']
print("\nClass-Specific Performance:")
print(class_performance)

## 8. Visualization

In [None]:
# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Figure 1: Overall Model Performance by Protein Class
fig, ax = plt.subplots(figsize=(5, 4))

# Calculate mean recovery rates
mean_recovery = results_df.groupby(['protein_class', 'model_name'])['recovery_rate'].mean().reset_index()

# Create bar plot
sns.barplot(data=mean_recovery, x='protein_class', y='recovery_rate', hue='model_name', ax=ax)
ax.set_title('Average Recovery Rate by Protein Class and Model', fontsize=16, fontweight='bold')
ax.set_xlabel('Protein Class', fontsize=14)
ax.set_ylabel('Recovery Rate (%)', fontsize=14)
ax.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')

# Add value labels on bars
for container in ax.containers:
    ax.bar_label(container, fmt='%.1f', fontsize=10)

plt.tight_layout()
plt.show()


In [None]:

# Figure 3: Distribution of Recovery Rates
fig, ax = plt.subplots(figsize=(6, 4))
# Create violin plot showing distribution of recovery rates
sns.violinplot(data=results_df, x='model_name', y='recovery_rate', hue='protein_class', ax=ax)
ax.set_title('Distribution of Recovery Rates by Model', fontsize=16, fontweight='bold')
ax.set_xlabel('Model', fontsize=14)
ax.set_ylabel('Recovery Rate (%)', fontsize=14)

plt.tight_layout()
plt.show()


In [None]:
# Figure 4: Correlation between Prediction Probability and Recovery
g = sns.relplot(data=results_df, x='avg_probability', y='recovery_rate',col='model_name',
                hue='protein_class', alpha=0.6)
# ax.set_title('Recovery Rate vs. Average Prediction Probability', fontsize=16, fontweight='bold')
# ax.set_xlabel('Average Prediction Probability', fontsize=14)
# ax.set_ylabel('Recovery Rate (%)', fontsize=14)
# ax.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()

9. Model Selection Recommendations

In [None]:
def generate_model_recommendations(results_df, summary_stats):
    """Generate recommendations for model selection based on results"""

    recommendations = {}

    # Overall best performing model
    overall_performance = results_df.groupby('model_name')['recovery_rate'].mean().sort_values(ascending=False)
    best_overall = overall_performance.index[0]

    recommendations['overall_best'] = {
        'model': best_overall,
        'performance': overall_performance.iloc[0],
        'reason': 'Highest average recovery rate across all protein classes'
    }

    # Best model for each protein class
    class_performance = results_df.groupby(['protein_class', 'model_name'])['recovery_rate'].mean()

    for protein_class in results_df['protein_class'].unique():
        class_data = class_performance[protein_class].sort_values(ascending=False)
        best_for_class = class_data.index[0]

        recommendations[f'best_for_{protein_class}'] = {
            'model': best_for_class,
            'performance': class_data.iloc[0],
            'reason': f'Highest recovery rate for {protein_class} sequences'
        }

    # Most consistent model (lowest std deviation)
    consistency = results_df.groupby('model_name')['recovery_rate'].std().sort_values()
    most_consistent = consistency.index[0]

    recommendations['most_consistent'] = {
        'model': most_consistent,
        'performance': consistency.iloc[0],
        'reason': 'Lowest standard deviation in recovery rates'
    }

    return recommendations

recommendations = generate_model_recommendations(results_df, summary_stats)

print("=== MODEL SELECTION RECOMMENDATIONS ===\n")

for rec_type, rec_data in recommendations.items():
    print(f"{rec_type.replace('_', ' ').title()}:")
    print(f"  Model: {rec_data['model']}")
    print(f"  Performance: {rec_data['performance']:.2f}")
    print(f"  Reason: {rec_data['reason']}")
    print()

## Scoring whole sequences

Protein evolutionary fitness is often scored by summing the position-wise log-likelihood across the sequence. This is known as the "marginal method" for calculation of protein fitness.

As observed in other notebooks, the probabilities of an amino acid at any site will be different if the amino acid identity at that site is provided as context. Calculation of the amino acid probability while masking the site is a more realistic methodology, and this formulates the "masked marginal method".

We previously defined the function `get_positionwise_probs` that calculates amino acid probabilities after masking at each site. This function uses the pre-defined HuggingFace `pipeline` method. While the functions defined below for the `masked_marginal_method` and `wildtype_marginal_method` are fundamentally the same, ranking amino acid variants identically, the scores are slightly different as a result of implementation choices in token handling, numerical precision, normalization or other processing.

You may choose to implement the true `masked_marginal_method` when there is not a canned `pipeline` method available.

In [None]:
def add_fitness_scoring_methods(protein_model_class):
    """Add fitness scoring methods to the existing ProteinLanguageModel class"""

    def marginal_method(self, sequence):
        """
        Marginal method: Sum position-wise log-probabilities across sequence
        Single forward pass, extract probabilities for actual residues at each position
        """
        formatted_seq = format_sequence_for_model(sequence, self.model_type)

        # Tokenize sequence
        inputs = self.tokenizer(formatted_seq, return_tensors="pt")

        # Move inputs to same device as model
        device = next(self.model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
        logits = outputs.logits

        # Calculate log-likelihood for actual sequence
        log_likelihood = 0.0

        # For ProtBERT/IgBERT, tokens might be individual amino acids
        if self.model_type in ['protbert', 'igbert']:
            tokens = formatted_seq.split()
        # For ESM, tokens are also individual amino acids
        else:
            tokens = list(formatted_seq.replace("<sep>", "X"))

        for i, (token, token_logits) in enumerate(zip(tokens, logits[0][1:])):
            # Apply softmax to get probabilities
            probs = F.softmax(token_logits, dim=-1)

            # Skip special tokens
            if token in ['[CLS]', '[SEP]', '<cls>', '<eos>', '<pad>', '[PAD]','X']:
                continue

            if token.upper() in "ACDEFGHIKLMNPQRSTVWY":

                # Get token ID for the original amino acid
                aa_token_id = self.tokenizer.convert_tokens_to_ids(token)

                # Get position in logits (account for CLS token)
                prob = probs[aa_token_id].item()
                log_likelihood += np.log(prob + 1e-10)  # Add small epsilon to avoid log(0)
            #     print(token, i, aa_token_id, prob, np.log(prob + 1e-10))

            # else:
            #     print(token, "not an amino acid, skipping")

        return log_likelihood

    def masked_marginal_method(self, sequence):
        """
        Masked marginal: Mask each position individually, sum log-likelihoods
        Multiple forward passes, one for each position
        """
        total_log_likelihood = 0.0
        device = next(self.model.parameters()).device

        simple_formatted_seq = format_sequence_for_model(sequence, self.model_type)

        # For ProtBERT/IgBERT, tokens might be individual amino acids
        if self.model_type in ['protbert', 'igbert']:
            simple_formatted_seq = simple_formatted_seq.split()
        # For ESM, tokens are also individual amino acids
        else:
            simple_formatted_seq = list(simple_formatted_seq.replace("<sep>", "X"))

        for i in range(len(simple_formatted_seq)):
            # Create masked sequence using your existing function
            masked_seq = sequence[:i] + ("<mask>" if self.model_type == "esm" else "[MASK]") + sequence[i+1:]
            formatted_seq = format_sequence_for_model(masked_seq, self.model_type)

            # Tokenize
            inputs = self.tokenizer(formatted_seq, return_tensors="pt")

            # Move inputs to same device as model
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits

            # Find mask position in tokenized sequence
            tokens = self.tokenizer.tokenize(formatted_seq) # does not include start and end tokens
            mask_pos = None

            for j, token in enumerate(tokens):
                if token in ['[MASK]', '<mask>']:
                    # Account for CLS token
                    mask_pos = j
                    break

            if not mask_pos == None:
                # Get probabilities for mask position
                probs = F.softmax(logits[0, mask_pos+1], dim=-1) # +1 to account for CLS token

                # Get probability of original amino acid
                original_aa = simple_formatted_seq[mask_pos]
                aa_token_id = self.tokenizer.convert_tokens_to_ids(original_aa)

                prob = probs[aa_token_id].item()
                total_log_likelihood += np.log(prob + 1e-10)
                #print(i, mask_pos, original_aa, aa_token_id, prob, np.log(prob + 1e-10), total_log_likelihood)

        return total_log_likelihood

    def wildtype_marginal_method(self, sequence):
        """
        Wild-type marginal: Single forward pass, extract all position probabilities
        Uses the existing get_positionwise_probs logic but returns log-likelihood
        """
        # Use your existing function to get position-wise probabilities
        probs_matrix, aa_order, native_indices = get_positionwise_probs(self, sequence, self.model_type)
        #print(sequence)
        #print(probs_matrix)
        #print(aa_order)
        #print(native_indices)

        # Calculate log-likelihood for native sequence
        total_log_likelihood = 0.0
        for pos, native_idx in enumerate(native_indices):
            if native_idx >= 0:  # Valid amino acid
                prob = probs_matrix[pos, native_idx]
                total_log_likelihood += np.log(prob + 1e-10)
                #print(pos, native_idx, prob, np.log(prob + 1e-10))
        return total_log_likelihood

    def score_sequence(self, sequence, method="marginal"):
        """
        Score a protein sequence using specified method

        Args:
            sequence: Protein sequence string
            method: One of ["marginal", "masked_marginal", "wildtype_marginal"]

        Returns:
            Fitness score (log-likelihood)
        """
        if method == "marginal":
            return self.marginal_method(sequence)
        elif method == "masked_marginal":
            return self.masked_marginal_method(sequence)
        elif method == "wildtype_marginal":
            return self.wildtype_marginal_method(sequence)
        else:
            raise ValueError(f"Unknown method: {method}")

    # Add methods to the class
    protein_model_class.marginal_method = marginal_method
    protein_model_class.masked_marginal_method = masked_marginal_method
    protein_model_class.wildtype_marginal_method = wildtype_marginal_method
    protein_model_class.score_sequence = score_sequence

    return protein_model_class

# Apply the fitness scoring methods to your existing class
ProteinLanguageModel = add_fitness_scoring_methods(ProteinLanguageModel)

In [None]:
def fitness_scoring(available_models, test_sequences):
    """
    Demonstrate fitness scoring using all available models and methods

    Args:
        available_models: Dictionary of loaded ProteinLanguageModel instances
        test_sequences: Dictionary or list of test sequences

    Returns:
        Dictionary with fitness scores for each model and method
    """

    # Get test sequence
    if isinstance(test_sequences, dict):
        test_sequence = test_sequences["antibody"][0]
    elif isinstance(test_sequences, list):
        test_sequence = test_sequences[0]
    else:
        test_sequence = str(test_sequences)

    print(f"Scoring sequence: {test_sequence[:50]}{'...' if len(test_sequence) > 50 else ''}")
    print(f"Sequence length: {len(test_sequence)}")
    print("=" * 80)

    results = {}

    for model_name, model in available_models.items():
        print(f"\nProcessing {model_name} ({model.model_type}):")
        print("-" * 40)

        model_results = {}

        # Test all three methods
        methods = ["marginal", "masked_marginal", "wildtype_marginal"]

        for method in methods:
            try:
                print(f"  Computing {method} score...", end=" ")
                score = model.score_sequence(test_sequence, method=method)
                model_results[method] = score
                print(f"{score:.3f}")

            except Exception as e:
                print(f"ERROR: {str(e)}")
                model_results[method] = None

        results[model_name] = model_results

        # Calculate per-residue scores for comparison
        if model_results.get("marginal") is not None:
            per_residue_marginal = model_results["marginal"] / len(test_sequence)
            print(f"  Per-residue marginal: {per_residue_marginal:.3f}")

        if model_results.get("masked_marginal") is not None:
            per_residue_masked = model_results["masked_marginal"] / len(test_sequence)
            print(f"  Per-residue masked: {per_residue_masked:.3f}")

    # Summary comparison
    print("\n" + "=" * 80)
    print("SUMMARY - Fitness Scores by Method:")
    print("=" * 80)

    for method in ["marginal", "masked_marginal", "wildtype_marginal"]:
        print(f"\n{method.upper()}:")
        for model_name in available_models.keys():
            score = results[model_name].get(method)
            if score is not None:
                print(f"  {model_name:12s}: {score:8.3f}")
            else:
                print(f"  {model_name:12s}: {'ERROR':>8s}")

    return results


In [None]:
model = available_models["IgBERT"]
sequence = sample_sequences["antibody"][0]
print(sequence)
fitness_score = model.score_sequence(sequence, method="masked_marginal")
print(fitness_score)


### Comparing variants

In [None]:
# Convenience function for comparing sequence variants
def compare_sequence_variants(model, sequences, method="wildtype_marginal"):
    """
    Compare fitness scores for multiple sequence variants

    Args:
        model: ProteinLanguageModel instance
        sequences: List of sequences to compare
        method: Scoring method to use

    Returns:
        List of (sequence, score) tuples sorted by score (descending)
    """
    results = []

    for seq in sequences:
        try:
            score = model.score_sequence(seq, method=method)
            results.append((seq, score))
        except Exception as e:
            print(f"Error scoring sequence {seq[:20]}...: {str(e)}")
            results.append((seq, float('-inf')))

    # Sort by score (higher is better for log-likelihood)
    results.sort(key=lambda x: x[1], reverse=True)

    return results

In [None]:
# Compare sequences
native_sequence = sample_sequences["antibody"][0]

variants = [native_sequence]
variant_dict = {native_sequence:'native'} #lets make it easier to identify the variants
for idx in range(22,35):
    mutated_seq = list(native_sequence)
    native_aa = mutated_seq[idx]

    mutated_aa = random.choice("S") # DEHKNQRST
    if not mutated_aa == native_aa:
        mutated_seq[idx] = mutated_aa

        variant = "".join(mutated_seq)
        variants.append(variant)
        variant_dict[variant] = f"{native_aa}{idx+1}{mutated_seq[idx]}"


If you want, implement your own method for variant generation. Create random mutations through the whole sequence or multiple mutations in the CDRs...

In [None]:
variant_score_df = pd.DataFrame(columns=["marginal", "masked_marginal", "wildtype_marginal"])
for method in ["wildtype_marginal", "marginal", "masked_marginal"]:
    print(f"\n=== Comparing variants using method: {method} ===")
    comparison = compare_sequence_variants(model, variants, method=method)
    for variant in comparison:
        print(f"\tVariant: {variant_dict[variant[0]]}  Score: {variant[1]:.3f}")
        variant_score_df.loc[variant_dict[variant[0]],method] = variant[1]


In [None]:
variant_score_df.sort_values(by="masked_marginal", ascending=False)

In [None]:
sns.scatterplot(variant_score_df, x="masked_marginal", y="wildtype_marginal", hue=variant_score_df.index, legend=False)

# Annotate each point
for i, txt in enumerate(variant_score_df.index):
    x = variant_score_df["masked_marginal"]
    y = variant_score_df["wildtype_marginal"]
    plt.annotate(txt, (x[i], y[i]), # xy is the point to annotate
                 xytext=(x[i] + 0.01, y[i] + 0.01), # xytext is the text position
                 textcoords="data", # coordinate system for xytext
                 ha='left', va='bottom') # horizontal and vertical alignment