# ESM Multi-Target Hierarchical Fine-Tuning and Embedding Comparison

This notebook walks through a layered process of fine-tuning a pre-trained ESM model. It first trains on an entire dataset, then iteratively fine-tunes new models on specific subsets of the data in a hierarchical fashion.

For each level of fine-tuning, it:
1.  Generates protein embeddings from the current base model.
2.  Fine-tunes a new model on a specific subset of data.
3.  Generates embeddings from the newly fine-tuned model.
4.  Produces a classification report, training validation curve, and comparison plots (PCA and t-SNE) for that specific level.
5.  Saves all artifacts into a dedicated, nested directory.

## 1. Setup and Imports

First, let's install and import all the necessary libraries. Ensure you have a GPU available for faster training.

In [None]:
%pip install -q transformers datasets scikit-learn numpy pandas torch seaborn matplotlib accelerate
%pip install -q esm biopython py3Dmol huggingface_hub

In [None]:
import os
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from scipy.special import softmax
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_fscore_support, accuracy_score, fbeta_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
from functools import partial
import hashlib
import json
import shutil
import requests

## 2. Configuration

Set all your parameters in this cell. This replaces the command-line arguments from the original script.

In [None]:
class Config:
    # --- Input and Output ---\n",
    # IMPORTANT: Make sure this CSV file is uploaded to your notebook environment!
    CSV_PATH = "../Data/TIMP_binder_all.csv"
    OUTPUT_DIR = "../Local/esm_hierarchical_out"
    SEQ_COL = "Full Seq"       # Column with protein sequences
    LABEL_COL = "Target" # The ultimate target column for classification
    BINDING_COL = "Encoding"   # Column indicating positive (1) or negative (0) binding
    COUNT_COL = "Count"

    # --- Hierarchy Definition ---
    # NOTE: HIERARCHY_COLS should not contain LABEL_COL
    HIERARCHY_COLS = ["Ligand_Subtype", "Loop_Subtype"] # <-- IMPORTANT: UPDATE THIS LIST

    # --- Model and Training ---\n",
    #MODEL_ID = "facebook/esm2_t36_3B_UR50D"
    #MODEL_ID = "facebook/esm2_t33_650M_UR50D"
    MODEL_ID = "facebook/esm2_t30_150M_UR50D"
    #MODEL_ID = "facebook/esm2_t6_8M_UR50D" # Smaller model for testing
    TEST_SIZE = 0.2
    EPOCHS = 25
    LEARNING_RATE = 1e-3
    BATCH_SIZE = 8
    FORCE_RETRAIN = False # Set to True to retrain even if a saved model exists

config = Config()

## 3. Helper Functions

These are the utility functions from the script for device detection, embedding computation, and plotting.

In [None]:
def get_device():
    """Detects and returns the available hardware device."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available(): # For Apple Silicon
        return torch.device("mps")
    else:
        return torch.device("cpu")

def compute_embeddings(sequences, model, tokenizer, device, run_name=""):
    """Computes embeddings for a list of sequences using the given model."""
    model.to(device)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for seq in tqdm(sequences, desc=f"Computing embeddings ({run_name})"):
            tokens = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1022).to(device)
            output = model(**tokens, output_hidden_states=True)
            # Use the last hidden state and average pool across the sequence length
            embedding = output.hidden_states[-1].mean(dim=1).squeeze().cpu().numpy()
            embeddings.append(embedding)
    return np.vstack(embeddings)

sns.set_palette(["#009100","#FF0000","#0000FF","#800080","#FFA500","#00FFFF","#FFC0CB","#A52A2A","#808000","#000000"])

def plot_comparison(before_data, after_data, labels, reduction_method, out_dir, model_tag, timestamp, hue_order=None, run_name=""):
    """Creates and saves a 2-panel comparison plot."""
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    title = f'{reduction_method} Comparison for {run_name} ({model_tag})'
    fig.suptitle(title, fontsize=20)

    # Before fine-tuning
    sns.scatterplot(ax=axes[0], x=before_data[:, 0], y=before_data[:, 1], hue=labels, s=50, alpha=0.7, hue_order=hue_order)
    axes[0].set_title("Before This Training Step", fontsize=16)
    axes[0].set_xlabel("Dimension 1")
    axes[0].set_ylabel("Dimension 2")
    axes[0].legend(title="Target")

    # After fine-tuning
    sns.scatterplot(ax=axes[1], x=after_data[:, 0], y=after_data[:, 1], hue=labels, s=50, alpha=0.7, hue_order=hue_order)
    axes[1].set_title("After This Training Step", fontsize=16)
    axes[1].set_xlabel("Dimension 1")
    axes[1].set_ylabel("Dimension 2")
    axes[1].legend(title="Target")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    filename = out_dir / f"{reduction_method.lower()}_comparison_{model_tag}_{timestamp}.png"
    plt.savefig(filename)
    print(f"Saved comparison plot to: {filename}")
    plt.show()
    plt.close()

class ProteinDataset(torch.utils.data.Dataset):
    """Custom PyTorch Dataset for protein sequences."""
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

    def __len__(self):
        return len(self.labels)


def compute_metrics_multiclass(pred, id2label: dict):
    """
    Computes generalized classification metrics with a special focus on binder-only performance.
    
    This function dynamically identifies the 'Non-Binder' class and calculates metrics
    only for the true binder classes.
    """
    BETA = 0.5

    labels = pred.label_ids
    preds_data = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions
    preds = np.argmax(preds_data, axis=-1)

    # --- Find the 'Non-Binder' label ID dynamically ---
    # We iterate through the id2label mapping to find the key (the integer label)
    # corresponding to the value 'Non-Binder'.
    non_binder_id = None
    for label_id, label_name in id2label.items():
        # Using .lower() makes the check case-insensitive (e.g., 'non-binder', 'Non-Binder')
        if "non-binder" in label_name.lower():
            non_binder_id = int(label_id)
            break
            
    # --- Create the list of binder labels by excluding the non-binder ---
    all_label_ids = list(id2label.keys())
    
    # Check if a non-binder class was found before trying to remove it
    if non_binder_id is not None:
        # The list of binder_labels is all labels EXCEPT the non_binder_id
        binder_labels = [label for label in all_label_ids if label != non_binder_id]
    else:
        # If no non-binder class is found, assume all classes are binders
        binder_labels = all_label_ids
        print("Warning: 'Non-Binder' class not found in id2label dict. Calculating metrics on all classes.")


    # --- Calculate your binder-focused metrics as before ---
    # F-beta score with beta=0.5 to heavily favor precision on binder classes
    binder_fbeta_score = fbeta_score(
        labels, 
        preds, 
        beta=BETA, 
        labels=binder_labels, # Using the dynamically generated list
        average='macro', 
        zero_division=0
    )

    # Precision for only binder classes
    binder_precision, binder_recall, binder_f1, _ = precision_recall_fscore_support(
        labels, 
        preds, 
        labels=binder_labels, # Using the dynamically generated list
        average='macro', 
        zero_division=0
    )

    # Old Metrics
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=0)
    f_beta_macro = fbeta_score(labels, preds, beta=BETA, average='macro', zero_division=0)
    acc = accuracy_score(labels, preds)

    return {
        'accuracy': acc,
        'f1_weighted': f1_weighted,
        'f1_macro': f1_macro,
        'f_beta_macro': f_beta_macro,
        'precision_weighted': precision_weighted,
        'precision_macro': precision_macro,
        'recall_weighted': recall_weighted,
        'recall_macro': recall_macro,
        'binder_precision': binder_precision,
        'binder_recall': binder_recall,
        'binder_f1': binder_f1,
        'binder_fbeta_score': binder_fbeta_score # For training
    }

Pipeline function

In [None]:
# Cell 7: Updated run_finetuning_pipeline Function

def run_finetuning_pipeline(train_df: pd.DataFrame, validation_df: pd.DataFrame, output_path: Path, base_model_id: str, run_name: str, config: Config):
    """
    Runs the complete fine-tuning and analysis pipeline using pre-split train and validation dataframes.
    """
    output_path.mkdir(parents=True, exist_ok=True)
    device = get_device()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_tag = config.MODEL_ID.split('/')[-1]
    finetuned_model_path = output_path / f"finetuned_{model_tag}"

    # For plotting and embedding, we use the combined data from this run
    df_subset = pd.concat([train_df, validation_df])

    print(f"--- Starting Pipeline for: {run_name} ---")
    print(f"Training data size: {len(train_df)}, Validation data size: {len(validation_df)}")
    print(f"Output directory: {output_path}")
    print(f"Base model: {base_model_id}")

    # --- 1. Data Processing for the current subset ---
    unique_labels = sorted(df_subset[config.LABEL_COL].unique())
    label2id = {label: i for i, label in enumerate(unique_labels)}
    id2label = {i: label for label, i in label2id.items()}
    
    # Apply mapping to both dataframes
    train_df['numerical_labels'] = train_df[config.LABEL_COL].map(label2id)
    validation_df['numerical_labels'] = validation_df[config.LABEL_COL].map(label2id)
    num_labels = len(unique_labels)
    
    # The rest of the script uses these separate lists
    train_texts, train_labels = train_df[config.SEQ_COL].tolist(), train_df['numerical_labels'].tolist()
    val_texts, val_labels = validation_df[config.SEQ_COL].tolist(), validation_df['numerical_labels'].tolist()

    # Data for plots
    sequences = df_subset[config.SEQ_COL].tolist()
    text_labels = df_subset[config.LABEL_COL].tolist()

    # --- 2. Pre-trained Model Embeddings ---
    print("\nStep A: Generating embeddings from the BASE model...")
    trust_code = 'esm' in base_model_id.lower() or 'synthyra' in base_model_id.lower()
    base_model = AutoModelForSequenceClassification.from_pretrained(
        base_model_id, num_labels=num_labels, trust_remote_code=trust_code, ignore_mismatched_sizes=True
    )
    tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID if 'esm2' in config.MODEL_ID else base_model_id)
    embeddings_before = compute_embeddings(sequences, base_model, tokenizer, device, run_name)

    # --- 3. Fine-tuning the Model ---
    print("\nStep B: Fine-tuning the model...")
    has_validation_data = not validation_df.empty
    
    if finetuned_model_path.exists() and not config.FORCE_RETRAIN:
        print(f"Found existing fine-tuned model. Loading from {finetuned_model_path}")
    else:
        # (Code for re-training is largely the same, but uses pre-split data)
        model_for_training = AutoModelForSequenceClassification.from_pretrained(
            base_model_id, num_labels=num_labels, id2label=id2label, label2id=label2id, trust_remote_code=trust_code, ignore_mismatched_sizes=True
        )

        # The id2label dictionary is created from your LabelEncoder or is in your model config
        id2label_map = model_for_training.config.id2label
        compute_metrics_with_context = partial(compute_metrics_multiclass, id2label=id2label_map) # Create a partial function with the id2label map baked in
        
        train_encodings = tokenizer(train_texts, truncation=True, padding=True)
        train_dataset = ProteinDataset(train_encodings, train_labels)
        
        # Only create validation dataset if validation_df is not empty
        val_dataset = None
        if has_validation_data:
            val_encodings = tokenizer(val_texts, truncation=True, padding=True)
            val_dataset = ProteinDataset(val_encodings, val_labels)

        training_args = TrainingArguments(
            output_dir=str(output_path / 'training_checkpoints'),
            num_train_epochs=config.EPOCHS,
            per_device_train_batch_size=config.BATCH_SIZE,
            per_device_eval_batch_size=config.BATCH_SIZE,
            warmup_steps=100,
            weight_decay=0.01,
            logging_dir=str(output_path / 'logs'),
            logging_steps=10,
            # Adjust strategy based on presence of validation data
            evaluation_strategy="epoch" if has_validation_data else "no",
            save_strategy="epoch",
            load_best_model_at_end=has_validation_data,
            metric_for_best_model="binder_fbeta_score" if has_validation_data else None,
            greater_is_better=True
        )

        trainer = Trainer(
            model=model_for_training, 
            args=training_args, 
            train_dataset=train_dataset,
            eval_dataset=val_dataset, 
            #compute_metrics=compute_metrics_multiclass,
            compute_metrics=compute_metrics_with_context,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] if has_validation_data else None
        )

        trainer.train()
        trainer.save_model(finetuned_model_path)
        tokenizer.save_pretrained(finetuned_model_path)
        print(f"Fine-tuned model saved to {finetuned_model_path}")
        
        if has_validation_data:
            # Plot validation curve
            df_history = pd.DataFrame(trainer.state.log_history)
            train_loss = df_history[df_history['loss'].notna()]
            eval_loss = df_history[df_history['eval_loss'].notna()]
            plt.figure(figsize=(10, 6))
            plt.plot(train_loss['step'], train_loss['loss'], label='Training Loss')
            plt.plot(eval_loss['step'], eval_loss['eval_loss'], label='Validation Loss', marker='o')
            plt.title(f'Training and Validation Loss ({run_name})')
            plt.xlabel('Training Steps'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
            val_curve_path = output_path / f"validation_curve_{model_tag}_{timestamp}.png"
            plt.savefig(val_curve_path)
            plt.show(); plt.close()
            print(f"Validation curve saved to {val_curve_path}")

    # --- 4. Detailed Classification Report on Validation Set ---
    if has_validation_data:
        print("\nStep C: Generating classification report on the fixed validation set...")
        # ... (The rest of the function proceeds as before, using the validation data)
        # The logic for reporting, embedding, and plotting remains the same.
        finetuned_model = AutoModelForSequenceClassification.from_pretrained(finetuned_model_path)
        report_trainer = Trainer(model=finetuned_model.to(device))
        predictions, _, _ = report_trainer.predict(val_dataset)
        y_pred = np.argmax(predictions, axis=1)
        report_text = classification_report(val_labels, y_pred, target_names=list(label2id.keys()), zero_division=0)
        print(f'Classification Report for {run_name}:\n{report_text}')
        class_report_path = output_path / f'classification_report_{model_tag}_{timestamp}.txt'
        with open(class_report_path, 'w') as f:
            f.write(f"Classification Report for run: {run_name}\n\n{report_text}")
        print(f"Classification report saved to {class_report_path}")

        # --- Expanded classification report ---

        # Create the partial function to pass the id2label map
        compute_metrics_with_context = partial(compute_metrics_multiclass, id2label=id2label_map)

        # Get predictions for the hold-out test set
        print("--- Generating Predictions on the Hold-Out Test Set ---")
        final_report_text = report_text

        # Calculate ALL custom metrics using your function
        predictions = trainer.predict(val_dataset)
        all_custom_metrics = compute_metrics_with_context(predictions)

        # 6. Dynamically create the custom metrics report text
        custom_metrics_header = "--- Comprehensive Performance Metrics ---"
        custom_metrics_lines = []
        for metric_name, metric_value in all_custom_metrics.items():
            # Nicely format the name (e.g., 'binder_fbeta_score' -> 'Binder Fbeta Score')
            formatted_name = metric_name.replace('_', ' ').replace('fbeta', 'F-beta').title()
            custom_metrics_lines.append(f"{formatted_name}: {metric_value:.4f}")

        custom_metrics_text = "\n".join([custom_metrics_header] + custom_metrics_lines)

        # 7. Print and save the combined, comprehensive report
        print(final_report_text)
        print(custom_metrics_text)

        final_report_path = Path(config.OUTPUT_DIR) / f'Expanded_classification_report_{run_name}.txt'

        with open(final_report_path, 'w') as f:
            f.write(f"--- Final Model Evaluation on the Hold-Out Test Set ---\n")
            f.write(f"Model: {run_name}\n\n")
            
            # Write the standard Sklearn report
            f.write("--- Standard Classification Report ---\n")
            f.write(final_report_text)
            f.write("\n\n") # Add spacing
            
            # Write the dynamically generated custom metrics report
            f.write(custom_metrics_text)

        print(f"\\nFinal, comprehensive report saved to: {final_report_path}")


    # --- 5. Post-fine-tuning Embeddings & Plotting ---
    print("\nStep D: Generating embeddings from the NEWLY fine-tuned model...")
    finetuned_model_loaded = AutoModelForSequenceClassification.from_pretrained(finetuned_model_path)
    embeddings_after = compute_embeddings(sequences, finetuned_model_loaded, tokenizer, device, run_name)
    print("\nStep E: Running PCA/t-SNE and generating comparison plots...")
    perplexity = min(30, len(df_subset) - 1) if len(df_subset) > 1 else 1
    pca_before = PCA(n_components=2).fit_transform(embeddings_before)
    tsne_before = TSNE(n_components=2, perplexity=perplexity, random_state=42, max_iter=1000).fit_transform(embeddings_before)
    pca_after = PCA(n_components=2).fit_transform(embeddings_after)
    tsne_after = TSNE(n_components=2, perplexity=perplexity, random_state=42, max_iter=1000).fit_transform(embeddings_after)
    hue_order = list(label2id.keys())
    plot_comparison(pca_before, pca_after, text_labels, "PCA", output_path, model_tag, timestamp, hue_order=hue_order, run_name=run_name)
    plot_comparison(tsne_before, tsne_after, text_labels, "t-SNE", output_path, model_tag, timestamp, hue_order=hue_order, run_name=run_name)
    
    print(f"--- Finished Pipeline for: {run_name} ---\n")
    return finetuned_model_path

Hierarchy function

In [None]:
# Cell 8: Updated Hierarchical Execution Logic

def process_level(df_train: pd.DataFrame, df_validation: pd.DataFrame, parent_model_path: Path, hierarchy_level: int, config: Config, parent_output_path: Path):
    """
    Recursively processes each level of the hierarchy using pre-split data.
    """
    if hierarchy_level >= len(config.HIERARCHY_COLS):
        return

    col_name = config.HIERARCHY_COLS[hierarchy_level]
    print(f"==================== STARTING HIERARCHY LEVEL {hierarchy_level + 1} (Column: '{col_name}') ====================")
    
    # Iterate through categories present in the TRAINING data for this level
    for category in sorted(df_train[col_name].unique()):
        print(f"\nProcessing Category: '{category}' from column '{col_name}'")
        
        # Filter both the training and validation sets for the current category
        train_subset = df_train[df_train[col_name] == category].copy()
        validation_subset = df_validation[df_validation[col_name] == category].copy()

        # Check for sufficient data
        if len(train_subset) < config.BATCH_SIZE or len(train_subset[config.LABEL_COL].unique()) < 2:
            print(f"SKIPPING '{category}': Insufficient training data or fewer than 2 classes.")
            continue
        if validation_subset.empty:
            print(f"WARNING: No validation data found for '{category}'. Training will proceed without evaluation-based early stopping.")

        safe_category_name = "".join(c for c in category if c.isalnum() or c in (' ', '_')).rstrip()
        run_name = f"L{hierarchy_level+1}_{safe_category_name.replace(' ', '_')}"
        output_path = parent_output_path / run_name
        
        newly_tuned_model_path = run_finetuning_pipeline(
            train_df=train_subset,
            validation_df=validation_subset, # Pass the filtered validation set
            output_path=output_path,
            base_model_id=str(parent_model_path),
            run_name=run_name,
            config=config
        )

        # Recursive call for the next level
        process_level(
            df_train=train_subset, # Pass filtered subsets to the next level
            df_validation=validation_subset,
            parent_model_path=newly_tuned_model_path,
            hierarchy_level=hierarchy_level + 1,
            config=config,
            parent_output_path=output_path
        )
    print(f"==================== FINISHED HIERARCHY LEVEL {hierarchy_level + 1} ====================")

## 4. Main Pipeline

This is the main execution block. It will perform all the steps from the original script in sequence.

In [None]:
print("--- Loading and Preparing Initial Data ---")
# (Data loading and cleaning is the same)
main_df = pd.read_csv(config.CSV_PATH)
main_df = main_df.dropna(subset=[config.SEQ_COL, config.LABEL_COL, config.BINDING_COL])
main_df = main_df.drop_duplicates(subset=[config.SEQ_COL])
main_df = main_df.copy()
main_df.loc[main_df[config.BINDING_COL] == 0, config.LABEL_COL] = 'Non-Binder'

# --- NEW: THREE-WAY DATA SPLIT ---
print("\n--- Creating fixed Train, Validation, and Test sets ---")
# 1. Split off the final, untouchable test set
main_train_val_df, final_test_df = train_test_split(
    main_df, test_size=0.2, random_state=42, stratify=main_df[config.LABEL_COL]
)
# 2. Split the remaining data into a training set and a fixed validation set
main_train_df, global_validation_df = train_test_split(
    main_train_val_df, test_size=0.15, random_state=42, stratify=main_train_val_df[config.LABEL_COL]
)
print(f"Total training examples: {len(main_train_df)}")
print(f"Total (fixed) validation examples: {len(global_validation_df)}")
print(f"Total (hold-out) test examples: {len(final_test_df)}")

# --- LEVEL 0: GLOBAL FINE-TUNING ---
print("\n==================== STARTING LEVEL 0 (GLOBAL FINE-TUNING) ====================")
global_output_path = Path(config.OUTPUT_DIR) / "L0_Global_Finetuning"
global_model_path = run_finetuning_pipeline(
    train_df=main_train_df, # Pass the dedicated training set
    validation_df=global_validation_df, # Pass the dedicated validation set
    output_path=global_output_path,
    base_model_id=config.MODEL_ID,
    run_name="L0_Global",
    config=config
)
print("==================== FINISHED LEVEL 0 ====================")

# --- START HIERARCHICAL FINE-TUNING (LEVEL 1 and beyond) ---
if config.HIERARCHY_COLS:
    process_level(
        df_train=main_train_df, # Start recursion with the main train/validation sets
        df_validation=global_validation_df,
        parent_model_path=global_model_path,
        hierarchy_level=0,
        config=config,
        parent_output_path=global_output_path
    )
else:
    print("\nNo hierarchy columns defined. Skipping layered fine-tuning.")

print("\n[DONE] All hierarchical fine-tuning tasks completed!")

Final Evaluation

In [None]:
# Cell 9: Final Unbiased Evaluation on the Hold-Out Test Set for ALL Models

print("\n==================== FINAL EVALUATION ON HOLD-OUT TEST SET ====================")

# --- 1. Find all trained models ---
# Use rglob to recursively find all saved model directories within the main output folder.
model_tag = config.MODEL_ID.split('/')[-1]
all_model_paths = sorted(list(Path(config.OUTPUT_DIR).rglob(f"finetuned_{model_tag}")))

if not all_model_paths:
    print("No trained models found to evaluate. Please ensure the training process completed successfully.")
else:
    print(f"Found {len(all_model_paths)} models to evaluate on the hold-out test set.")

# --- 2. Loop through each model and evaluate it ---
for model_to_evaluate_path in all_model_paths:
    # Use the parent directory's name for a descriptive title (e.g., L0_Global, L1_Subclass_A)
    run_name = model_to_evaluate_path.parent.name
    print(f"\n{'='*25} EVALUATING: {run_name} {'='*25}")
    print(f"Model Path: {model_to_evaluate_path}")

    device = get_device()
    try:
        final_model = AutoModelForSequenceClassification.from_pretrained(model_to_evaluate_path)
        final_model.to(device)
        # The tokenizer is consistent across all fine-tuning stages
        tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID)
    except Exception as e:
        print(f"Could not load model from {model_to_evaluate_path}. Skipping. Error: {e}")
        continue # Move to the next model

    # --- 3. Prepare the Test Dataset ---
    # We must use the exact 'label2id' mapping saved with each specific model.
    with open(model_to_evaluate_path / 'config.json', 'r') as f:
        model_config_data = json.load(f)
        label2id = model_config_data['label2id']
        id2label = model_config_data['id2label']

    # Map the text labels from the test set to the numerical IDs of the current model
    # and filter out any test data points whose labels this specific model was not trained on.
    final_test_df['numerical_labels'] = final_test_df[config.LABEL_COL].map(label2id)
    test_df_filtered = final_test_df.dropna(subset=['numerical_labels'])
    
    if test_df_filtered.empty:
        print(f"No applicable data found in the hold-out test set for the labels of model '{run_name}'. Skipping.")
        continue

    test_labels = test_df_filtered['numerical_labels'].astype(int).tolist()
    test_sequences = test_df_filtered[config.SEQ_COL].tolist()
    
    test_encodings = tokenizer(test_sequences, truncation=True, padding=True)
    test_dataset = ProteinDataset(test_encodings, test_labels)

    # --- 4. Use the Trainer to Get Predictions ---
    final_trainer = Trainer(model=final_model)
    predictions, _, _ = final_trainer.predict(test_dataset)
    y_pred = np.argmax(predictions, axis=1)

    # --- 5. Generate and Save a Unique Report ---
    print(f"\n--- Final Unbiased Classification Report for {run_name} ---")
    target_names = [id2label[str(i)] for i in sorted(id2label.keys(), key=int)] # Ensure correct label order
    final_report_text = classification_report(
        test_labels, y_pred, target_names=target_names, zero_division=0
    )
    print(final_report_text)

    # Save the report to a unique file in the main output directory
    final_report_path = Path(config.OUTPUT_DIR) / f'FINAL_TEST_REPORT_{run_name}.txt'
    with open(final_report_path, 'w') as f:
        f.write(f"--- Final Model Evaluation on the Hold-Out Test Set ---\n\n")
        f.write(f"Model: {run_name}\n")
        f.write(f"Model Path: {model_to_evaluate_path}\n\n")
        f.write(final_report_text)
    print(f"Final, unbiased report saved to: {final_report_path}")

print(f"\n{'='*25} All Evaluations Complete {'='*25}")