In [None]:
# ==============================================================================
# 1. SETUP AND IMPORTS
# ==============================================================================
# Install necessary libraries (Hugging Face Datasets, Transformers, PyTorch, scikit-learn)
!pip install -q datasets transformers torch scikit-learn tqdm

import os
import time
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix, roc_auc_score
from tqdm.auto import tqdm

# Set seeds for reproducibility, matching project standard
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Determine the processing device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Hardware device set to: {device}")

In [None]:
# ==============================================================================
# 2. CONFIGURATION (HYPERPARAMETERS)
# ==============================================================================
MODEL_NAME = "roberta-base"
DATA_DIR   = "imdb_tokenized_256"
HEAD_PATH  = "es_head_best.pt" # Caching file for best model head weights

# Evolution Strategies (ES) Hyperparameters
POP = 64          # Population size (number of parameter candidates per iteration)
SIGMA = 0.02      # Mutation noise scale (standard deviation of Gaussian noise)
ITERS = 200       # Total number of ES iterations (generations)
LR = 0.01         # Learning rate for the Adam optimizer used to update theta
EVAL_STEPS_PER_CAND = 8 # Number of batches to evaluate each candidate on
MAX_LENGTH = 256
BATCH_TRAIN = 64
BATCH_VAL = 128


In [None]:
# ==============================================================================
# 3. DATA PREPARATION AND CACHING
# ==============================================================================
if not os.path.exists(DATA_DIR):
    print("STATUS: Data cache not found. Loading and tokenizing IMDb dataset.")
    # Load dataset directly from Hugging Face for simplicity in Colab
    raw_dset = load_dataset("imdb")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            max_length=MAX_LENGTH
        )

    # Tokenize and format the dataset
    tokenized_dset = raw_dset.map(tokenize_function, batched=True)
    tokenized_dset = tokenized_dset.rename_column("label", "labels")
    tokenized_dset = tokenized_dset.remove_columns(["text"])

    # Create the 50/50 train/validation split
    dset = DatasetDict({
        "train": tokenized_dset["train"].select(range(25000)),
        "validation": tokenized_dset["train"].select(range(25000, 50000)),
    })

    # Save to disk (disk-based caching)
    dset.save_to_disk(DATA_DIR)
    print(f"STATUS: Dataset saved and cached to disk at '{DATA_DIR}'.")
else:
    print(f"STATUS: Loading cached dataset from '{DATA_DIR}'.")
    dset = DatasetDict.load_from_disk(DATA_DIR)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

# Prepare DataLoaders
collate = DataCollatorWithPadding(tokenizer=tokenizer)
train_loader = DataLoader(dset["train"], batch_size=BATCH_TRAIN, shuffle=True, collate_fn=collate, num_workers=2)
val_loader = DataLoader(dset["validation"], batch_size=BATCH_VAL, shuffle=False, collate_fn=collate, num_workers=2)


In [None]:
# ==============================================================================
# 4. MODEL SETUP AND ES UTILITIES
# ==============================================================================
# Load the base model and freeze the main body (RoBERTa encoder)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2).to(device)
for param in model.base_model.parameters():
    param.requires_grad = False
model.eval()

# Flatten the classification head parameters (ES-trainable parameters)
head_params = model.classifier.parameters()
theta = torch.cat([p.flatten() for p in head_params]).to(device)
DIM = theta.numel()

@torch.no_grad()
def unpack_head(theta_vec):
    """Restores the flattened theta vector into the model's classification head."""
    pointer = 0
    for p in model.classifier.parameters():
        num_elements = p.numel()
        p.data.copy_(theta_vec[pointer:pointer + num_elements].view_as(p))
        pointer += num_elements
    return model

def batch_loss(batch):
    """Calculates the loss for a single batch."""
    out = model(
        input_ids=batch["input_ids"].to(device),
        attention_mask=batch["attention_mask"].to(device),
        labels=batch["labels"].to(device),
    )
    return out.loss

@torch.no_grad()
def val_accuracy():
    """Calculates full validation set accuracy for ES fitness."""
    correct, total = 0, 0
    for batch in val_loader:
        out = model(
            input_ids=batch["input_ids"].to(device),
            attention_mask=batch["attention_mask"].to(device),
        )
        preds = out.logits.argmax(-1).cpu()
        y = batch["labels"]
        correct += (preds == y).sum().item()
        total += y.shape[0]
    return correct / total

# --- NEW METRIC: TRAINABLE PARAMETER FRACTION ---
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_fraction = trainable_params / total_params

print(f"\nModel Configuration:")
print(f"  Trainable Parameter Count: {trainable_params:,}")
print(f"  Total Parameter Count:     {total_params:,}")
print(f"  Trainable Fraction:        {trainable_fraction:.6f}")



In [None]:
# ==============================================================================
# 5. EVOLUTION STRATEGIES TRAINING LOOP
# ==============================================================================
print(f"\nSTATUS: Starting ES Optimization ({ITERS} Iterations)...")

best_theta = theta.clone()
best_acc = 0.0
optimizer = torch.optim.Adam([theta], lr=LR)
theta.requires_grad_(True)

start_time = time.time()
train_iterator = iter(train_loader)

for t in tqdm(range(1, ITERS + 1), desc="ES Generations"):
    noises = torch.randn(POP // 2, DIM).to(device)
    rewards = []

    for eps in noises:
        # Evaluate for +epsilon and -epsilon (Symmetric Sampling)
        for sign in (+1.0, -1.0):
            cand = theta.data + sign * SIGMA * eps
            unpack_head(cand)

            loss_sum, n = 0.0, 0
            for _ in range(EVAL_STEPS_PER_CAND):
                try:
                    batch = next(train_iterator)
                except StopIteration:
                    train_iterator = iter(train_loader)
                    batch = next(train_iterator)
                loss_sum += batch_loss(batch).item()
                n += 1
            avg_loss = loss_sum / n
            rewards.append(-avg_loss) # Reward is the negative of the loss

    rewards = torch.tensor(rewards).to(device)

    # Compute the ES gradient estimate
    rewards_matrix = rewards.view(POP // 2, 2)
    diff = (rewards_matrix[:, 0] - rewards_matrix[:, 1]).to(device) # r+ - r-
    grad_estimate = torch.matmul(diff, noises) / (2 * POP * SIGMA)

    # Update theta using Adam
    theta.grad = -grad_estimate # Minimize -Reward = Loss
    optimizer.step()

    # Periodically check validation accuracy and save the best weights
    if t % 10 == 0 or t == ITERS:
        unpack_head(theta.data)
        current_acc = val_accuracy()
        print(f"Iteration {t}/{ITERS} | Val Acc: {current_acc:.4f} | Best Acc: {best_acc:.4f}")

        if current_acc > best_acc:
            best_acc = current_acc
            best_theta.copy_(theta.data)
            # Model Weights Caching
            torch.save(unpack_head(best_theta).classifier.state_dict(), HEAD_PATH)
            print(f"INFO: New best head saved to: {HEAD_PATH}")

elapsed_train_time = time.time() - start_time
print(f"STATUS: ES Training complete. Duration: {elapsed_train_time:.1f} seconds.")


In [None]:
# ==============================================================================
# 6. FINAL EVALUATION AND REPORTING
# ==============================================================================
print("\nSTATUS: Starting Final Evaluation on Full Validation Set...")

# Load the best cached weights for final evaluation
if os.path.exists(HEAD_PATH):
    model.classifier.load_state_dict(torch.load(HEAD_PATH, map_location=device))
    print(f"INFO: Loaded best ES head from: {HEAD_PATH}")
else:
    print(f"WARNING: Best head weights file '{HEAD_PATH}' not found. Evaluating the final, un-cached ES head.")

model.eval()

# Inference loop: Collect labels and LOGITS (for ROC AUC)
labels, logits = [], []
with torch.no_grad():
    for batch in val_loader:
        out = model(
            input_ids=batch["input_ids"].to(device),
            attention_mask=batch["attention_mask"].to(device),
        )
        logits.append(out.logits.cpu().numpy())
        labels.append(batch["labels"].numpy())

labels = np.concatenate(labels)
logits = np.concatenate(logits)
preds = logits.argmax(-1) # Hard predictions

# --- NEW METRIC: ROC AUC ---
# Calculate the probability of the positive class (class 1)
probs = torch.softmax(torch.tensor(logits), dim=1)[:, 1].numpy()
roc_auc = roc_auc_score(labels, probs)

# Standard Metrics
acc = float(accuracy_score(labels, preds))
f1  = float(f1_score(labels, preds, average="macro"))
report = classification_report(labels, preds, digits=4)
cm = confusion_matrix(labels, preds)

# Print Final Summary Report
print("\n" + "="*60)
print(f"FINAL PERFORMANCE SUMMARY: {MODEL_NAME} + Evolution Strategies (Head-Only)")
print("="*60)
print(f"Metric: Val Accuracy:                {acc:.4f}")
print(f"Metric: Macro-F1:                    {f1:.4f}")
print(f"Metric: ROC AUC:                     {roc_auc:.4f}")
print("-" * 30)
print(f"Efficiency: ES Optimization Time (s):  {elapsed_train_time:.1f}")
print(f"Efficiency: Trainable Param Fraction: {trainable_fraction:.6f}")
print("\nClassification Report:\n", report)
print("Confusion Matrix:\n", cm)

# Save Outputs (Report and Predictions Caching)
pd.DataFrame({"label": labels, "pred": preds}).to_csv("es_val_predictions.csv", index=False)
with open("es_report.txt", "w") as f:
    f.write(f"Val Accuracy: {acc:.6f}\n")
    f.write(f"Macro-F1: {f1:.6f}\n")
    f.write(f"ROC AUC: {roc_auc:.6f}\n")
    f.write(f"Trainable Parameters: {trainable_params:,}\n")
    f.write(f"Trainable Fraction: {trainable_fraction:.6f}\n")
    f.write(f"ElapsedTrainSec: {elapsed_train_time:.1f}\n\n")
    f.write("Classification report:\n")
    f.write(report + "\n")
    f.write("Confusion matrix:\n")
    f.write(str(cm))

print("\nSTATUS: Final report saved to 'es_report.txt'")
print("STATUS: Validation predictions saved to 'es_val_predictions.csv'")