# ESM-2 Embedding Similarity Analysis for Mutations

<img src="../figures/esm_mutation.svg" alt="ESM-2 Embedding Similarity Analysis for Mutations" width="350px">

ESM-2 is one of the most advanced protein language models out there. It’s trained on millions of protein sequences using self-supervised learning, meaning it learns patterns and relationships without needing labeled data. What makes it special is that its embeddings—basically, how it represents proteins—capture both their structure and evolutionary history. That’s why it performs so well in tasks like predicting protein structures and figuring out their functions.

For our experiment, we used the 650-million parameter version of ESM-2 to test how well mutation-based token replacements in evoBPE preserve biological meaning. In simple terms, we wanted to see if swapping amino acids in a way that mimics real mutations keeps the original protein’s properties better than just making random substitutions. Even though mutation tokens are relatively rare in evoBPE’s training data, we wanted to check if they still lead to more biologically meaningful changes.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HF_HOME"] = "/cta/share/users/esm"

In [2]:
from time import time
import sqlite3
import pandas as pd
from tqdm import tqdm
import numpy as np
from transformers import EsmTokenizer, EsmModel, EsmForMaskedLM
import torch
import torch.nn.functional as F
from tqdm import tqdm
from pandarallel import pandarallel
from Bio.Align import substitution_matrices
import ast
from sklearn.metrics.pairwise import cosine_similarity

import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import random
from scipy import stats

from vocabulary_functions import get_mutated, get_parents, set_difference, set_intersection, load_tokenizers, calc_agreement, calc_dice_idx_only

In [3]:
pandarallel.initialize(progress_bar=True, nb_workers=20)

INFO: Pandarallel will run on 20 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


## Load Datasets

In [176]:
# Connect to DB
db_file = "/cta/share/users/uniprot/human/human.db"
conn = sqlite3.connect(db_file)

uniref_id = '50'
df_protein = pd.read_sql(f"""SELECT Entry as uniprot_id, Sequence as sequence
                          FROM proteins
                          WHERE Entry IN (SELECT uniprot_accession FROM uniref{uniref_id}_distilled)""", conn)
df_protein = df_protein[df_protein['sequence'].str.len() < 512].reset_index(drop=True)

df_protein_sliced = pd.read_sql(f"SELECT uniprot_id, sequence FROM uniref{uniref_id}_domain_sliced_plddt70", conn)
df_protein_sliced = df_protein_sliced[df_protein_sliced['uniprot_id'].isin(df_protein['uniprot_id'])].reset_index(drop=True)

conn.close()

## Load Tokenizers

In [177]:
# 'dataset': {'uniref50', 'uniref90'}
# 'is_pretokenizer': {True, False}
# 'subs_matrix': {'blosum45', 'blosum62', 'pam70', 'pam250'}
# 'mutation_cutoff': {0.7, 0.8, 0.9}
# 'min_mutation_freq': {0, 0.05,. 0.005}
# 'min_mutation_len': {3}
# 'max_mutation_len': {12}
# 'vocab_size': list=[800, 1600, 3200, 6400, 12800, 25600, 51200]

vocab_sizes = [51200]
uniref_id = "50"

tokenizer_opts_list = [
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': False,
        'subs_matrix': 'blosum62',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': False,
        'subs_matrix': 'blosum62',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.005,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    }
]

In [None]:
import json
from tokenizers import Tokenizer

tokenizer_list = load_tokenizers(tokenizer_opts_list, 'hf')
tokenizer_list['PUMA blosum62 0.7 0.005 51200'] = Tokenizer.from_file(f"/cta/share/users/mutbpe/tokenizers/blosum62_alldataset/1000000/hf_uniref50_mutbpe_0.7_3_12_0.05_51200.json")

inner_vocab_list = load_tokenizers(tokenizer_opts_list, 'vocab')
with open(f"/cta/share/users/mutbpe/tokenizers/blosum62_alldataset/1000000/uniref50_mutbpe_0.7_3_12_0.05_51200.json") as json_file:
    inner_vocab_list['PUMA blosum62 0.7 0.005 51200'] = json.load(json_file)

vocab_list = {name: list(set(tok.get_vocab().keys())) for name, tok in tokenizer_list.items()}
methods = [name[:-len(str(vocab_sizes[0]))-1] for name in list(tokenizer_list.keys())[::len(vocab_sizes)]]
methods2names = {mn: mn.replace('mut', 'evo').replace('std', '').replace('blosum', 'BLOSUM').replace('pam', 'PAM').replace('pre', 'Pre') for mn in methods}
methods2names = {k: ' '.join(v.split()[:-2]) if 'evoBPE' in v else v for k, v in methods2names.items()}

inner_vocab_parents_list = {}
inner_vocab_mutated_list = {}
for k, v in inner_vocab_list.items():
    inner_vocab_parents_list[k] = get_parents(v)
    inner_vocab_mutated_list[k] = get_mutated(v)

for tokenizer_name in tokenizer_list.keys():
    for mutated_token, mutated_token_attr in inner_vocab_mutated_list[tokenizer_name].items():
        parent_token = mutated_token_attr['parent']
        inner_vocab_parents_list[tokenizer_name][parent_token]['mutations'] = inner_vocab_parents_list[tokenizer_name][parent_token].get('mutations', []) + [mutated_token]

In [179]:
# --- Vocabulary Lineage Construction ---
vocab_lineage_list = {}
for k, v in inner_vocab_list.items():
    vocab_lineage_list[k] = {token: {
        'frequency': -1, 'order': -1, 'parent_pair': [], 'parent_mutation': "",
        'parent_mutation_similarity': -1, 'partner_pair_self': False,
        'partner_pair_left': [], 'partner_pair_right': [], 'child_pair': [], 'child_mutation': []
    } for token in v.keys()}

for method_name, vocab in tqdm(inner_vocab_list.items(), desc="Building Vocabulary Lineage"):
    for token, inner_elements in vocab.items():
        lineage = vocab_lineage_list[method_name][token]
        lineage['frequency'] = inner_elements.get('frequency', -1)
        lineage['order'] = inner_elements.get('order', -1)
        lineage['parent_pair'] = inner_elements.get('pair', [])
        lineage['parent_mutation'] = inner_elements.get('parent', "")
        lineage['parent_mutation_similarity'] = inner_elements.get('similarity', -1)

        if 'pair' in inner_elements:
            p1, p2 = inner_elements['pair']
            if p1 == p2:
                vocab_lineage_list[method_name][p1]['partner_pair_self'] = True
                vocab_lineage_list[method_name][p1]['child_pair'].append(token)
            else:
                vocab_lineage_list[method_name][p1]['partner_pair_right'].append(p2)
                vocab_lineage_list[method_name][p2]['partner_pair_left'].append(p1)
                vocab_lineage_list[method_name][p1]['child_pair'].append(token)
                vocab_lineage_list[method_name][p2]['child_pair'].append(token)
        if 'parent' in inner_elements:
            parent = inner_elements['parent']
            vocab_lineage_list[method_name][parent]['child_mutation'].append(token)

Building Vocabulary Lineage: 100%|██████████| 2/2 [00:00<00:00, 13.46it/s]


In [180]:
for name, tokenizer in tqdm(list(tokenizer_list.items())):
    if 'pre' in name:
        df_protein_sliced[name] = [enc.tokens for enc in tokenizer.encode_batch(df_protein_sliced['sequence'])]
    else:
        df_protein[name] = [enc.tokens for enc in tokenizer.encode_batch(df_protein['sequence'])]

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


In [181]:
df_protein_sliced = df_protein_sliced.groupby('uniprot_id').sum().reset_index()
df_protein = df_protein.set_index(['uniprot_id', 'sequence']).join(df_protein_sliced.set_index(['uniprot_id', 'sequence'])).reset_index()
df_protein.head()

Unnamed: 0,uniprot_id,sequence,PUMA blosum62 0.7 0.05 51200,PUMA blosum62 0.7 0.005 51200
0,A0A087WZT3,MELSAEYLREKLQRDLEAEHVLPSPGGVGQVRGETAASETQLGS,"[MEL, SA, EYL, REKL, QRDL, EAEH, VL, PSP, GGVG...","[ME, LSAE, YLR, EKLQ, RDLE, AEHV, LPSP, GGVG, ..."
1,A0A0B4J2F0,MFRRLTFAQLLFATVLGIAGGVYIFQPVFEQYAKDQKELKEKMQLV...,"[M, FRRL, TFA, QLL, FAT, VLG, IA, GGV, YI, FQ,...","[MFRR, LTF, AQ, LLF, AT, VLGI, AGGV, YI, FQ, P..."
2,A0A0C5B5G6,MRWQEMGYIFYPRKLR,"[MRW, QEMG, YI, FY, PRKL, R]","[MRW, QE, MG, YI, FYP, RKLR]"
3,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,"[M, TQR, AGAA, ML, PSA, LLLL, CV, PGCL, TVSG, ...","[MTQ, RAG, AAM, LP, SALL, LLCV, PGC, LTV, SGP,..."
4,A0A0S2Z4G9,MNQSRSRSDGGSEETLPQDHNHHENERRWQQERLHREEAYYQFINE...,"[MNQ, SRSR, SDGG, SEE, TLPQ, DH, NHH, ENERR, W...","[MNQ, SRSR, SDGG, SEE, TLP, QD, HN, HHE, NE, R..."


In [182]:
# --- Pre-computation Function ---

def precompute_alternatives(sub_matrix):
    """
    Pre-computes a lookup table for amino acid substitutions with better scores.
    """
    print("Pre-computing alternative amino acid scores...")
    AMINO_ACIDS = sub_matrix.alphabet[:-1]
    precomputed_alternatives = {orig_aa: {} for orig_aa in AMINO_ACIDS}

    for orig_aa in tqdm(AMINO_ACIDS):
        for mut_aa in AMINO_ACIDS:
            if orig_aa == mut_aa:
                continue
            
            try:
                score_mutation = sub_matrix[(orig_aa, mut_aa)]
            except KeyError:
                continue

            possible_alternatives = []
            for alt_aa in AMINO_ACIDS:
                if alt_aa in (orig_aa, mut_aa):
                    continue
                try:
                    score_alternative = sub_matrix[(orig_aa, alt_aa)]
                    if score_alternative == score_mutation:
                        possible_alternatives.append((alt_aa, score_alternative))
                except KeyError:
                    continue
            
            # Sort by score (descending) and store only the amino acid
            possible_alternatives.sort(key=lambda x: x[1], reverse=True)
            precomputed_alternatives[orig_aa][mut_aa] = [alt[0] for alt in possible_alternatives]
            
    return precomputed_alternatives


# --- Enhanced Helper Functions ---

def find_alternative_replacement_optimized(
    original_token, 
    mutation_replacement, 
    replacement_pool_set, 
    precomputed_alts,
    vocab_lineage
):
    """
    Finds an alternative token using the pre-computed lookup table.
    """
    if len(original_token) != len(mutation_replacement):
        return mutation_replacement

    diff_positions = [
        i for i, (orig_aa, mut_aa) in enumerate(zip(original_token, mutation_replacement)) 
        if orig_aa != mut_aa
    ]

    if not diff_positions:
        return mutation_replacement

    alternative_token = mutation_replacement

    for pos in diff_positions:
        original_aa = original_token[pos]
        mutated_aa = mutation_replacement[pos]

        best_alternatives = precomputed_alts.get(original_aa, {}).get(mutated_aa, [])
        
        for alt_aa in best_alternatives:
            alt_token_list = list(alternative_token)
            alt_token_list[pos] = alt_aa
            new_token = "".join(alt_token_list)
            
            if (new_token not in replacement_pool_set) and (new_token not in vocab_lineage):
                alternative_token = new_token
                break

    return alternative_token


def create_random_alternative_baseline(original_token, mutation_replacement, sub_matrix):
    """
    Create random alternative replacements matched for BLOSUM scores.
    """
    if len(original_token) != len(mutation_replacement):
        return mutation_replacement
    
    AMINO_ACIDS = list('ABCDEFGHIKLMNPQRSTVWYZ')
    alternative_token = list(mutation_replacement)
    
    for i, (orig_aa, mut_aa) in enumerate(zip(original_token, mutation_replacement)):
        if orig_aa != mut_aa:
            try:
                target_score = sub_matrix[(orig_aa, mut_aa)]
                # Find amino acids with similar BLOSUM scores
                candidates = []
                for aa in AMINO_ACIDS:
                    if aa != orig_aa and aa != mut_aa:
                        try:
                            score = sub_matrix[(orig_aa, aa)]
                            if abs(score - target_score) <= 1:  # Allow ±1 score difference
                                candidates.append(aa)
                        except KeyError:
                            continue
                
                if candidates:
                    alternative_token[i] = random.choice(candidates)
            except KeyError:
                continue
    
    return ''.join(alternative_token)


# --- Enhanced Main Function ---

def run_mutation_experiment_optimized(df_protein, vocab_lineage_list, sub_matrix_precomputed_alternatives, create_baseline=True):
    """
    Generates mutated, alternative, and baseline sequences efficiently.
    """
    df_results = df_protein.copy()
    token_cols = list(tokenizer_list.keys())
    change_counters = {}

    for col_name in token_cols:
        print(f"\nProcessing column: {col_name}")
        change_counters[col_name] = 0

        precomputed_alternatives = sub_matrix_precomputed_alternatives[col_name.split()[1]]
        sub_matrix = substitution_matrices.load(col_name.split()[1].upper())
        
        all_mutated_sequences = []
        all_alternative_sequences = []
        all_baseline_sequences = [] if create_baseline else None
        
        vocab_lineage = vocab_lineage_list.get(col_name, {})

        for tokenized_sequence in tqdm(df_results[col_name]):
            new_mutated_sequence = []
            new_alternative_sequence = []
            new_baseline_sequence = [] if create_baseline else None
            
            for token in tokenized_sequence:
                token_info = vocab_lineage.get(token)
                
                if not token_info:
                    new_mutated_sequence.append(token)
                    new_alternative_sequence.append(token)
                    if create_baseline:
                        new_baseline_sequence.append(token)
                    continue
                
                if 'child_mutation' in token_info and token_info['child_mutation']:
                    replacement_pool = token_info.get('child_mutation', [])
                else:
                    new_mutated_sequence.append(token)
                    new_alternative_sequence.append(token)
                    if create_baseline:
                        new_baseline_sequence.append(token)
                    continue

                replacement_pool_set = set(replacement_pool + [token])

                for mutation_replacement in replacement_pool:
                    alternative_replacement = find_alternative_replacement_optimized(
                        token, 
                        mutation_replacement, 
                        replacement_pool_set, 
                        precomputed_alternatives,
                        vocab_lineage
                    )
                    
                    if alternative_replacement != mutation_replacement:
                        new_mutated_sequence.append(mutation_replacement)
                        new_alternative_sequence.append(alternative_replacement)
                        if create_baseline:
                            baseline_replacement = create_random_alternative_baseline(
                                token, mutation_replacement, sub_matrix
                            )
                            new_baseline_sequence.append(baseline_replacement)
                        change_counters[col_name] += 1
                        break
                else:
                    new_mutated_sequence.append(token)
                    new_alternative_sequence.append(token)
                    if create_baseline:
                        new_baseline_sequence.append(token)

            all_mutated_sequences.append(new_mutated_sequence)
            all_alternative_sequences.append(new_alternative_sequence)
            if create_baseline:
                all_baseline_sequences.append(new_baseline_sequence)

        df_results[f'{col_name} mutated'] = all_mutated_sequences
        df_results[f'{col_name} alternative'] = all_alternative_sequences
        if create_baseline:
            df_results[f'{col_name} baseline'] = all_baseline_sequences
        
    return df_results, change_counters

In [199]:
# --- Enhanced ESM-2 Functions ---

def get_esm2_model_and_tokenizer(model_name="facebook/esm2_t30_150M_UR50D"):
    """Initialize ESM-2 model and tokenizer."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    tokenizer = EsmTokenizer.from_pretrained(model_name)
    model = EsmForMaskedLM.from_pretrained(model_name).to(device)
    model.eval()
    
    print(f"Loaded ESM-2 model: {model_name}")
    return model, tokenizer, device


def get_esm2_embeddings_and_logits_batch(sequences, model, tokenizer, device, batch_size=32, max_length=1024, extract_layers=None):
    """
    Get ESM-2 embeddings and logits for multiple protein sequences in batches.
    
    Args:
        extract_layers: List of layer indices to extract embeddings from (default: last layer only)
    """
    if extract_layers is None:
        extract_layers = [-1]  # Last layer only
    
    sequence_data = {}
    
    for i in tqdm(range(0, len(sequences), batch_size), desc="Processing embedding batches"):
        batch_sequences = sequences[i:i + batch_size]
        
        inputs = tokenizer(
            batch_sequences, 
            return_tensors="pt", 
            padding='longest', 
            truncation=True, 
            max_length=max_length
        ).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            batch_logits = outputs.logits
            hidden_states = outputs.hidden_states
        
        for j, seq in enumerate(batch_sequences):
            seq_len = len(seq)
            
            # Store logits (for probability analysis)
            logits = batch_logits[j, 1:seq_len+1].cpu()  # Remove <cls>, keep actual sequence
            
            # Store embeddings from specified layers
            embeddings = {}
            for layer_idx in extract_layers:
                layer_embeddings = hidden_states[layer_idx][j, 1:seq_len+1].cpu()
                embeddings[layer_idx] = layer_embeddings
            
            sequence_data[seq] = {
                'logits': logits,
                'embeddings': embeddings
            }
    
    return sequence_data


def compute_masked_probabilities(sequence, changed_positions, model, tokenizer, device):
    """
    Compute masked probabilities for specific positions in a sequence.
    """
    probabilities = {}
    
    for pos in changed_positions:
        # Create masked sequence
        seq_list = list(sequence)
        original_aa = seq_list[pos]
        seq_list[pos] = tokenizer.mask_token
        masked_seq = ''.join(seq_list)
        
        # Tokenize and get predictions
        inputs = tokenizer(masked_seq, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
        
        # Get probabilities at the masked position
        mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
        if len(mask_token_index) > 0:
            mask_pos = mask_token_index[0]
            position_logits = logits[0, mask_pos, :]
            position_probs = F.softmax(position_logits, dim=-1)
            
            # Get probabilities for all amino acids
            aa_probs = {}
            for aa in 'ABCDEFGHIKLMNPQRSTVWYZ':
                aa_token_id = tokenizer.convert_tokens_to_ids(aa)
                if aa_token_id is not None:
                    aa_probs[aa] = position_probs[aa_token_id].item()
            
            probabilities[pos] = {
                'original_aa': original_aa,
                'aa_probabilities': aa_probs
            }
    
    return probabilities

In [200]:
# --- Analysis Functions ---

def reconstruct_sequence_from_tokens(tokens):
    """Reconstruct amino acid sequence from tokens."""
    if isinstance(tokens, str):
        tokens = ast.literal_eval(tokens)
    return ''.join(tokens)


def find_differing_amino_acid_positions(original_seq, mutated_seq, alternative_seq, baseline_seq=None):
    """Find positions where amino acids differ between sequences."""
    differing_positions = []
    sequences = [mutated_seq, alternative_seq]
    if baseline_seq:
        sequences.append(baseline_seq)
    
    min_len = min(len(original_seq), *[len(seq) for seq in sequences])
    
    for i in range(min_len):
        original_aa = original_seq[i]
        differs = any(seq[i] != original_aa for seq in sequences)
        if differs:
            differing_positions.append(i)
    
    return differing_positions


def normalize_embeddings(embeddings, method='l2'):
    """Normalize embeddings using different methods."""
    if method == 'l2':
        return F.normalize(embeddings, p=2, dim=-1)
    elif method == 'center_l2':
        centered = embeddings - embeddings.mean(dim=0, keepdim=True)
        return F.normalize(centered, p=2, dim=-1)
    else:
        return embeddings


def compute_similarities_multiple_methods(original_emb, comparison_emb, positions, method='concat'):
    """Compute similarities using different aggregation methods."""
    if not positions or original_emb is None or comparison_emb is None:
        return None
    
    # Ensure positions are within bounds
    max_pos = min(original_emb.shape[0] - 1, comparison_emb.shape[0] - 1)
    valid_positions = [pos for pos in positions if pos <= max_pos]
    
    if not valid_positions:
        return None
    
    if method == 'concat':
        # Concatenate embeddings at differing positions
        orig_concat = original_emb[valid_positions].flatten().unsqueeze(0).numpy()
        comp_concat = comparison_emb[valid_positions].flatten().unsqueeze(0).numpy()
        return cosine_similarity(orig_concat, comp_concat)[0, 0]
    
    elif method == 'position_wise':
        # Compute position-wise similarities and average
        similarities = []
        for pos in valid_positions:
            orig_vec = original_emb[pos].unsqueeze(0).numpy()
            comp_vec = comparison_emb[pos].unsqueeze(0).numpy()
            sim = cosine_similarity(orig_vec, comp_vec)[0, 0]
            similarities.append(sim)
        return np.mean(similarities)
    
    elif method == 'global_mean':
        # Global mean pooling over entire sequence
        orig_mean = original_emb.mean(dim=0).unsqueeze(0).numpy()
        comp_mean = comparison_emb.mean(dim=0).unsqueeze(0).numpy()
        return cosine_similarity(orig_mean, comp_mean)[0, 0]
    
    elif method == 'local_window':
        # Local window around each changed position
        window_size = 8
        similarities = []
        for pos in valid_positions:
            start = max(0, pos - window_size)
            end = min(original_emb.shape[0], pos + window_size + 1)
            
            orig_window = original_emb[start:end].mean(dim=0).unsqueeze(0).numpy()
            comp_window = comparison_emb[start:end].mean(dim=0).unsqueeze(0).numpy()
            sim = cosine_similarity(orig_window, comp_window)[0, 0]
            similarities.append(sim)
        return np.mean(similarities)


def compute_probability_metrics(original_seq, comparison_seq, differing_positions, sequence_data, tokenizer):
    """Compute probability-based metrics."""
    metrics = {}
    
    # Get logits for both sequences
    orig_logits = sequence_data.get(original_seq, {}).get('logits')
    comp_logits = sequence_data.get(comparison_seq, {}).get('logits')
    
    if orig_logits is None or comp_logits is None:
        return metrics
    
    # Compute pseudo-perplexity difference
    orig_log_probs = F.log_softmax(orig_logits, dim=-1)
    comp_log_probs = F.log_softmax(comp_logits, dim=-1)
    
    # Get log probabilities at differing positions
    orig_position_probs = []
    comp_position_probs = []
    
    for pos in differing_positions:
        if pos < orig_logits.shape[0] and pos < comp_logits.shape[0]:
            orig_aa = original_seq[pos]
            comp_aa = comparison_seq[pos]
            
            # Get token IDs
            orig_token_id = tokenizer.convert_tokens_to_ids(orig_aa)
            comp_token_id = tokenizer.convert_tokens_to_ids(comp_aa)
            
            if orig_token_id < orig_log_probs.shape[1] and comp_token_id < comp_log_probs.shape[1]:
                orig_position_probs.append(orig_log_probs[pos, orig_token_id].item())
                comp_position_probs.append(comp_log_probs[pos, comp_token_id].item())
    
    if orig_position_probs and comp_position_probs:
        metrics['avg_log_prob_orig'] = np.mean(orig_position_probs)
        metrics['avg_log_prob_comp'] = np.mean(comp_position_probs)
        metrics['log_prob_diff'] = np.mean(orig_position_probs) - np.mean(comp_position_probs)
    
    return metrics


def compute_rank_based_metrics(original_seq, mutated_seq, alternative_seq, baseline_seq, differing_positions, model, tokenizer, device):
    """Compute rank-based metrics using masked probabilities."""
    rank_metrics = {'orig_wins_vs_mut': 0, 'orig_wins_vs_alt': 0, 'orig_wins_vs_base': 0, 'mut_wins_vs_alt': 0, 'total_positions': 0}
    
    if not differing_positions:
        return rank_metrics
    
    masked_probs = compute_masked_probabilities(original_seq, differing_positions, model, tokenizer, device)
    
    for pos in differing_positions:
        if pos not in masked_probs:
            continue
        
        aa_probs = masked_probs[pos]['aa_probabilities']
        orig_aa = original_seq[pos]
        mut_aa = mutated_seq[pos] if pos < len(mutated_seq) else orig_aa
        alt_aa = alternative_seq[pos] if pos < len(alternative_seq) else orig_aa
        base_aa = baseline_seq[pos] if (baseline_seq and pos < len(baseline_seq)) else orig_aa
        
        orig_prob = aa_probs.get(orig_aa, 0)
        mut_prob = aa_probs.get(mut_aa, 0)
        alt_prob = aa_probs.get(alt_aa, 0)
        base_prob = aa_probs.get(base_aa, 0)
        
        if orig_prob > mut_prob:
            rank_metrics['orig_wins_vs_mut'] += 1
        if orig_prob > alt_prob:
            rank_metrics['orig_wins_vs_alt'] += 1
        if orig_prob > base_prob:
            rank_metrics['orig_wins_vs_base'] += 1
        if mut_prob > alt_prob:
            rank_metrics['mut_wins_vs_alt'] += 1
        
        rank_metrics['total_positions'] += 1
    
    return rank_metrics


def get_blosum_scores(original_seq, comparison_seq, differing_positions, sub_matrix):
    """Get BLOSUM scores for differing positions."""
    scores = []
    for pos in differing_positions:
        if pos < len(original_seq) and pos < len(comparison_seq):
            orig_aa = original_seq[pos]
            comp_aa = comparison_seq[pos]
            try:
                score = sub_matrix[(orig_aa, comp_aa)]
                scores.append(score)
            except KeyError:
                continue
    return scores

In [205]:
# --- Main Enhanced Experiment Function ---

def run_enhanced_protein_similarity_experiment(
    df_protein_oma, 
    sub_matrix_precomputed_alternatives,
    batch_size=32, 
    max_length=514, 
    extract_layers=[-1],
    normalization_method='l2',
    similarity_methods=['concat', 'position_wise', 'global_mean'],
    include_probability_analysis=True,
    include_rank_analysis=True,
    include_baseline=True
):
    """
    Enhanced main function with all improvements integrated.
    """
    
    # Initialize ESM-2 model
    model, tokenizer, device = get_esm2_model_and_tokenizer()
    
    # Identify tokenizer columns
    tokenizer_columns = list(tokenizer_list.keys())
    print(f"Found {len(tokenizer_columns)} tokenizer columns to process")
    
    # Collect all unique sequences
    print("Collecting unique sequences...")
    unique_sequences = set()
    
    for tokenizer_col in tokenizer_columns:
        mutated_col = tokenizer_col + ' mutated'
        alternative_col = tokenizer_col + ' alternative'
        baseline_col = tokenizer_col + ' baseline' if include_baseline else None
        
        for idx, row in df_protein_oma.iterrows():
            try:
                original_tokens = row[tokenizer_col]
                mutated_tokens = row[mutated_col]
                alternative_tokens = row[alternative_col]
                
                original_seq = reconstruct_sequence_from_tokens(original_tokens)
                mutated_seq = reconstruct_sequence_from_tokens(mutated_tokens)
                alternative_seq = reconstruct_sequence_from_tokens(alternative_tokens)
                
                unique_sequences.update([original_seq, mutated_seq, alternative_seq])
                
                if include_baseline and baseline_col in row:
                    baseline_tokens = row[baseline_col]
                    baseline_seq = reconstruct_sequence_from_tokens(baseline_tokens)
                    unique_sequences.add(baseline_seq)
                    
            except Exception as e:
                continue
    
    unique_sequences = list(unique_sequences)
    print(f"Found {len(unique_sequences)} unique sequences")
    
    # Get embeddings and logits for all sequences
    print("Computing ESM-2 embeddings and logits...")
    sequence_data = get_esm2_embeddings_and_logits_batch(
        unique_sequences, model, tokenizer, device, batch_size, max_length, extract_layers
    )
    
    # Clear GPU memory
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Process each tokenizer
    all_results = []
    
    for tokenizer_col in tokenizer_columns:
        print(f"\nProcessing tokenizer: {tokenizer_col}")
        
        # Get substitution matrix for BLOSUM analysis
        sub_matrix_name = tokenizer_col.split()[1]
        sub_matrix = substitution_matrices.load(sub_matrix_name.upper())
        
        mutated_col = tokenizer_col + ' mutated'
        alternative_col = tokenizer_col + ' alternative'
        baseline_col = tokenizer_col + ' baseline' if include_baseline else None
        
        # Initialize result containers
        results = {
            'similarities': {method: {'mut': [], 'alt': [], 'base': []} for method in similarity_methods},
            'probability_metrics': {'mut': [], 'alt': [], 'base': []},
            'rank_metrics': [],
            'blosum_scores': {'mut': [], 'alt': [], 'base': []},
            'position_counts': []
        }
        
        processed_proteins = 0
        
        for idx, row in tqdm(df_protein_oma.iterrows(), total=len(df_protein_oma), desc="Processing proteins"):
            try:
                # Reconstruct sequences
                original_seq = reconstruct_sequence_from_tokens(row[tokenizer_col])
                mutated_seq = reconstruct_sequence_from_tokens(row[mutated_col])
                alternative_seq = reconstruct_sequence_from_tokens(row[alternative_col])
                baseline_seq = None
                
                if include_baseline and baseline_col in row:
                    baseline_seq = reconstruct_sequence_from_tokens(row[baseline_col])
                
                # Skip if no changes
                if (original_seq == mutated_seq and original_seq == alternative_seq and 
                    (not include_baseline or original_seq == baseline_seq)):
                    continue
                
                # Find differing positions
                differing_positions = find_differing_amino_acid_positions(
                    original_seq, mutated_seq, alternative_seq, baseline_seq
                )
                
                if not differing_positions:
                    continue
                
                results['position_counts'].append(len(differing_positions))
                
                # Get sequence data
                orig_data = sequence_data.get(original_seq, {})
                mut_data = sequence_data.get(mutated_seq, {})
                alt_data = sequence_data.get(alternative_seq, {})
                base_data = sequence_data.get(baseline_seq, {}) if baseline_seq else {}
                
                # Process each layer
                for layer_idx in extract_layers:
                    orig_emb = orig_data.get('embeddings', {}).get(layer_idx)
                    mut_emb = mut_data.get('embeddings', {}).get(layer_idx)
                    alt_emb = alt_data.get('embeddings', {}).get(layer_idx)
                    base_emb = base_data.get('embeddings', {}).get(layer_idx) if baseline_seq else None
                    
                    if orig_emb is None:
                        continue
                    
                    # Normalize embeddings
                    orig_emb_norm = normalize_embeddings(orig_emb, normalization_method)
                    mut_emb_norm = normalize_embeddings(mut_emb, normalization_method) if mut_emb is not None else None
                    alt_emb_norm = normalize_embeddings(alt_emb, normalization_method) if alt_emb is not None else None
                    base_emb_norm = normalize_embeddings(base_emb, normalization_method) if base_emb is not None else None
                    
                    # Compute similarities using different methods
                    for method in similarity_methods:
                        if mut_emb_norm is not None:
                            sim_mut = compute_similarities_multiple_methods(
                                orig_emb_norm, mut_emb_norm, differing_positions, method
                            )
                            if sim_mut is not None:
                                results['similarities'][method]['mut'].append(sim_mut)
                        
                        if alt_emb_norm is not None:
                            sim_alt = compute_similarities_multiple_methods(
                                orig_emb_norm, alt_emb_norm, differing_positions, method
                            )
                            if sim_alt is not None:
                                results['similarities'][method]['alt'].append(sim_alt)
                        
                        if base_emb_norm is not None:
                            sim_base = compute_similarities_multiple_methods(
                                orig_emb_norm, base_emb_norm, differing_positions, method
                            )
                            if sim_base is not None:
                                results['similarities'][method]['base'].append(sim_base)
                
                # Probability analysis
                if include_probability_analysis:
                    prob_metrics_mut = compute_probability_metrics(
                        original_seq, mutated_seq, differing_positions, sequence_data, tokenizer
                    )
                    prob_metrics_alt = compute_probability_metrics(
                        original_seq, alternative_seq, differing_positions, sequence_data, tokenizer
                    )
                    
                    results['probability_metrics']['mut'].append(prob_metrics_mut)
                    results['probability_metrics']['alt'].append(prob_metrics_alt)
                    
                    if baseline_seq:
                        prob_metrics_base = compute_probability_metrics(
                            original_seq, baseline_seq, differing_positions, sequence_data, tokenizer
                        )
                        results['probability_metrics']['base'].append(prob_metrics_base)
                
                # Rank analysis
                if include_rank_analysis:
                    rank_metrics = compute_rank_based_metrics(
                        original_seq, mutated_seq, alternative_seq, baseline_seq, differing_positions, 
                        model, tokenizer, device
                    )
                    results['rank_metrics'].append(rank_metrics)
                
                # BLOSUM score analysis
                blosum_mut = get_blosum_scores(original_seq, mutated_seq, differing_positions, sub_matrix)
                blosum_alt = get_blosum_scores(original_seq, alternative_seq, differing_positions, sub_matrix)
                
                results['blosum_scores']['mut'].extend(blosum_mut)
                results['blosum_scores']['alt'].extend(blosum_alt)
                
                if baseline_seq:
                    blosum_base = get_blosum_scores(original_seq, baseline_seq, differing_positions, sub_matrix)
                    results['blosum_scores']['base'].extend(blosum_base)
                
                processed_proteins += 1
                
            except Exception as e:
                print(f"Error processing row {idx}: {e}")
                continue
        
        # Compile results for this tokenizer
        tokenizer_results = {
            'tokenizer': tokenizer_col,
            'processed_proteins': processed_proteins,
            'total_differing_positions': sum(results['position_counts']),
            'avg_differing_positions': np.mean(results['position_counts']) if results['position_counts'] else 0
        }
        
        # Add similarity results
        for method in similarity_methods:
            for comparison in ['mut', 'alt', 'base']:
                if results['similarities'][method][comparison]:
                    tokenizer_results[f'{method}_similarity_{comparison}'] = np.mean(results['similarities'][method][comparison])
        
        # Add rank results
        if results['rank_metrics']:
            total_wins_mut = sum(r['orig_wins_vs_mut'] for r in results['rank_metrics'])
            total_wins_alt = sum(r['orig_wins_vs_alt'] for r in results['rank_metrics'])
            total_wins_base = sum(r['orig_wins_vs_base'] for r in results['rank_metrics'])
            total_wins_mutaalt = sum(r['mut_wins_vs_alt'] for r in results['rank_metrics'])
            total_positions = sum(r['total_positions'] for r in results['rank_metrics'])
            
            if total_positions > 0:
                tokenizer_results['rank_win_rate_vs_mut'] = total_wins_mut / total_positions
                tokenizer_results['rank_win_rate_vs_alt'] = total_wins_alt / total_positions
                tokenizer_results['rank_win_rate_vs_base'] = total_wins_base / total_positions
                tokenizer_results['rank_win_rate_mut_vs_alt'] = total_wins_mutaalt / total_positions
        
        # Add BLOSUM correlation analysis
        for comparison in ['mut', 'alt', 'base']:
            blosum_scores = results['blosum_scores'][comparison]
            similarities = results['similarities']['concat'][comparison]  # Use concat method for correlation
            
            if len(blosum_scores) > 0 and len(similarities) > 0 and len(blosum_scores) == len(similarities):
                correlation, p_value = stats.pearsonr(blosum_scores, similarities)
                tokenizer_results[f'blosum_similarity_correlation_{comparison}'] = correlation
                tokenizer_results[f'blosum_similarity_p_value_{comparison}'] = p_value
        
        all_results.append(tokenizer_results)
    
    return pd.DataFrame(all_results)

In [186]:
# Initialize and run the experiment
print("Starting enhanced protein mutation analysis...")

# Precompute alternatives for all substitution matrices
sub_matrix_precomputed_alternatives = {}
for sub_matrix_name in ['blosum45', 'blosum62', 'pam70', 'pam250']:
    sub_matrix = substitution_matrices.load(sub_matrix_name.upper())
    sub_matrix_precomputed_alternatives[sub_matrix_name] = precompute_alternatives(sub_matrix)

# Run mutation experiment with baseline
print("Running mutation experiment with baseline...")
df_protein_oma, change_counters = run_mutation_experiment_optimized(
    df_protein, vocab_lineage_list, sub_matrix_precomputed_alternatives, create_baseline=True
)

Starting enhanced protein mutation analysis...
Pre-computing alternative amino acid scores...


100%|██████████| 23/23 [00:00<00:00, 835.91it/s]


Pre-computing alternative amino acid scores...


100%|██████████| 23/23 [00:00<00:00, 809.04it/s]


Pre-computing alternative amino acid scores...


100%|██████████| 23/23 [00:00<00:00, 1137.82it/s]


Pre-computing alternative amino acid scores...


100%|██████████| 23/23 [00:00<00:00, 1212.38it/s]


Running mutation experiment with baseline...

Processing column: PUMA blosum62 0.7 0.05 51200


100%|██████████| 58149/58149 [00:12<00:00, 4535.16it/s]



Processing column: PUMA blosum62 0.7 0.005 51200


100%|██████████| 58149/58149 [00:14<00:00, 4067.80it/s]


In [208]:
# Run enhanced similarity experiment
print("Running enhanced similarity experiment...")
results_df = run_enhanced_protein_similarity_experiment(
    df_protein_oma[:100],
    sub_matrix_precomputed_alternatives,
    batch_size=64,
    extract_layers=[6, 12, -1],  # Multiple layers
    similarity_methods=['concat', 'position_wise', 'global_mean', 'local_window'],
    include_probability_analysis=False,
    include_rank_analysis=True,
    include_baseline=True
)

Running enhanced similarity experiment...
Using device: cuda
Loaded ESM-2 model: facebook/esm2_t30_150M_UR50D
Found 2 tokenizer columns to process
Collecting unique sequences...
Found 689 unique sequences
Computing ESM-2 embeddings and logits...


Processing embedding batches: 100%|██████████| 11/11 [00:04<00:00,  2.22it/s]



Processing tokenizer: PUMA blosum62 0.7 0.05 51200


Processing proteins: 100%|██████████| 100/100 [00:20<00:00,  4.85it/s]



Processing tokenizer: PUMA blosum62 0.7 0.005 51200


Processing proteins: 100%|██████████| 100/100 [00:22<00:00,  4.51it/s]


In [210]:
results_df

Unnamed: 0,tokenizer,processed_proteins,total_differing_positions,avg_differing_positions,concat_similarity_mut,concat_similarity_alt,concat_similarity_base,position_wise_similarity_mut,position_wise_similarity_alt,position_wise_similarity_base,global_mean_similarity_mut,global_mean_similarity_alt,global_mean_similarity_base,local_window_similarity_mut,local_window_similarity_alt,local_window_similarity_base,rank_win_rate_vs_mut,rank_win_rate_vs_alt,rank_win_rate_vs_base,rank_win_rate_mut_vs_alt
0,PUMA blosum62 0.7 0.05 51200,98,1014,10.346939,0.927336,0.943783,0.929539,0.927336,0.943783,0.929539,0.998688,0.99884,0.998875,0.994957,0.996405,0.995617,0.905325,0.97929,0.883629,0.852071
1,PUMA blosum62 0.7 0.005 51200,99,1095,11.060606,0.928825,0.945729,0.931397,0.928825,0.945729,0.931397,0.998657,0.998866,0.999026,0.99492,0.996618,0.995743,0.927854,0.911416,0.835616,0.754338


In [211]:
results_df[results_df.columns[[0,1,-4,-3,-2,-1]]]

Unnamed: 0,tokenizer,processed_proteins,rank_win_rate_vs_mut,rank_win_rate_vs_alt,rank_win_rate_vs_base,rank_win_rate_mut_vs_alt
0,PUMA blosum62 0.7 0.05 51200,98,0.905325,0.97929,0.883629,0.852071
1,PUMA blosum62 0.7 0.005 51200,99,0.927854,0.911416,0.835616,0.754338
