# ESM Multi-Target Fine-Tuning and Embedding Comparison

This notebook walks through the process of using a pre-trained ESM model for a **multi-class classification** task. It generates protein embeddings, fine-tunes the model on multiple targets (including a 'Non-Binder' category), and compares the embedding space before and after fine-tuning using PCA and t-SNE.

## 1. Setup and Imports

First, let's install and import all the necessary libraries. Ensure you have a GPU available for faster training.

In [None]:
%pip install -q transformers datasets scikit-learn pandas torch seaborn matplotlib

In [None]:
import os
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_fscore_support, accuracy_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
import hashlib
import json

## 2. Configuration

Set all your parameters in this cell. This replaces the command-line arguments from the original script.

In [None]:
class Config:
    # --- Input and Output ---
    # IMPORTANT: Make sure this CSV file is uploaded to your notebook environment!
    CSV_PATH = "../Data/TIMP_binder_MMP9_AB.csv"
    OUTPUT_DIR = "../Local/esm_multitarget_out"
    SEQ_COL = "Full Seq" # Column with protein sequences
    LABEL_COL = "Target" # Column with target labels
    BINDING_COL = "Encoding" # Column indicating positive (1) or negative (0) binding

    # --- Model and Training ---
    #MODEL_ID = "facebook/esm2_t6_8M_UR50D"
    MODEL_ID = "facebook/esm2_t30_150M_UR50D"
    #MODEL_ID = "facebook/esm2_t33_650M_UR50D"
    #MODEL_ID = "EvolutionaryScale/esmc-300m-2024-12" # Use a newer ESM-C model. Not functional yet
    TEST_SIZE = 0.2
    EPOCHS = 25
    LEARNING_RATE = 1e-3
    DECAY_TYPE = 'linear'
    BATCH_SIZE = 8
    FORCE_RETRAIN = False # Set to True to retrain even if a saved model exists
    TUNE_AGAIN = True

config = Config()

## 3. Helper Functions

These are the utility functions from the script for device detection, embedding computation, and plotting.

In [None]:
def get_device():
    """Detects and returns the available hardware device."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available(): # For Apple Silicon
        return torch.device("mps")
    else:
        return torch.device("cpu")

def compute_embeddings(sequences, model, tokenizer, device):
    """Computes embeddings for a list of sequences using the given model."""
    model.to(device)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for seq in tqdm(sequences, desc="Computing embeddings"):
            tokens = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1022).to(device)
            output = model(**tokens, output_hidden_states=True)
            # Use the last hidden state and average pool across the sequence length
            embedding = output.hidden_states[-1].mean(dim=1).squeeze().cpu().numpy()
            embeddings.append(embedding)
    return np.vstack(embeddings)

sns.set_palette(["#0000FF","#FF0000","#009100",])

def plot_single(data, labels, title, filename, hue_order=None):
    """Creates and saves a single scatter plot."""
    plt.figure(figsize=(10, 8))
    sns.scatterplot(x=data[:, 0], y=data[:, 1], hue=labels, s=50, alpha=0.7, hue_order=hue_order)
    plt.title(title, fontsize=16)
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.legend(title="Target")
    plt.tight_layout()
    plt.savefig(filename)
    plt.show() # Display the plot inline
    plt.close()

def plot_comparison(before_data, after_data, labels, reduction_method, out_dir, model_tag, timestamp, hue_order=None):
    """Creates and saves a 2-panel comparison plot."""
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    fig.suptitle(f'{reduction_method} Comparison ({model_tag})', fontsize=20)

    # Before fine-tuning
    sns.scatterplot(ax=axes[0], x=before_data[:, 0], y=before_data[:, 1], hue=labels, s=50, alpha=0.7, hue_order=hue_order)
    axes[0].set_title("Before Fine-Tuning", fontsize=16)
    axes[0].set_xlabel("Dimension 1")
    axes[0].set_ylabel("Dimension 2")
    axes[0].legend(title="Target")

    # After fine-tuning
    sns.scatterplot(ax=axes[1], x=after_data[:, 0], y=after_data[:, 1], hue=labels, s=50, alpha=0.7, hue_order=hue_order)
    axes[1].set_title("After Fine-Tuning", fontsize=16)
    axes[1].set_xlabel("Dimension 1")
    axes[1].set_ylabel("Dimension 2")
    axes[1].legend(title="Target")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    filename = out_dir / f"{reduction_method.lower()}_comparison_{model_tag}_{timestamp}.png"
    plt.savefig(filename)
    plt.show() # Display the plot inline
    plt.close()

class ProteinDataset(torch.utils.data.Dataset):
    """Custom PyTorch Dataset for protein sequences."""
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

    def __len__(self):
        return len(self.labels)

def compute_metrics_multiclass(pred):
    """Computes detailed classification metrics for the multi-class Trainer."""
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # Use 'weighted' average for multi-class classification to account for label imbalance
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

## 4. Main Pipeline

This is the main execution block. It will perform all the steps from the original script in sequence.

In [None]:
# Setup paths and parameters
out_dir = Path(config.OUTPUT_DIR)
out_dir.mkdir(parents=True, exist_ok=True)
device = get_device()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = config.MODEL_ID.split('/')[-1]

print(f"Using device: {device}")
print(f"Using model: {config.MODEL_ID}")

# Load and process data
print("--- Loading and Processing Data ---")
df = pd.read_csv(config.CSV_PATH)
df = df.dropna(subset=[config.SEQ_COL, config.LABEL_COL, config.BINDING_COL]) # Drop rows with missing critical info
df = df.drop_duplicates(subset=[config.SEQ_COL]) # Remove duplicate sequences
df = df.copy() # Avoid SettingWithCopyWarning

# --- Preprocessing Step: Treat non-binders as a separate class ---
df.loc[df[config.BINDING_COL] == 0, config.LABEL_COL] = 'Non-Binder'
print(f"Total sequences in dataset: {len(df)}")

# Convert text labels to integers
unique_labels = sorted(df[config.LABEL_COL].unique()) # Sort for consistent mapping
label2id = {label: i for i, label in enumerate(unique_labels)}
id2label = {i: label for label, i in label2id.items()}
df['numerical_labels'] = df[config.LABEL_COL].map(label2id)

num_labels = len(unique_labels)
print(f"Found {num_labels} unique classes for training: {', '.join(unique_labels)}")
print("\nClass distribution:")
print(df[config.LABEL_COL].value_counts())

sequences = df[config.SEQ_COL].tolist()
labels = df['numerical_labels'].tolist()
text_labels = df[config.LABEL_COL].tolist() # For plotting

### Step 1: Pre-trained Model Embeddings

In [None]:
print("\n--- Step 1: Pre-trained Model Embeddings ---")

# Caching setup for embeddings
sequences_hash = hashlib.md5("".join(sequences).encode()).hexdigest()
emb_cache_file = out_dir / f"embeddings_before_{model_tag}_{sequences_hash}.npy"

if emb_cache_file.exists():
    print(f"Loading pre-trained embeddings from cache: {emb_cache_file}")
    embeddings_before = np.load(emb_cache_file)
else:
    tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID)
    trust_code = 'esm' in config.MODEL_ID.lower()
    model = AutoModelForSequenceClassification.from_pretrained(
        config.MODEL_ID, 
        num_labels=num_labels, 
        trust_remote_code=trust_code
    )
    embeddings_before = compute_embeddings(sequences, model, tokenizer, device)
    np.save(emb_cache_file, embeddings_before)
    print(f"Saved pre-trained embeddings to cache: {emb_cache_file}")

# PCA and t-SNE on pre-trained embeddings
print("\nRunning PCA and t-SNE on pre-trained embeddings...")
pca_before = PCA(n_components=2).fit_transform(embeddings_before)
tsne_before = TSNE(n_components=2, perplexity=min(30, len(df)-1), random_state=42).fit_transform(embeddings_before)

# Save results
pca_before_csv = out_dir / f'pca_before_{model_tag}_{timestamp}.csv'
tsne_before_csv = out_dir / f'tsne_before_{model_tag}_{timestamp}.csv'
pd.DataFrame(pca_before, columns=['PC1', 'PC2']).to_csv(pca_before_csv, index=False)
pd.DataFrame(tsne_before, columns=['TSNE1', 'TSNE2']).to_csv(tsne_before_csv, index=False)
print("Saved PCA and t-SNE results for pre-trained embeddings.")

### Step 2: Fine-tuning Classification Head

In [None]:
print("\n--- Step 2: Fine-tuning Classification Head ---")

finetuned_model_path = out_dir / f"finetuned_{model_tag}"

# We need the validation set later for the classification report
train_texts, val_texts, train_labels, val_labels = train_test_split(sequences, labels, test_size=config.TEST_SIZE, random_state=42, stratify=labels)

if not finetuned_model_path.exists() or config.FORCE_RETRAIN or config.TUNE_AGAIN:
    if config.FORCE_RETRAIN and finetuned_model_path.exists():
        print("Forcing re-training.")

    if config.TUNE_AGAIN:
        print("Using stored weights and re-training")
        tokenizer = AutoTokenizer.from_pretrained(finetuned_model_path)
        model = AutoModelForSequenceClassification.from_pretrained(finetuned_model_path)
    else:
        print("Training model frome the source")
        tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID)
        trust_code = 'esm' in config.MODEL_ID.lower()
        model = AutoModelForSequenceClassification.from_pretrained(
            config.MODEL_ID, 
            num_labels=num_labels,
            id2label=id2label,
            label2id=label2id,
            trust_remote_code=trust_code
        )

    train_encodings = tokenizer(train_texts, truncation=True, padding=True)
    val_encodings = tokenizer(val_texts, truncation=True, padding=True)

    train_dataset = ProteinDataset(train_encodings, train_labels)
    val_dataset = ProteinDataset(val_encodings, val_labels)

    training_args = TrainingArguments(
        output_dir=str(out_dir / 'training_checkpoints'),
        num_train_epochs=config.EPOCHS,
        per_device_train_batch_size=config.BATCH_SIZE,
        per_device_eval_batch_size=config.BATCH_SIZE,
        warmup_steps=100,
        weight_decay=0.01,
        #learning_rate=config.LEARNING_RATE,
        #lr_scheduler_type=config.DECAY_TYPE,
        logging_dir=str(out_dir / 'logs'),
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1", # Using f1 for best model selection
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics_multiclass, # Use the multi-class metrics function
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

    trainer.train()
    trainer.save_model(finetuned_model_path)
    tokenizer.save_pretrained(finetuned_model_path)
    print(f"Fine-tuned model saved to {finetuned_model_path}")

    print("\n--- Generating Validation Curve ---")
    # Extract log history
    log_history = trainer.state.log_history
    
    # Use pandas to easily separate training and validation logs
    df = pd.DataFrame(log_history)
    
    # Get training and validation loss data
    train_logs = df[df['loss'].notna()].dropna(axis=1, how='all').reset_index(drop=True)
    eval_logs = df[df['eval_loss'].notna()].dropna(axis=1, how='all').reset_index(drop=True)

    # Plotting
    plt.figure(figsize=(10, 6))
    plt.plot(train_logs['step'], train_logs['loss'], label='Training Loss')
    plt.plot(eval_logs['step'], eval_logs['eval_loss'], label='Validation Loss', marker='o')
    
    plt.title('Training and Validation Loss Curve')
    plt.xlabel('Training Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(out_dir / f"validation_curve_{model_tag}_{timestamp}.png")
    plt.show()
    plt.close()

else:
    print(f"Found existing fine-tuned model. Loading from {finetuned_model_path}")

### Step 2b: Detailed Classification Report

In [None]:
print("\n--- Step 2b: Detailed Classification Report ---")

# Load the best model
tokenizer = AutoTokenizer.from_pretrained(finetuned_model_path)
model = AutoModelForSequenceClassification.from_pretrained(finetuned_model_path)
model.to(device)

# Create a dataset for the validation texts
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
val_dataset = ProteinDataset(val_encodings, val_labels)

# Re-initialize Trainer to use its `predict` method
trainer = Trainer(model=model)
predictions, _, _ = trainer.predict(val_dataset)
y_pred = np.argmax(predictions, axis=1)

print(f'Classification Report for model: {model_tag}\n')
report_text = classification_report(val_labels, y_pred, target_names=list(label2id.keys()), zero_division=0)
print(report_text)

# Save report to a file
class_report_path = out_dir / f'classification_report_{model_tag}_{timestamp}.txt'
with open(class_report_path, 'w') as f:
    f.write(report_text)
print(f"\nClassification report saved to {class_report_path}")

### Step 3: Post-fine-tuning Model Embeddings

In [None]:
print("\n--- Step 3: Post-fine-tuning Model Embeddings ---")

emb_after_cache_file = out_dir / f"embeddings_after_{model_tag}_{sequences_hash}.npy"

if emb_after_cache_file.exists() and not config.FORCE_RETRAIN:
    print(f"Loading fine-tuned embeddings from cache: {emb_after_cache_file}")
    embeddings_after = np.load(emb_after_cache_file)
else:
    tokenizer = AutoTokenizer.from_pretrained(finetuned_model_path)
    trust_code = 'esm' in config.MODEL_ID.lower()
    model = AutoModelForSequenceClassification.from_pretrained(
        finetuned_model_path, 
        trust_remote_code=trust_code
    )
    embeddings_after = compute_embeddings(sequences, model, tokenizer, device)
    np.save(emb_after_cache_file, embeddings_after)
    print(f"Saved fine-tuned embeddings to cache: {emb_after_cache_file}")

# PCA and t-SNE on fine-tuned embeddings
print("\nRunning PCA and t-SNE on fine-tuned embeddings...")
pca_after = PCA(n_components=2).fit_transform(embeddings_after)
tsne_after = TSNE(n_components=2, perplexity=min(30, len(df)-1), random_state=42).fit_transform(embeddings_after)

# Save results
pca_after_csv = out_dir / f'pca_after_{model_tag}_{timestamp}.csv'
tsne_after_csv = out_dir / f'tsne_after_{model_tag}_{timestamp}.csv'
pd.DataFrame(pca_after, columns=['PC1', 'PC2']).to_csv(pca_after_csv, index=False)
pd.DataFrame(tsne_after, columns=['TSNE1', 'TSNE2']).to_csv(tsne_after_csv, index=False)
print("Saved PCA and t-SNE results for fine-tuned embeddings.")

### Step 4: Generating Plots and Comparison

In [None]:
print("\n--- Step 4: Generating Plots ---")
hue_order = list(label2id.keys())

# Define filenames for summary
pca_before_png = out_dir / f"pca_before_{model_tag}_{timestamp}.png"
pca_after_png = out_dir / f"pca_after_{model_tag}_{timestamp}.png"
tsne_before_png = out_dir / f"tsne_before_{model_tag}_{timestamp}.png"
tsne_after_png = out_dir / f"tsne_after_{model_tag}_{timestamp}.png"
pca_comp_png = out_dir / f"pca_comparison_{model_tag}_{timestamp}.png"
tsne_comp_png = out_dir / f"tsne_comparison_{model_tag}_{timestamp}.png"

# Individual plots
print("\nPCA before fine-tuning:")
plot_single(pca_before, text_labels, "PCA - Before Fine-tuning", pca_before_png, hue_order=hue_order)
print("\nPCA after fine-tuning:")
plot_single(pca_after, text_labels, "PCA - After Fine-tuning", pca_after_png, hue_order=hue_order)
print("\nt-SNE before fine-tuning:")
plot_single(tsne_before, text_labels, "t-SNE - Before Fine-tuning", tsne_before_png, hue_order=hue_order)
print("\nt-SNE after fine-tuning:")
plot_single(tsne_after, text_labels, "t-SNE - After Fine-tuning", tsne_after_png, hue_order=hue_order)

# Comparison plots
print("\nPCA Comparison:")
plot_comparison(pca_before, pca_after, text_labels, "PCA", out_dir, model_tag, timestamp, hue_order=hue_order)
print("\nt-SNE Comparison:")
plot_comparison(tsne_before, tsne_after, text_labels, "t-SNE", out_dir, model_tag, timestamp, hue_order=hue_order)

### Step 5: Final Summary

In [None]:
print("\n[DONE] All tasks completed!")
print("\n--- Summary of Output Files ---")
print(f"- Fine-tuned model saved to: {finetuned_model_path}")
print(f"- Classification report saved to: {class_report_path}")
print("\n- Embeddings:")
print(f"  - Before fine-tuning (cached): {emb_cache_file}")
print(f"  - After fine-tuning (cached): {emb_after_cache_file}")
print("\n- CSV Coordinates:")
print(f"  - PCA (before): {pca_before_csv}")
print(f"  - t-SNE (before): {tsne_before_csv}")
print(f"  - PCA (after): {pca_after_csv}")
print(f"  - t-SNE (after): {tsne_after_csv}")
print("\n- Plots:")
print(f"  - PCA Comparison: {pca_comp_png}")
print(f"  - t-SNE Comparison: {tsne_comp_png}")