# ESM Multi-Target Fine-Tuning and Embedding Comparison

This notebook walks through the process of using a pre-trained ESM model for a **multi-class classification** task. It generates protein embeddings, fine-tunes the model on multiple targets (including a 'Non-Binder' category), and compares the embedding space before and after fine-tuning using PCA and t-SNE.

## 1. Setup and Imports

First, let's install and import all the necessary libraries. Ensure you have a GPU available for faster training.

In [None]:
%pip install -q transformers datasets scikit-learn pandas torch seaborn matplotlib accelerate
%pip install -q esm biopython py3Dmol httpx

In [None]:
import os
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from scipy.special import softmax
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_fscore_support, accuracy_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
import hashlib
import json
import configparser

# ESM3 Imports for generation
from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
from esm.sdk import client as cl

# Biopython Imports
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
import py3Dmol
import requests

## 2. Configuration

Set all your parameters in this cell.

In [None]:
credentials = configparser.ConfigParser()
credentials.read('../credentials.ini')

class Config:
    # --- Input and Output ---
    # IMPORTANT: Make sure this CSV file is uploaded to your notebook environment!
    CSV_PATH = "../Data/TIMP_binder_MMP3_AB.csv"
    OUTPUT_DIR = "../esm_gen_out"
    SEQ_COL = "Full Seq" # Column with protein sequences
    LABEL_COL = "Target" # Column with target labels
    BINDING_COL = "Encoding" # Column indicating positive (1) or negative (0) binding

    # --- Model and Training ---
    MODEL_ID = "facebook/esm2_t33_650M_UR50D"
    #MODEL_ID = "facebook/esm2_t30_150M_UR50D"
    #MODEL_ID = "facebook/esm2_t6_8M_UR50D" # Smaller model for testing; switch to larger for better results
    #MODEL_ID = "Synthyra/ESMplusplus_small" # not functional yet
    TEST_SIZE = 0.2
    EPOCHS = 25
    LEARNING_RATE = 1e-3
    DECAY_TYPE = 'linear'
    BATCH_SIZE = 8
    FORCE_RETRAIN = False # Set to True to retrain even if a saved model exists

    # --- Generation Parameters ---
    HF_TOKEN = credentials['huggingFace']['token']
    FORGE_TOKEN = credentials['forge']['token']
    CANONICAL_TEMPLATE_UNIPROT_ID = "P35625" # Human TIMP-3
    TARGET_UNIPROT_ID = "P08254" # Human MMP3
    NUM_SEQUENCES_TO_GENERATE = 10000 # Number of new sequences to create for each class
    LOOP = "AB"

config = Config()

## 3. Helper Functions

These are the utility functions from the script for device detection, embedding computation, and plotting.

In [None]:
def get_device():
    """Detects and returns the available hardware device."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available(): # For Apple Silicon
        return torch.device("mps")
    else:
        return torch.device("cpu")
    
def get_uniprot_sequence(accession_id):
    """Fetches a protein sequence from UniProt."""
    url = f"https://www.uniprot.org/uniprot/{accession_id}.fasta"
    response = requests.get(url)
    if response.status_code == 200:
        # The response text is in FASTA format, we need to parse it
        fasta_data = response.text.splitlines()
        sequence = "".join(fasta_data[1:])
        return sequence
    else:
        print(f"Error fetching sequence for {accession_id}. Status code: {response.status_code}")
        return None

def compute_embeddings(sequences, model, tokenizer, device):
    """Computes embeddings for a list of sequences using the given model."""
    model.to(device)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for seq in tqdm(sequences, desc="Computing embeddings"):
            tokens = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1022).to(device)
            output = model(**tokens, output_hidden_states=True)
            # Use the last hidden state and average pool across the sequence length
            embedding = output.hidden_states[-1].mean(dim=1).squeeze().cpu().numpy()
            embeddings.append(embedding)
    return np.vstack(embeddings)

sns.set_palette(["#009100","#FF0000","#0000FF","#800080","#FFA500","#00FFFF","#FFC0CB","#A52A2A","#808000","#000000"])

def plot_single(data, labels, title, filename, hue_order=None, style=None):
    """Creates and saves a single scatter plot."""
    plt.figure(figsize=(12, 10))
    sns.scatterplot(x=data[:, 0], y=data[:, 1], hue=labels, style=style, s=50, alpha=0.7, hue_order=hue_order)
    plt.title(title, fontsize=16)
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.legend(title="Target")
    plt.tight_layout()
    plt.savefig(filename)
    plt.show() # Display the plot inline
    plt.close()

def plot_comparison(before_data, after_data, labels, reduction_method, out_dir, model_tag, timestamp, hue_order=None):
    """Creates and saves a 2-panel comparison plot."""
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    fig.suptitle(f'{reduction_method} Comparison ({model_tag})', fontsize=20)

    # Before fine-tuning
    sns.scatterplot(ax=axes[0], x=before_data[:, 0], y=before_data[:, 1], hue=labels, s=50, alpha=0.7, hue_order=hue_order)
    axes[0].set_title("Before Fine-Tuning", fontsize=16)
    axes[0].set_xlabel("Dimension 1")
    axes[0].set_ylabel("Dimension 2")
    axes[0].legend(title="Target")

    # After fine-tuning
    sns.scatterplot(ax=axes[1], x=after_data[:, 0], y=after_data[:, 1], hue=labels, s=50, alpha=0.7, hue_order=hue_order)
    axes[1].set_title("After Fine-Tuning", fontsize=16)
    axes[1].set_xlabel("Dimension 1")
    axes[1].set_ylabel("Dimension 2")
    axes[1].legend(title="Target")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    filename = out_dir / f"{reduction_method.lower()}_comparison_{model_tag}_{timestamp}.png"
    plt.savefig(filename)
    plt.show() # Display the plot inline
    plt.close()

class ProteinDataset(torch.utils.data.Dataset):
    """Custom PyTorch Dataset for protein sequences."""
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

    def __len__(self):
        return len(self.labels)

def compute_metrics_multiclass(pred):
    """Computes detailed classification metrics for the multi-class Trainer."""
    labels = pred.label_ids
    
    # Handle case where pred.predictions is a tuple
    preds_data = pred.predictions
    if isinstance(preds_data, tuple):
        preds_data = preds_data[0]  # take the logits
    preds = preds_data.argmax(-1)

    # Use 'weighted' average for multi-class classification to account for label imbalance
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

## 4. Main Pipeline

This is the main execution block. It will perform all the steps from the original script in sequence.

In [None]:
# Setup paths and parameters
out_dir = Path(config.OUTPUT_DIR)
out_dir.mkdir(parents=True, exist_ok=True)
device = get_device()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = config.MODEL_ID.split('/')[-1]

print(f"Using device: {device}")
print(f"Using model: {config.MODEL_ID}")

# Load and process data
print("--- Loading and Processing Data ---")
df = pd.read_csv(config.CSV_PATH)
df = df.dropna(subset=[config.SEQ_COL, config.LABEL_COL, config.BINDING_COL]) # Drop rows with missing critical info
df = df.drop_duplicates(subset=[config.SEQ_COL]) # Remove duplicate sequences
df = df.copy() # Avoid SettingWithCopyWarning

# --- Preprocessing Step: Treat non-binders as a separate class ---
df.loc[df[config.BINDING_COL] == 0, config.LABEL_COL] = 'Non-Binder'
print(f"Total sequences in dataset: {len(df)}")

# Convert text labels to integers
unique_labels = sorted(df[config.LABEL_COL].unique()) # Sort for consistent mapping
label2id = {label: i for i, label in enumerate(unique_labels)}
id2label = {i: label for label, i in label2id.items()}
df['numerical_labels'] = df[config.LABEL_COL].map(label2id)

num_labels = len(unique_labels)
print(f"Found {num_labels} unique classes for training: {', '.join(unique_labels)}")
print("\nClass distribution:")
print(df[config.LABEL_COL].value_counts())

sequences = df[config.SEQ_COL].tolist()
labels = df['numerical_labels'].tolist()
text_labels = df[config.LABEL_COL].tolist() # For plotting

### Step 4a: Pre-trained Model Embeddings

In [None]:
print("\n--- Step 4a: Pre-trained Model Embeddings ---")

# Caching setup for embeddings
sequences_hash = hashlib.md5("".join(sequences).encode()).hexdigest()
emb_cache_file = out_dir / f"embeddings_before_{model_tag}_{sequences_hash}.npy"

if emb_cache_file.exists():
    print(f"Loading pre-trained embeddings from cache: {emb_cache_file}")
    embeddings_before = np.load(emb_cache_file)
else:
    trust_code = 'esm' in config.MODEL_ID.lower()
    model = AutoModelForSequenceClassification.from_pretrained( 
        config.MODEL_ID, 
        num_labels=num_labels, 
        trust_remote_code=trust_code
    )
    if 'facebook/esm2' in config.MODEL_ID.lower():
        tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID) # facebook/esm2_t33_650M_UR50D uses AutoTokenizer
    else:
        tokenizer = model.tokenizer  # Use the tokenizer associated with the model

    embeddings_before = compute_embeddings(sequences, model, tokenizer, device)
    np.save(emb_cache_file, embeddings_before)
    print(f"Saved pre-trained embeddings to cache: {emb_cache_file}")

# PCA and t-SNE on pre-trained embeddings
print("\nRunning PCA and t-SNE on pre-trained embeddings...")
pca_before = PCA(n_components=2).fit_transform(embeddings_before)
tsne_before = TSNE(n_components=2, perplexity=min(30, len(df)-1), random_state=42).fit_transform(embeddings_before)

# ------------------ #
# Save results
pca_before_csv = out_dir / f'pca_before_{model_tag}_{timestamp}.csv'
tsne_before_csv = out_dir / f'tsne_before_{model_tag}_{timestamp}.csv'
pd.DataFrame(pca_before, columns=['PC1', 'PC2']).to_csv(pca_before_csv, index=False)
pd.DataFrame(tsne_before, columns=['TSNE1', 'TSNE2']).to_csv(tsne_before_csv, index=False)
print("Saved PCA and t-SNE results for pre-trained embeddings.")

### Step 4b: Fine-tuning Classification Head

In [None]:
print("\n--- Step 4b: Fine-tuning Classification Head ---")

finetuned_model_path = out_dir / f"finetuned_{model_tag}"

# We need the validation set later for the classification report
train_texts, val_texts, train_labels, val_labels = train_test_split(sequences, labels, test_size=config.TEST_SIZE, random_state=42, stratify=labels)

if not finetuned_model_path.exists() or config.FORCE_RETRAIN:
    if config.FORCE_RETRAIN and finetuned_model_path.exists():
        print("Forcing re-training.")
    
    trust_code = 'esm' in config.MODEL_ID.lower()
    model = AutoModelForSequenceClassification.from_pretrained(
        config.MODEL_ID, 
        num_labels=num_labels,
        id2label=id2label,
        label2id=label2id,
        trust_remote_code=trust_code
    )

    if 'facebook/esm2' in config.MODEL_ID.lower():
        tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID) # facebook/esm2... uses AutoTokenizer
    else:
        tokenizer = model.tokenizer  # Use the tokenizer associated with the model


    train_encodings = tokenizer(train_texts, truncation=True, padding=True)
    val_encodings = tokenizer(val_texts, truncation=True, padding=True)

    train_dataset = ProteinDataset(train_encodings, train_labels)
    val_dataset = ProteinDataset(val_encodings, val_labels)

    training_args = TrainingArguments(
        output_dir=str(out_dir / 'training_checkpoints'),
        num_train_epochs=config.EPOCHS,
        per_device_train_batch_size=config.BATCH_SIZE,
        per_device_eval_batch_size=config.BATCH_SIZE,
        warmup_steps=100,
        weight_decay=0.01,
        logging_dir=str(out_dir / 'logs'),
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1", # Using f1 for best model selection
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics_multiclass, # Use the multi-class metrics function
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

    trainer.train()
    trainer.save_model(finetuned_model_path)
    tokenizer.save_pretrained(finetuned_model_path)
    print(f"Fine-tuned model saved to {finetuned_model_path}")

    print("\n--- Generating Validation Curve ---")
    # Extract log history
    log_history = trainer.state.log_history
    
    # Use pandas to easily separate training and validation logs
    df = pd.DataFrame(log_history)
    
    # Get training and validation loss data
    train_logs = df[df['loss'].notna()].dropna(axis=1, how='all').reset_index(drop=True)
    eval_logs = df[df['eval_loss'].notna()].dropna(axis=1, how='all').reset_index(drop=True)

    # Plotting
    validation_curve_file = out_dir / f"validation_curve_{model_tag}_{timestamp}.png"
    plt.figure(figsize=(10, 6))
    plt.plot(train_logs['step'], train_logs['loss'], label='Training Loss')
    plt.plot(eval_logs['step'], eval_logs['eval_loss'], label='Validation Loss', marker='o')
    
    plt.title('Training and Validation Loss Curve')
    plt.xlabel('Training Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(validation_curve_file)
    plt.show()
    plt.close()

else:
    print(f"Found existing fine-tuned model. Loading from {finetuned_model_path}")

### Step 4c: Detailed Classification Report

In [None]:
print("\n--- Step 4c: Detailed Classification Report ---")

# Load the best model
model = AutoModelForSequenceClassification.from_pretrained(finetuned_model_path)
if 'facebook/esm2' in config.MODEL_ID.lower():
    tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID) # facebook/esm2... uses AutoTokenizer
else:
    tokenizer = model.tokenizer  # Use the tokenizer associated with the model

model.to(device)

# Create a dataset for the validation texts
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
val_dataset = ProteinDataset(val_encodings, val_labels)

# Re-initialize Trainer to use its `predict` method
trainer = Trainer(model=model)
predictions, _, _ = trainer.predict(val_dataset)
y_pred = np.argmax(predictions, axis=1)

print(f'Classification Report for model: {model_tag}\n')
report_text = classification_report(val_labels, y_pred, target_names=list(label2id.keys()), zero_division=0)
print(report_text)

# Save report to a file
class_report_path = out_dir / f'classification_report_{model_tag}_{timestamp}.txt'
with open(class_report_path, 'w') as f:
    f.write(report_text)
print(f"\nClassification report saved to {class_report_path}")

### Step 4d: Post-fine-tuning Model Embeddings

In [None]:
print("\n--- Step 4d: Post-fine-tuning Model Embeddings ---")

emb_after_cache_file = out_dir / f"embeddings_after_{model_tag}_{sequences_hash}.npy"

if emb_after_cache_file.exists() and not config.FORCE_RETRAIN:
    print(f"Loading fine-tuned embeddings from cache: {emb_after_cache_file}")
    embeddings_after = np.load(emb_after_cache_file)
else:
    # Model and tokenizer are already loaded from the previous step
    embeddings_after = compute_embeddings(sequences, model, tokenizer, device)
    np.save(emb_after_cache_file, embeddings_after)
    print(f"Saved fine-tuned embeddings to cache: {emb_after_cache_file}")

# PCA and t-SNE on fine-tuned embeddings
print("\nRunning PCA and t-SNE on fine-tuned embeddings...")
pca_after = PCA(n_components=2).fit_transform(embeddings_after)
tsne_after = TSNE(n_components=2, perplexity=min(30, len(df)-1), random_state=42).fit_transform(embeddings_after)

# Save results
pca_after_csv = out_dir / f'pca_after_{model_tag}_{timestamp}.csv'
tsne_after_csv = out_dir / f'tsne_after_{model_tag}_{timestamp}.csv'
pd.DataFrame(pca_after, columns=['PC1', 'PC2']).to_csv(pca_after_csv, index=False)
pd.DataFrame(tsne_after, columns=['TSNE1', 'TSNE2']).to_csv(tsne_after_csv, index=False)
print("Saved PCA and t-SNE results for fine-tuned embeddings.")

### Step 4e: Generating Plots and Comparison

In [None]:
print("\n--- Step 4e: Generating Plots ---")
hue_order = list(label2id.keys())

# Define filenames for summary
pca_before_png = out_dir / f"pca_before_{model_tag}_{timestamp}.png"
pca_after_png = out_dir / f"pca_after_{model_tag}_{timestamp}.png"
tsne_before_png = out_dir / f"tsne_before_{model_tag}_{timestamp}.png"
tsne_after_png = out_dir / f"tsne_after_{model_tag}_{timestamp}.png"
pca_comp_png = out_dir / f"pca_comparison_{model_tag}_{timestamp}.png"
tsne_comp_png = out_dir / f"tsne_comparison_{model_tag}_{timestamp}.png"

# Individual plots
print("\nPCA before fine-tuning:")
plot_single(pca_before, text_labels, "PCA - Before Fine-tuning", pca_before_png, hue_order=hue_order)
print("\nPCA after fine-tuning:")
plot_single(pca_after, text_labels, "PCA - After Fine-tuning", pca_after_png, hue_order=hue_order)
print("\nt-SNE before fine-tuning:")
plot_single(tsne_before, text_labels, "t-SNE - Before Fine-tuning", tsne_before_png, hue_order=hue_order)
print("\nt-SNE after fine-tuning:")
plot_single(tsne_after, text_labels, "t-SNE - After Fine-tuning", tsne_after_png, hue_order=hue_order)

# Comparison plots
print("\nPCA Comparison:")
plot_comparison(pca_before, pca_after, text_labels, "PCA", out_dir, model_tag, timestamp, hue_order=hue_order)
print("\nt-SNE Comparison:")
plot_comparison(tsne_before, tsne_after, text_labels, "t-SNE", out_dir, model_tag, timestamp, hue_order=hue_order)

## 5. Generate and Evaluate New Sequences

Now, using the fine-tuned model, we will generate new protein sequences based on examples from our dataset.

### Step 5a. Initialize ESM3 Generation Model

In [None]:
print("\n--- Step 5a: Initializing ESM3 Generation Model ---")
login(config.HF_TOKEN)

print("Setting up ESM3 generation model...")
try:
    gen_model: ESM3InferenceClient = ESM3.from_pretrained("esm3-open").to("cuda") # or "cpu"
    # gen_model: ESM3InferenceClient = cl("esm3-medium-2024-08", token=FORGE_TOKEN)
    print("ESM3 generation model loaded successfully.")
except Exception as e:
    print(f"Could not load ESM3 model. Error: {e}")
    gen_model = None # type: ignore


### Step 5b: Generate New Sequences for Each Class

In [None]:
all_generated_seqs = []

# Generate masks for different loops
loop = config.LOOP # Options: "AB", "C", "EF", "GH", "Multi"
if loop == "AB":
    begin_loop = 30
    end_loop = 36 # normally 35
elif loop == "C":
    begin_loop = 62
    end_loop = 68 # 78
elif loop == "EF":
    begin_loop = 92
    end_loop = 96
elif loop == "GH":
    begin_loop = 127
    end_loop = 137
elif loop == "Multi":
    begin_loop = 143
    end_loop = 153
else:
    begin_loop = 1
    end_loop = 2
loop_mask = "_" * (end_loop - begin_loop)

print("\n--- Step 5b: Generating New Sequences from a Canonical Template ---")

# Get canonical Human TIMP-3 sequence
print(f"Fetching canonical starting sequence from UniProt: {config.CANONICAL_TEMPLATE_UNIPROT_ID}")
template_sequence = get_uniprot_sequence(config.CANONICAL_TEMPLATE_UNIPROT_ID)

if template_sequence:
    template_sequence = template_sequence[23:] # Remove leading signal peptide for TIMP-3
    print(f"Successfully fetched Human TIMP-3 sequence (UniProt: {config.CANONICAL_TEMPLATE_UNIPROT_ID})")
    print(f"Sequence Length: {len(template_sequence)}")
    print(f"Using template: {template_sequence[:30]}...")
    
    # Define the loop region to be masked and regenerated
    print(f"Masking loop {loop} from position {begin_loop} to {end_loop} (length {end_loop - begin_loop})")
    masked_prompt = template_sequence[:begin_loop] + loop_mask + template_sequence[end_loop:]
    protein_prompt = ESMProtein(sequence=masked_prompt)

    for i in tqdm(range(config.NUM_SEQUENCES_TO_GENERATE), desc="Generating novel sequences"):
        generated_protein = gen_model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=8, temperature=1.5)) # default temp is 0.7
        all_generated_seqs.append({"sequence": generated_protein.sequence})

    print("\nProcessing generated sequences to find unique candidates and their frequencies...")
    temp_df = pd.DataFrame(all_generated_seqs)
    counts = temp_df['sequence'].value_counts()
    generated_df = pd.DataFrame({'sequence': counts.index, 'generation_count': counts.values})
    generated_df.to_csv(out_dir / "generated_sequences.csv", index=False)
    latex_table = generated_df.to_latex(index=False, caption=f"TIMP-3 {loop} Loop Variants", label=f"tab:timp3_loops_{loop}", escape=False)
    with open(out_dir / f"timp3_{loop}_loops_table.tex", 'w', encoding='utf-8') as f:
        f.write(latex_table)
    print(f"Generated {len(temp_df)} total sequences, resulting in {len(generated_df)} unique sequences.")
else:
    print("Could not fetch template sequence. Skipping generation.")
    generated_df = pd.DataFrame()

print(generated_df.head())

In [None]:
raise # DONT RUN THIS BLOCK YET
import pandas as pd
from tqdm import tqdm

# Assume 'config', 'get_unipur_sequence', 'ESMProtein', 'gen_model', 
# 'GenerationConfig', and 'out_dir' are defined elsewhere in your script.

# --- Main List for All Generated Sequences ---
# This is now outside the loop to collect sequences from all runs.
all_generated_seqs = []

# --- Define the loops you want to process ---
# This makes it easy to add or remove loops in the future.
loops_to_process = ["AB", "C", "EF", "GH", "Multi"]

# --- Fetch the template sequence once ---
print("\n--- Step 5b: Fetching Canonical Template ---")
print(f"Fetching canonical starting sequence from UniProt: {config.CANONICAL_TEMPLATE_UNIPROT_ID}")
template_sequence = get_uniprot_sequence(config.CANONICAL_TEMPLATE_UNIPROT_ID)

if template_sequence:
    template_sequence = template_sequence[23:] # Remove leading signal peptide for TIMP-3
    print(f"Successfully fetched Human TIMP-3 sequence (UniProt: {config.CANONICAL_TEMPLATE_UNIPROT_ID})")
    print(f"Sequence Length: {len(template_sequence)}")
    print(f"Using template: {template_sequence[:30]}...")

    # --- Iterate Over Each Defined Loop ---
    for loop in loops_to_process:
        print(f"\n--- Generating New Sequences for Loop: {loop} ---")

        # Set the mask coordinates for the current loop
        if loop == "AB":
            begin_loop = 30
            end_loop = 36
        elif loop == "C":
            begin_loop = 62
            end_loop = 68
        elif loop == "EF":
            begin_loop = 92
            end_loop = 96
        elif loop == "GH":
            begin_loop = 127
            end_loop = 137
        elif loop == "Multi":
            begin_loop = 143
            end_loop = 153
        else:
            # Skip if loop is not defined
            print(f"Warning: Loop '{loop}' is not defined. Skipping.")
            continue
        
        loop_mask = "_" * (end_loop - begin_loop)
        
        # Define the loop region to be masked and regenerated
        print(f"Masking loop {loop} from position {begin_loop} to {end_loop} (length {end_loop - begin_loop})")
        masked_prompt = template_sequence[:begin_loop] + loop_mask + template_sequence[end_loop:]
        protein_prompt = ESMProtein(sequence=masked_prompt)

        # Generate the specified number of sequences for the current loop
        for i in tqdm(range(config.NUM_SEQUENCES_TO_GENERATE), desc=f"Generating for loop {loop}"):
            generated_protein = gen_model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=8, temperature=1.5))
            # MODIFICATION: Add the current loop name along with the sequence
            all_generated_seqs.append({
                "sequence": generated_protein.sequence,
                "loop": loop  # This is the new column you wanted
            })

    # --- Process and Save All Results After the Loop is Done ---
    print("\n--- Processing All Generated Sequences ---")
    if all_generated_seqs:
        temp_df = pd.DataFrame(all_generated_seqs)
        
        # MODIFICATION: Group by both sequence and loop to get accurate counts
        generated_df = temp_df.groupby(['sequence', 'loop']).size().reset_index(name='generation_count')
        
        # Optional: Sort for better readability
        generated_df = generated_df.sort_values(by=['loop', 'generation_count'], ascending=[True, False])
        
        # Save the consolidated data to a single CSV file
        output_csv_path = out_dir / "all_generated_sequences.csv"
        generated_df.to_csv(output_csv_path, index=False)
        print(f"Saved all unique sequences to {output_csv_path}")

        # Generate a single LaTeX table for all loops
        latex_table = generated_df.to_latex(
            index=False,
            caption="All Generated TIMP-3 Loop Variants",
            label="tab:timp3_all_loops",
            escape=False,
            longtable=True # Good for long tables
        )
        output_tex_path = out_dir / "timp3_all_loops_table.tex"
        with open(output_tex_path, 'w', encoding='utf-8') as f:
            f.write(latex_table)
        print(f"Saved LaTeX table to {output_tex_path}")

        print(f"Generated {len(temp_df)} total sequences, resulting in {len(generated_df)} unique sequence/loop combinations.")
    else:
        print("No sequences were generated.")
        generated_df = pd.DataFrame()

else:
    print("Could not fetch template sequence. Skipping generation.")
    generated_df = pd.DataFrame()

print("\n--- Final DataFrame Head ---")
print(generated_df.head())

### Step 5c: Analyze Generated Sequences

In [None]:
print("\n--- Step 5c: Analyzing Generated Sequences with Fine-tuned Model ---")

if generated_df.empty:
    print("No generated sequences to analyze.")

original_sequences_set = set(df[config.SEQ_COL].tolist())
initial_generated_count = len(generated_df)
generated_df = generated_df[~generated_df['sequence'].isin(original_sequences_set)].reset_index(drop=True)
final_generated_count = len(generated_df)

print(f"Removed {initial_generated_count - final_generated_count} generated sequences that already existed in the original dataset.")
print(f"Proceeding to analyze {final_generated_count} novel, unique sequences.")

if final_generated_count == 0:
    print("No novel sequences remained after filtering against the original dataset. Skipping analysis of generated sequences.")

generated_sequences_list = generated_df['sequence'].tolist()

# Compute embeddings for the generated sequences
print("Computing embeddings for generated sequences...")
generated_embeddings = compute_embeddings(generated_sequences_list, model, tokenizer, device)

# Predict the class for each generated sequence
print("Predicting classes for generated sequences...")
gen_encodings = tokenizer(generated_sequences_list, truncation=True, padding=True)
# The labels here are just placeholders, they aren't used in prediction
gen_dataset = ProteinDataset(gen_encodings, [0] * len(generated_sequences_list))

gen_predictions = trainer.predict(gen_dataset)
pred_logits = gen_predictions.predictions
pred_probs = softmax(pred_logits, axis=1)
confidence_scores = np.max(pred_probs, axis=1)
predicted_class_ids = np.argmax(pred_logits, axis=1)

generated_df['predicted_class'] = [id2label[i] for i in predicted_class_ids]
generated_df['confidence'] = confidence_scores

generated_df.to_csv(out_dir / "generated_sequences_confidence.csv", index=False)

# Combine all data and run a single t-SNE for a consistent coordinate space
print("Running t-SNE on combined experimental and generated data...")
combined_embeddings = np.vstack([embeddings_after, generated_embeddings])
tsne_combined = TSNE(n_components=2, perplexity=min(30, len(combined_embeddings)-1), random_state=42).fit_transform(combined_embeddings)

# Create a master DataFrame for plotting
plot_df = pd.DataFrame()
plot_df['label'] = text_labels + generated_df['predicted_class'].tolist()
plot_df['source'] = ['Experimental'] * len(text_labels) + ['Generated'] * len(generated_df)
plot_df['tsne1'] = tsne_combined[:, 0]
plot_df['tsne2'] = tsne_combined[:, 1]

### Step 5d: Genetate Plots From Generation

In [None]:
# Generate the requested plots
print("Generating new t-SNE plots...")
experimental_df = plot_df[plot_df['source'] == 'Experimental']
generated_plot_df = plot_df[plot_df['source'] == 'Generated']

# Define filenames for the plots
combined_plot_filename = out_dir / f"tsne_combined_{model_tag}_{timestamp}.png"
experimental_only_filename = out_dir / f"tsne_experimental_only_{model_tag}_{timestamp}.png"
generated_only_filename = out_dir / f"tsne_generated_only_{model_tag}_{timestamp}.png"
side_by_side_filename = out_dir / f"tsne_side_by_side_{model_tag}_{timestamp}.png"

# --- Combined Plot ---
print("Generating combined t-SNE plot...")
plot_single(plot_df[['tsne1', 'tsne2']].values, plot_df['label'],
            "t-SNE of Experimental and Generated Sequences", combined_plot_filename,
            hue_order=hue_order, style=plot_df['source'])

# --- Experimental Data Only ---
print("Generating experimental data only t-SNE plot...")
plot_single(experimental_df[['tsne1', 'tsne2']].values, experimental_df['label'],
            "t-SNE of Experimental Data Only", experimental_only_filename,
            hue_order=hue_order)

# --- Generated Data Only ---
print("Generating generated data only t-SNE plot...")
plot_single(generated_plot_df[['tsne1', 'tsne2']].values, generated_plot_df['label'],
            "t-SNE of Generated Data Only (Colored by Predicted Class)", generated_only_filename,
            hue_order=hue_order)

# --- Side-by-Side Comparison ---
print("Generating side-by-side t-SNE comparison plot...")
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
fig.suptitle('t-SNE Comparison: Experimental vs. Generated Data', fontsize=20)
sns.scatterplot(ax=axes[0], data=experimental_df, x='tsne1', y='tsne2', hue='label', s=50, alpha=0.7, hue_order=hue_order)
axes[0].set_title("Experimental Data", fontsize=16)
sns.scatterplot(ax=axes[1], data=generated_plot_df, x='tsne1', y='tsne2', hue='label', s=50, alpha=0.7, hue_order=hue_order)
axes[1].set_title("Generated Data (Predicted)", fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig(side_by_side_filename)   
plt.show(); plt.close()

## 6. Top 10 Generated Sequences

This section identifies and displays the top 10 most confidently predicted sequences for each target class.

In [None]:
if not generated_df.empty:
    print("\n--- Top 10 Most Confident Generated Sequences Per Predicted Class ---")

    for class_label in unique_labels:
        print(f"\n\n{'='*50}")
        print(f"Top 10 for class: {class_label}")
        print(f"{'='*50}")

        top_10 = generated_df[generated_df['predicted_class'] == class_label].sort_values('confidence', ascending=False).head(10)
        latex_table = top_10.to_latex(index=False, caption=f"TIMP-3 {loop} Loop Variants", label=f"tab:timp3_loops_{loop}", escape=False)
        with open(out_dir / f"timp3_{loop}_loops_table_{class_label}_top.tex", 'w', encoding='utf-8') as f:
            f.write(latex_table)

        if top_10.empty:
            print("No generated sequences were confidently predicted for this class.")
        else:
            for i, row in top_10.iterrows():
                print(f"  Confidence: {row.confidence:.4f} | Sequence: {row.sequence}")
else:
    print("\n--- No generated sequences to analyze ---")


## 7. Validate Top Candidates

This section performs computational validation on the top-ranked sequences. We will check for structural integrity (pLDDT), prepare files for binding prediction, and assess sequence plausibility.



In [None]:
# First, let's collect the top unique sequences from the previous step into a dictionary.
top_candidates = {}
if not generated_df.empty and 'predicted_class' in generated_df.columns:
    for class_label in unique_labels:
        #if class_label == 'Non-Binder': continue
        
        top_sequences_for_class = generated_df[
            generated_df['predicted_class'] == class_label
        ].sort_values(
            by=['confidence', 'generation_count'], ascending=[False, False]
        ).head(10)
        
        if not top_sequences_for_class.empty:
            top_candidates[class_label] = top_sequences_for_class['sequence'].tolist()

if not top_candidates:
    print("No top candidates were identified in the previous step. Skipping validation.")

### Heat map

Create heat maps of the amino acid generation frequency for all, then for each class. Then generate grouped heat maps for various biochemical properties of amino acids

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

# --- New Section: Amino Acid Group Definitions ---
# You can customize these groups or add new ones as needed.
GROUPINGS = {
    'charge': {
        'Positive': ['R', 'H', 'K'],
        'Negative': ['D', 'E'],
        'Neutral': ['A', 'N', 'C', 'Q', 'G', 'I', 'L', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
    },
    'hydropathy': {
        'Hydrophobic': ['A', 'V', 'I', 'L', 'M', 'F', 'Y', 'W'],
        'Hydrophilic': ['R', 'N', 'D', 'C', 'Q', 'E', 'H', 'K', 'S', 'T'],
        'Neutral': ['G', 'P']
    },
    'size': {
        'Tiny': ['A', 'C', 'G', 'S'],
        'Small': ['D', 'N', 'P', 'T', 'V'],
        'Medium': ['E', 'H', 'I', 'L', 'K', 'M', 'Q'],
        'Large': ['F', 'R', 'W', 'Y']
    },
    'chemical': {
        'Aliphatic': ['A', 'G', 'I', 'L', 'P', 'V'],
        'Aromatic': ['F', 'W', 'Y'],
        'Acidic': ['D', 'E'],
        'Basic': ['R', 'H', 'K'],
        'Hydroxylic': ['S', 'T'],
        'Amide': ['N', 'Q'],
        'Sulfur': ['C', 'M']
    }
}

# --- New Section: Helper Function for Grouped Heatmaps ---
def generate_grouped_heatmap(freq_df, group_dict, title, filename):
    """
    Generates and saves a heatmap based on grouped amino acid frequencies.

    Args:
        freq_df (pd.DataFrame): DataFrame with amino acids as index and positions as columns.
        group_dict (dict): Dictionary defining the amino acid groups.
        title (str): Title for the plot.
        filename (str or Path): Path to save the output image.
    """
    # Create a mapping from each amino acid to its group
    aa_to_group_map = {aa: group for group, aa_list in group_dict.items() for aa in aa_list}
    
    # Use the map to group the DataFrame and sum the frequencies
    grouped_freq_df = freq_df.groupby(aa_to_group_map).sum()
    
    # Ensure the order of groups in the heatmap is consistent
    group_order = list(group_dict.keys())
    grouped_freq_df = grouped_freq_df.reindex(group_order)

    # Generate and save the heatmap
    plt.figure(figsize=(10, max(4, len(group_order) * 0.8)))
    sns.heatmap(grouped_freq_df, cmap='viridis', annot=True, fmt=".2f", linewidths=.5)
    plt.title(title, fontsize=16)
    plt.xlabel('Position in Sequence', fontsize=12)
    plt.ylabel('Amino Acid Group', fontsize=12)
    plt.tight_layout()
    plt.savefig(filename)
    plt.show()
    plt.close()
    print(f"  Saved grouped heatmap: '{filename}'")

generated_df = pd.read_csv(out_dir / "generated_sequences_confidence.csv") # Reread from file
begin_loop = 30; end_loop = 36

amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
loop_positions = list(range(begin_loop, end_loop))

all_candidates_by_class = {}
if 'predicted_class' in generated_df.columns:
    all_candidates_by_class = generated_df.groupby('predicted_class')['sequence'].apply(list).to_dict()

if not all_candidates_by_class:
    print("No sequences were found or the 'predicted_class' column is missing.")
else:
    # === Generate a single heatmap for ALL sequences combined ===
    print("Generating combined heatmap for all sequences...")
    all_sequences = [seq for sublist in all_candidates_by_class.values() for seq in sublist]
    loop_sequences = [seq[begin_loop:end_loop] for seq in all_sequences]

    freq_df_all = pd.DataFrame(0, index=list(amino_acids), columns=loop_positions)
    for loop in loop_sequences:
        if len(loop) == (end_loop - begin_loop):
            for i, aa in enumerate(loop):
                position = begin_loop + i
                if aa in freq_df_all.index:
                    freq_df_all.loc[aa, position] += 1

    if all_sequences:
        freq_df_all = freq_df_all / len(all_sequences)

    freq_df_all.columns = [pos + 1 for pos in freq_df_all.columns]

    plt.figure(figsize=(8, 10))
    sns.heatmap(freq_df_all, cmap='viridis', annot=False) # Annot set to False for clarity on dense plots
    plt.title('Amino Acid Frequency in Loop (All Candidates)', fontsize=16)
    plt.xlabel('Position in Sequence', fontsize=12)
    plt.ylabel('Amino Acid', fontsize=12)
    plt.tight_layout()
    all_filename = out_dir / f'heatmap_all_candidates_individual_{config.LOOP}_{timestamp}.png'
    plt.savefig(all_filename)
    plt.show()
    plt.close()
    print(f"Saved '{all_filename}'")

    # --- New: Generate Grouped Heatmaps for ALL sequences ---
    print("\nGenerating grouped heatmaps for all sequences...")
    for group_name, group_dict in GROUPINGS.items():
        title = f'Grouped AA Frequency ({group_name.capitalize()}) - All Candidates'
        filename = out_dir / f'heatmap_all_candidates_grouped_by_{group_name}_{config.LOOP}_{timestamp}.png'
        generate_grouped_heatmap(freq_df_all, group_dict, title, filename)

    # === Generate a separate heatmap for EACH class ===
    print("\nGenerating separate heatmaps for each class...")
    for class_label, sequences in all_candidates_by_class.items():
        print(f"- Processing class: {class_label}")
        loop_sequences_class = [seq[begin_loop:end_loop] for seq in sequences]

        freq_df_class = pd.DataFrame(0, index=list(amino_acids), columns=loop_positions)
        for loop in loop_sequences_class:
            if len(loop) == (end_loop - begin_loop):
                for i, aa in enumerate(loop):
                    position = begin_loop + i
                    if aa in freq_df_class.index:
                        freq_df_class.loc[aa, position] += 1
        
        if sequences:
            freq_df_class = freq_df_class / len(sequences)

        freq_df_class.columns = [pos + 1 for pos in freq_df_class.columns]

        plt.figure(figsize=(8, 10))
        sns.heatmap(freq_df_class, cmap='viridis', annot=False) # Annot set to False
        plt.title(f'Amino Acid Frequency for Class: {class_label}', fontsize=16)
        plt.xlabel('Position in Sequence', fontsize=12)
        plt.ylabel('Amino Acid', fontsize=12)
        plt.tight_layout()
        
        class_filename = out_dir / f'heatmap_for_{class_label}_individual_{timestamp}.png'
        plt.savefig(class_filename)
        plt.show()
        plt.close()
        print(f"  Saved '{class_filename}'")

        # --- New: Generate Grouped Heatmaps for EACH class ---
        print(f"  Generating grouped heatmaps for class: {class_label}...")
        for group_name, group_dict in GROUPINGS.items():
            title = f'Grouped AA Frequency ({group_name.capitalize()}) - Class: {class_label}'
            filename = out_dir / f'heatmap_for_{class_label}_grouped_by_{group_name}_{config.LOOP}_{timestamp}.png'
            generate_grouped_heatmap(freq_df_class, group_dict, title, filename)

    print("\nAll heatmaps generated successfully.")

### Step 7a: Predict 3D Structure and Assess Integrity (pLDDT)
A good candidate should fold into a high-confidence 3D structure. ESM3 can predict this structure and provide a pLDDT score (predicted Local Distance Difference Test) for each amino acid, where >90 is very high confidence and >70 is generally confident. We'll visualize the top candidate for the first class.

In [None]:
print("\n--- Predicting 3D structures for top candidates ---")
validation_dir = out_dir / "validation_outputs"
validation_dir.mkdir(exist_ok=True)
pdb_files = {}

if gen_model:
    for class_label, sequences_list in top_candidates.items():
        print(f"\nProcessing top candidates for: {class_label}")
        pdb_files[class_label] = []
        for i, seq in enumerate(sequences_list):
            try:
                protein = ESMProtein(sequence=seq)
                # Use the generation model to predict structure
                structure = gen_model.generate(protein, GenerationConfig(track="structure", num_steps=8))
                
                pdb_filename = validation_dir / f"{class_label}_candidate_{i+1}.pdb"
                structure.to_pdb(str(pdb_filename))
                pdb_files[class_label].append(str(pdb_filename))
                print(f"Saved structure for {class_label} candidate {i+1} to {pdb_filename}")
            except Exception as e:
                print(f"Could not predict structure for {class_label} candidate {i+1}. Error: {e}")
    
    first_class = list(top_candidates.keys())[0]
    first_pdb_file = pdb_files.get(first_class, [None])[0]

    if first_pdb_file:
        print(f"\n--- Visualizing structure for the #1 candidate of the '{first_class}' class ---")
        with open(first_pdb_file, 'r') as f:
            pdb_data = f.read()

        view = py3Dmol.view(width=800, height=600)
        view.addModel(pdb_data, "pdb")
        view.setStyle({'cartoon': {'colorscheme': 'pLDDT'}})
        view.zoomTo()
        view.show()
else:
    print("ESM3 generation model not loaded. Skipping structure prediction.")

### Step 7b: Prepare for Binding Prediction (Co-folding)
To predict if the variant binds to MMP-9, we can use a co-folding tool like AlphaFold-Multimer or ESMFold-Multimer. This requires a FASTA file containing both the target (MMP-9) and the candidate sequence. This step prepares those files.

The FASTA files are then uploaded to Google drive where ColabFold will check the folding and binding properties (https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/batch/AlphaFold2_batch.ipynb)


In [None]:
print("\n--- Preparing FASTA files for co-folding analysis ---")
fasta_dir = validation_dir / "cofold_fasta_files"
fasta_dir.mkdir(exist_ok=True)

uniprot_mmp3 = "P08254"
uniprot_mmp9 = "P14780"
uniprot_adam17 = "P78536"
for class_label, sequences_list in top_candidates.items():
    if class_label == "MMP3":
        target_sequence = get_uniprot_sequence(uniprot_mmp3)
    elif class_label == "MMP9":
        target_sequence = get_uniprot_sequence(uniprot_mmp9)
    elif class_label == "ADAM17":
        target_sequence = get_uniprot_sequence(uniprot_adam17)
    else:
        target_sequence = get_uniprot_sequence(config.TARGET_UNIPROT_ID)
    #print(f"Fetching target sequence from UniProt: {config.TARGET_UNIPROT_ID}")
    print(f"Sequence Length: {len(target_sequence)}aa")
    print(f"Sequence: {target_sequence[:60]}...") # Uncomment to view
    for i, seq in enumerate(sequences_list):
        fasta_filename = fasta_dir / f"{class_label}_candidate_{i+1}_vs_TIMP3.fasta"
        with open(fasta_filename, 'w') as f:
            f.write(f">{class_label} Target Protein\n{target_sequence}\n")
            f.write(f">{class_label}_Candidate_{i+1} Predicted Binder\n{seq}\n")

        complex_fasta_filename = fasta_dir / f"complex_{class_label}_candidate_{i+1}.fasta"
        with open(complex_fasta_filename, 'w') as f:
            f.write(f">{class_label}_complex_candidate_{i+1}\n")
            # Write both sequences on the same line, separated by a colon
            f.write(f"{target_sequence}:{seq}\n")
print(f"\nFASTA files for all top candidates are saved in '{fasta_dir}'.")
print("These files can now be submitted to a co-folding service like AlphaFold-Multimer.")

### Step 7c: Analyze AlphaFold Output from File

Automated Batch Validation of ColabFold Outputs. This cell automates the analysis of all ColabFold `.zip` outputs.
**Action:**
1.  Run your candidate FASTA files through the ColabFold notebook.
2.  Download the resulting `.zip` files.
3.  Place all of them inside the `validation_outputs` directory created by the previous step.
4.  Run this cell. It will generate a full report for every candidate.

In [None]:
import zipfile
from pathlib import Path
from IPython.display import display, HTML

# --- SETUP: Define the directory where you placed your zip files ---
out_dir = Path(config.OUTPUT_DIR)
validation_dir = out_dir / "validation_outputs"
if not validation_dir.exists():
    print(f"Creating directory: {validation_dir}")
    validation_dir.mkdir()

print(f"Searching for ColabFold .zip outputs in: {validation_dir}")

# Find all zip files in the directory
zip_files = list(validation_dir.glob("*.zip"))

candidate_num = "001"

if not zip_files:
    print("\nERROR: No .zip files found in the validation directory.")
    print("Please make sure to place your ColabFold outputs there before running this cell.")
else:
    print(f"Found {len(zip_files)} candidate zip files to analyze.")
    
    results_list = []

    # --- LOOP THROUGH EACH ZIP FILE AND ANALYZE ---
    for zip_file_path in sorted(zip_files):
        candidate_name = zip_file_path.stem
        display(HTML(f"<hr><h2>Analyzing Candidate: {candidate_name}</h2>"))
        
        # Unzip the contents into a subdirectory
        unzip_dir = validation_dir / candidate_name
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(unzip_dir)
        
        # Find the files for the top-ranked model (rank 1)
        try:
            scores_file = next(unzip_dir.glob(f"*_scores_rank_{candidate_num}_*.json"))
            pae_file = next(unzip_dir.glob("*_predicted_aligned_error_v1.json"))
            pdb_file = next(unzip_dir.glob(f"*_unrelaxed_rank_{candidate_num}_*.pdb"))
        except StopIteration:
            print(f"WARNING: Could not find all required output files for rank 1 in {candidate_name}. Skipping.")
            continue

        # --- EXTRACT AND DISPLAY KEY SCORES ---
        with open(scores_file, 'r') as f:
            scores_data = json.load(f)
        
        ipTM_score = scores_data.get("iptm")
        mean_plddt = np.mean(scores_data.get("plddt", [0]))
        
        results_list.append({
            "Candidate": candidate_name,
            "ipTM": ipTM_score,
            "Mean pLDDT": mean_plddt
        })

        print("\n--- Confidence Scores ---")
        if ipTM_score is not None:
            print(f"Interface pTM (ipTM): {ipTM_score:.4f}")
            if ipTM_score > 0.85:
                print("   Interpretation: High confidence binding prediction.")
            else:
                print("   Interpretation: Low confidence binding prediction.")
        
        if mean_plddt > 0:
             print(f"Mean pLDDT: {mean_plddt:.2f} (Overall structure confidence)")
        
        # --- PLOT THE PAE MATRIX FROM RAW JSON DATA ---
        with open(pae_file, 'r') as f:
            pae_data = json.load(f)
        
        pae_matrix = pae_data.get('predicted_aligned_error')
        if pae_matrix:
            print("\n--- Predicted Aligned Error (PAE) Plot ---")
            plt.figure(figsize=(8, 6))
            plt.imshow(pae_matrix, cmap='Greens_r', vmin=0, vmax=30)
            plt.colorbar(label="Expected Position Error (Å)")
            plt.title(f"PAE Plot for {candidate_name}")
            plt.xlabel("Scored Residue")
            plt.ylabel("Aligned Residue")
            plt.savefig(validation_dir / f"PAE_{candidate_name}_{candidate_num}.png")
            plt.show()
            plt.close()
        
        # --- VISUALIZE THE 3D STRUCTURE ---
        print("\n--- Predicted 3D Structure (Colored by pLDDT) ---")
        with open(pdb_file, 'r') as f:
            pdb_data = f.read()

        view = py3Dmol.view(width=800, height=600)
        view.addModel(pdb_data, "pdb")
        view.setStyle({'cartoon': {'colorscheme': 'pLDDT'}})
        view.zoomTo()
        view.show()

In [None]:
# --- FINAL SUMMARY ---
if results_list:
    display(HTML("<hr><h1>Final Summary Ranking</h1>"))
    summary_df = pd.DataFrame(results_list)
    #summary_df['Cand Enc'] = pd.factorize(summary_df['Candidate'].str[:-10])[0]
    summary_df_sorted = summary_df.sort_values(by="ipTM", ascending=False).reset_index(drop=True)
    
    # Style the dataframe for better readability
    styled_summary_df = summary_df_sorted.style.background_gradient(subset=['ipTM'], cmap='viridis').format({'ipTM': '{:.4f}', 'Mean pLDDT': '{:.2f}'})
    styled_summary_df = styled_summary_df.background_gradient(subset=['Mean pLDDT'], cmap='viridis')
    #styled_summary_df = styled_summary_df.background_gradient(subset=['Cand Enc'], cmap='viridis')

    display(styled_summary_df)
    
    styled_summary_df.to_excel(validation_dir / 'AlphaFold_Summary.xlsx', engine='openpyxl', index=False)
    latex_table = styled_summary_df.hide(axis="index").to_latex(caption=f"AlphaFold2-multimer Summary", label=f"tab:af2_multi_sum", convert_css=True)
    latex_table = latex_table.replace("_", " ").replace(".result", "")
    with open(validation_dir / f"AlphaFold_Summary.tex", 'w', encoding='utf-8') as f:
        f.write(latex_table)

## 8. Final Summary

In [None]:
print("\n[DONE] All tasks completed!")
print("\n--- Summary of Output Files ---")
print(f"- Fine-tuned model saved to: {out_dir}")
print(f"- Classification report saved to: {class_report_path}")
print("\n- Embeddings:")
print(f"  - Before fine-tuning (cached): {emb_cache_file}")
print(f"  - After fine-tuning (cached): {emb_after_cache_file}")
print("\n- CSV Coordinates:")
print(f"  - PCA (before): {pca_before_csv}")
print(f"  - t-SNE (before): {tsne_before_csv}")
print(f"  - PCA (after): {pca_after_csv}")
print(f"  - t-SNE (after): {tsne_after_csv}")
print("\n- Plots:")
print(f"  - PCA Comparison: {pca_comp_png}")
print(f"  - t-SNE Comparison: {tsne_comp_png}")
print(f"  - Validation Curve: {validation_curve_file}")
print(f"  - Combined t-SNE Plot: {combined_plot_filename}")
print(f"  - Experimental Only t-SNE Plot: {experimental_only_filename}")
print(f"  - Generated Only t-SNE Plot: {generated_only_filename}")
print(f"  - Side-by-Side t-SNE Comparison: {side_by_side_filename}")
print("\nPlease check the output directory for all generated files.")
