# Per-position Inference for models trained on full data

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches  
from tqdm.notebook import tqdm
tqdm.pandas(leave = False)

from transformers import (
    AutoTokenizer,
    EsmTokenizer,
    EsmForMaskedLM,
    pipeline,
)

from itertools import chain
import torch
import torch.nn.functional as F
import scipy

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model_dict_full= {
    "m_8M_F": "./01all_esm_models/deepspeed/esm/all_checkpoints_4good/m_8M_full_batch_128_2025-02-10/checkpoint-500000",
    "m_35M_F": "./01all_esm_models/deepspeed/esm/all_checkpoints_4good/m_35M_full_batch_128_2025-02-10/checkpoint-500000",
    "m_150M_F": "./01all_esm_models/deepspeed/esm/all_checkpoints_4good/m_150M_full_batch_128_2025-02-11/checkpoint-500000",
    "m_350M_F": "./01all_esm_models/deepspeed/esm/all_checkpoints_4good/m_350M_full_batch_128_2025-01-29/checkpoint-500000",
    "m_650M_F": "./01all_esm_models/deepspeed/esm/all_checkpoints_4good/m_650M_full_batch_128_2025-01-29/checkpoint-395000",
    
}

# tokenizer
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")


## metrics

In [5]:
def infer_and_group_stats(model, tokenizer, seq, cdr):
    losses = []
    predictions = ""
    scores = []
    perplexities = []

    with torch.no_grad():
        sep = "<cls><cls>"
        sep_idx = seq.find(sep)
        heavy = seq[:sep_idx]
        light = seq[sep_idx + len(sep):]
        cdr_mask = cdr[:sep_idx] + cdr[sep_idx + 2:]

        unmasked = tokenizer(seq, return_tensors = "pt").to(device)["input_ids"]
        ranges = [range(sep_idx), range(sep_idx + len(sep), len(seq))]
        total_len = sum(len(i) for i in ranges)

        # model iteratively predicts each residue (skipping over separator tokens)
        for i in chain(*ranges):
        # for i in tqdm(chain(*ranges), total=total_len, leave=False):
            masked = seq[:i] + "<mask>" + seq[i+1:]
            tokenized = tokenizer(masked, return_tensors="pt").to(device)
            mask_pos = (tokenized.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
            labels = torch.where(tokenized.input_ids == tokenizer.mask_token_id, unmasked, -100)
            output = model(**tokenized, labels = labels)
            logits = output.logits

            # predicted aa
            pred_token = logits[0, mask_pos].argmax(axis=-1)
            predictions+=tokenizer.decode(pred_token)

            # prediction confidence
            prob = logits[0, mask_pos].softmax(dim=-1).topk(1)[0].item()
            scores.append(prob)

            # loss
            loss = output.loss.item()
            losses.append(loss)
            
            # perplexity
            ce_loss = F.cross_entropy(logits.view(-1, tokenizer.vocab_size), labels.view(-1)) # i think this is the same as output.loss.item()
            perplexities.append(float(torch.exp(ce_loss)))

        # group stats by region
        # find indices splitting regions (fwrs and cdrs in heavy and light chains)
        cdr_idxs = [0] + [i for i in range(len(cdr_mask)) if cdr_mask[i] != cdr_mask[i-1]] + [len(cdr_mask)]
        cdr_idxs.insert(7, sep_idx)
        
        # accuracy
        predictions_by_region = [predictions[cdr_idxs[n]:cdr_idxs[n+1]] for n in range(len(cdr_idxs)-1)]
        seq_by_region = [seq.replace(sep, "")[cdr_idxs[n]:cdr_idxs[n+1]] for n in range(len(cdr_idxs)-1)]
        region_mean_acc = [sum(true[i] == predict[i] for i in range(len(true)))/len(true) for true, predict in zip(seq_by_region, predictions_by_region)]

        # prediction confidence
        region_mean_scores = [np.mean(scores[cdr_idxs[n]:cdr_idxs[n+1]]) for n in range(len(cdr_idxs)-1)]

        # loss (median)
        region_median_loss = [np.median(losses[cdr_idxs[n]:cdr_idxs[n+1]]) for n in range(len(cdr_idxs)-1)]
        
        # perplexity
        region_mean_perplexity = [np.mean(perplexities[cdr_idxs[n]:cdr_idxs[n+1]]) for n in range(len(cdr_idxs)-1)]
        
        return {
            "sequence": seq.replace(sep, ""),
            "heavy": heavy,
            "light": light,
            "cdr_indices": cdr_idxs,
            "prediction": predictions,
            "accuracy_by_region": region_mean_acc,
            "score_by_region": region_mean_scores,
            "loss_by_region": region_median_loss,
            "perplexity_by_region": region_mean_perplexity,
            "score": scores,
            "loss": losses,
            "perplexity": perplexities
        }

## Load test data


In [6]:
# to seed the random state for the splits
seed = 42

# test data separated in unmutated (germline) and mutated 
germline_test_df = pd.read_csv("./04Per_residue_inference/GERMLINE_annotated_with_cdr_mask_last_version.csv")
mutated_test_df = pd.read_csv("./04Per_residue_inference/MUTATED_annotated_with_cdr_mask_last_version.csv")

# format 2000 sample sequences for model inference
germline_test_df = germline_test_df.sample(n = 2000, random_state = seed)
mutated_test_df = mutated_test_df.sample(n = 2000, random_state = seed)

data_dict = {
    "germline": germline_test_df,
    "mutated": mutated_test_df,
}

### Inference on 2000 sequences across models


In [None]:
# model 8M_F germline
# Specify the model and data keys  
name = 'm_8M_F'  # Same model as before  
seq_type = 'germline'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  

In [None]:
# model 8M_F mutated
# Specify the model and data keys  
name = 'm_8M_F'  # Same model as before  
seq_type = 'mutated'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  

In [None]:
# model 35M_F germline
name = 'm_35M_F'  # Same model as before  
seq_type = 'germline'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  

In [None]:
# model 35M_F mutated

name = 'm_35M_F'  # Same model as before  
seq_type = 'mutated'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  

In [None]:

name = 'm_150M_F'  # Same model as before  
seq_type = 'germline'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  

In [None]:

name = 'm_150M_F'  # Same model as before  
seq_type = 'mutated'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  

In [None]:

name = 'm_350M_F'  # Same model as before  
seq_type = 'germline'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  

In [None]:

name = 'm_350M_F'  # Same model as before  
seq_type = 'mutated'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  

In [None]:

name = 'm_650M_F'  # Same model as before  
seq_type = 'germline'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  

In [None]:
 
name = 'm_650M_F'  # Same model as before  
seq_type = 'mutated'  # Replace with the actual key from data_dict  
  
# Retrieve the model path and data  
model_path = model_dict_full[name]  
data = data_dict[seq_type]  
  
# Load the model  
model = EsmForMaskedLM.from_pretrained(model_path).to(device)  
  
# Perform inference  
inference_data = []  
sequences = list(data.iterrows())  
  
for _id, row in tqdm(sequences):  
    d = infer_and_group_stats(  
        model,  
        tokenizer,  
        row['text'],  
        row['cdr_mask']  
    )  
    inference_data.append(d)  
  
# Create a DataFrame from the inference data  
inference_df = pd.DataFrame(inference_data)  
  
# Save the results  
inference_df.to_json(f"./results_full/{name}_{seq_type}_{len(sequences)}.json")  
  
print(f"Inference completed for {name} on {seq_type}. Results saved.")  