In [1]:
from google.colab import drive
import os
drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Efficient_AI_Project/EarlyExit_Experiments"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
CLASSIFIER_MODEL_NAME = "komal/gpt2-medium-classifier_finetuning_24layer"
MODEL_FOLDER_NAME = "lr5e-05_wd0.01_ep3_drop0.0_lossW0.9-0.8-0.7-0.6-0.5-0.4-0.3-0.2"
CHECKPOINT_PATH = (
    f"{PROJECT_ROOT}/"
    f"{CLASSIFIER_MODEL_NAME}/"
    f"{MODEL_FOLDER_NAME}/"
    f"best_model/pytorch_model.bin"
)

assert os.path.exists(CHECKPOINT_PATH), "❌ Classifier checkpoint not found!"

In [3]:
MLP_MODEL_NAME = "komal/gpt2-medium-MLP"

PERSISTENT_BASE_DIR = f"{PROJECT_ROOT}/{MLP_MODEL_NAME}"
os.makedirs(PERSISTENT_BASE_DIR, exist_ok=True)

In [4]:
import sys

!git clone https://github.com/komalniraula/adaptive-inference-llm

repo_name = 'adaptive-inference-llm' # Must match the folder created by git clone
project_path = os.path.join('/content', repo_name)

# Append the project root directory to the system path

sys.path.append(project_path)

fatal: destination path 'adaptive-inference-llm' already exists and is not an empty directory.


In [5]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
import random
import itertools
import json
import time
import pandas as pd

from transformers import AutoModelForCausalLM, AutoTokenizer

from evaluation.dataset_loaders.sst2 import load_sst2

In [6]:
# -----------------------------
# Load datasets
# -----------------------------
# Load the labeled SST-2 train split (67k samples)
sst2_train_data = load_sst2(task='train', fraction=1)

# Load the SST-2 validation split (872 samples) for testing
# This is equivalent to your previous sst2_test = ds["validation"]
sst2_test_data = load_sst2(task='test', fraction=1)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [7]:
# -----------------------------
# Convert datasets to (text,label)
# -----------------------------
# The external load_sst2 function should already return a list/Dataset
# where elements can be converted.
# We adapt the conversion based on the expected output of your loader module.

def convert_sst2(sample):
    # Assuming your module returns samples with 'text' and 'label' keys
    # based on the preprocess function in your original loader.
    if "text" in sample:
        return sample["text"], int(sample["label"])
    elif "sentence" in sample:
        # Fallback for the raw SST-2 key
        return sample["sentence"], int(sample["label"])
    else:
         raise ValueError("SST-2 sample missing expected keys for conversion.")

# Create data pairs
train_pairs = [convert_sst2(x) for x in sst2_train_data]
test_pairs  = [convert_sst2(x) for x in sst2_test_data]

print(f"Total train samples: {len(train_pairs)}")
print(f"Total test samples: {len(test_pairs)}")


# -----------------------------
# Dataset class
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

class SentimentDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, label = self.data[idx]
        enc = self.tokenizer(
            text,
            add_special_tokens=False,
            return_tensors="pt"
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }


# -----------------------------
# Dynamic padding collate_fn
# -----------------------------
def collate_fn(batch):
    input_ids = [b["input_ids"] for b in batch]
    labels = torch.tensor([b["labels"] for b in batch])

    max_len = max(len(ids) for ids in input_ids)
    pad_id = tokenizer.pad_token_id

    padded = torch.full((len(batch), max_len), pad_id)
    for i, ids in enumerate(input_ids):
        padded[i, :len(ids)] = ids

    attention_mask = (padded != tokenizer.eos_token_id).float()


    return {
        "input_ids": padded,
        "attention_mask": attention_mask,
        "labels": labels
    }

Total train samples: 67349
Total test samples: 872


In [8]:
# -----------------------------
# Create datasets
# -----------------------------
train_ds = SentimentDataset(train_pairs, tokenizer)
test_ds = SentimentDataset(test_pairs, tokenizer)

# -----------------------------
# Create dataloaders
# -----------------------------
train_loader = DataLoader(
    train_ds,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_ds,
    batch_size=16,
    shuffle=False,
    collate_fn=collate_fn
)

print("Train loader batches:", len(train_loader))
print("Test loader batches:", len(test_loader))

# Optional: Inspect first batch
example = next(iter(train_loader))
print(example["input_ids"].shape)
print(example["attention_mask"].shape)
print(example["labels"])

Train loader batches: 4210
Test loader batches: 55
torch.Size([16, 30])
torch.Size([16, 30])
tensor([0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0])


In [9]:
class ExitMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)   # binary: exit vs continue
        )

    def forward(self, x):
        return torch.sigmoid(self.net(x))

class GPT2EarlyExitClassifier(nn.Module):
    def __init__(self, model_name, exit_layers, hyperparameters):
        super().__init__()

        # -----------------------------
        # Backbone (already fine-tuned)
        # -----------------------------
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            output_hidden_states=True,
            return_dict=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.exit_layers = sorted(exit_layers)
        self.hp = hyperparameters
        self.num_labels = self.hp.get("num_labels", 2)
        dropout_rate = self.hp.get("dropout", 0.0)

        hidden_size = self.model.config.hidden_size

        # -----------------------------
        # Pre-trained exit classifiers
        # -----------------------------
        self.exit_heads = nn.ModuleDict()
        for layer in self.exit_layers:
            self.exit_heads[str(layer)] = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_size, self.num_labels)
            )

        # -----------------------------
        # Predictive gating MLPs
        # input = [hidden_state, max_prob, entropy]
        # -----------------------------
        mlp_input_dim = hidden_size + 2
        self.exit_mlps = nn.ModuleDict()

        for layer in self.exit_layers:
            self.exit_mlps[str(layer)] = ExitMLP(
                input_dim=mlp_input_dim,
                hidden_dim=self.hp.get("mlp_hidden_dim", 128),
                dropout=dropout_rate
            )

        self.gate_loss_weight = self.hp.get("gate_loss_weight", 1.0)

        # -----------------------------
        # Freeze backbone + exit heads
        # -----------------------------
        for p in self.model.parameters():
            p.requires_grad = False

        for p in self.exit_heads.parameters():
            p.requires_grad = False

    def forward(self, input_ids, attention_mask=None, exit_labels=None):
        """
        Args:
            exit_labels: dict[layer] -> binary labels
                         1 = EXIT, 0 = CONTINUE

        Returns:
            gate loss
            logits per exit (for analysis)
            exit probabilities per exit
        """

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )

        hidden_states = outputs.hidden_states

        logits_dict = {}
        exit_probs_dict = {}
        total_loss = 0.0

        for layer in self.exit_layers:

            # [B, seq_len, H] → [B, H]
            cls_vec = hidden_states[layer][:, -1, :]

            # ---- Pretrained classifier (frozen) ----
            logits = self.exit_heads[str(layer)](cls_vec)
            logits_dict[layer] = logits

            # ---- Confidence features ----
            probs = torch.softmax(logits, dim=-1)
            max_prob, _ = probs.max(dim=-1, keepdim=True)
            entropy = -(probs * torch.log(probs + 1e-8)).sum(
                dim=-1, keepdim=True
            )

            # ---- Gating MLP ----
            mlp_input = torch.cat(
                [cls_vec, max_prob, entropy],
                dim=-1
            )

            p_exit = self.exit_mlps[str(layer)](mlp_input)
            exit_probs_dict[layer] = p_exit

            # ---- Gate loss only ----
            if exit_labels is not None:
                gate_loss = F.binary_cross_entropy(
                    p_exit.squeeze(-1),
                    exit_labels[layer].float()
                )
                total_loss += self.gate_loss_weight * gate_loss

        return {
            "loss": total_loss if exit_labels is not None else None,
            "logits": logits_dict,
            "exit_probs": exit_probs_dict
        }

In [10]:
hyperparameter_grid = {
    # model config (fixed)
    "num_labels": [2],
    "dropout": [0.0],

    # MLP gate architecture
    "mlp_hidden_dim": [64, 128],

    # gate loss scaling
    "gate_loss_weight": [1.0],

    # training hyperparams (gate-only)
    "learning_rate": [1e-4, 3e-4, 3e-5],
    "weight_decay": [0.0],          # often unnecessary for tiny MLPs
    "num_epochs": [2],
    "batch_size": [16],

    # stability
    "max_grad_norm": [1.0],

    # logging
    "log_every": [500]
}


In [11]:
def make_experiment_name(hp):
    return (
        f"mlpH{hp['mlp_hidden_dim']}_"
        f"lr{hp['learning_rate']}_"
        f"wd{hp['weight_decay']}_"
        f"ep{hp['num_epochs']}_"
        f"drop{hp['dropout']}_"
        f"batch_size{hp['batch_size']}_"
        f"gateW{hp['gate_loss_weight']}"
    )

In [12]:
@torch.no_grad()
def evaluate_mlp_gates(model, data_loader, device):

    model.eval()
    total_loss = 0.0
    correct, total = 0, 0

    for batch in data_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        y_true = batch["labels"]

        logits_out = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            exit_labels=None
        )

        exit_labels = {}
        for layer in model.exit_layers:
            preds = logits_out["logits"][layer].argmax(dim=-1)
            exit_labels[layer] = (preds == y_true).long()

        out = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            exit_labels=exit_labels
        )

        total_loss += out["loss"].item()

        for layer in model.exit_layers:
            p_exit = out["exit_probs"][layer].squeeze(-1)
            pred_exit = (p_exit > 0.5).long()
            correct += (pred_exit == exit_labels[layer]).sum().item()
            total += exit_labels[layer].numel()

    return total_loss / len(data_loader), correct / total

In [13]:
def train_mlp_gates(model, train_loader, val_loader, hp, device):

    exp_name = make_experiment_name(hp)
    epoch_history = []

    optimizer = torch.optim.AdamW(
        model.exit_mlps.parameters(),
        lr=hp["learning_rate"],
        weight_decay=hp["weight_decay"]
    )

    exp_dir = os.path.join(PERSISTENT_BASE_DIR, exp_name)
    os.makedirs(exp_dir, exist_ok=True)

    best_val_gate_loss = float("inf")
    best_epoch_metrics = None

    for epoch in range(hp["num_epochs"]):
        model.train()
        total_train_loss = 0.0

        for step, batch in enumerate(train_loader, start=1):
            batch = {k: v.to(device) for k, v in batch.items()}
            y_true = batch["labels"]

            # -------------------------------
            # Generate exit labels (deterministic)
            # -------------------------------
            model.eval()
            with torch.no_grad():
                logits_out = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    exit_labels=None
                )

                exit_labels = {}
                for layer in model.exit_layers:
                    preds = logits_out["logits"][layer].argmax(dim=-1)
                    exit_labels[layer] = (preds == y_true).long()
            model.train()

            # -------------------------------
            # Train gates
            # -------------------------------
            out = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                exit_labels=exit_labels
            )

            loss = out["loss"]
            total_train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                model.exit_mlps.parameters(),
                hp["max_grad_norm"]
            )
            optimizer.step()

            if step % hp["log_every"] == 0:
              print(
                  f"[Epoch {epoch+1}/{hp['num_epochs']}] "
                  f"Step {step}/{len(train_loader)} | "
                  f"Gate Loss: {loss.item():.4f}"
              )

        avg_train_gate_loss = total_train_loss / len(train_loader)

        val_gate_loss, val_gate_accuracy = evaluate_mlp_gates(
            model, val_loader, device
        )

        epoch_metrics = {
            "epoch": epoch + 1,
            "train_gate_loss": avg_train_gate_loss,
            "val_gate_loss": val_gate_loss,
            "val_gate_accuracy": val_gate_accuracy
        }
        epoch_history.append(epoch_metrics)

        if val_gate_loss < best_val_gate_loss:
            best_val_gate_loss = val_gate_loss
            best_epoch_metrics = epoch_metrics

            best_model_dir = os.path.join(exp_dir, "best_model")
            os.makedirs(best_model_dir, exist_ok=True)
            torch.save(
                model.exit_mlps.state_dict(),
                os.path.join(best_model_dir, "mlp_gates.bin")
            )

    return epoch_history, best_epoch_metrics

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# generate all hyperparameter combinations
keys = list(hyperparameter_grid.keys())
values = list(hyperparameter_grid.values())

all_experiment_results = {}

exit_layers = [3, 6, 9, 12, 15, 18, 21, 24]
device = "cuda" if torch.cuda.is_available() else "cpu"

for combo in itertools.product(*values):

    hp = dict(zip(keys, combo))
    hp["num_labels"] = 2

    exp_name = make_experiment_name(hp)
    print("\n==============================")
    print("RUNNING MLP GATE EXPERIMENT:", exp_name)
    print("Hyperparameters:", hp)
    print("==============================")

    # 1️⃣ Create model
    model = GPT2EarlyExitClassifier(
        model_name="gpt2-medium",
        exit_layers=exit_layers,
        hyperparameters=hp
    )

    # 2️⃣ Load fine-tuned classifier checkpoint
    state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(state_dict, strict=False)

    # 3️⃣ Move to GPU
    model = model.to(device)

    # 4️⃣ Freeze backbone + exit heads (extra safety)
    for p in model.model.parameters():
        p.requires_grad = False
    for p in model.exit_heads.parameters():
        p.requires_grad = False

    # 5️⃣ Train MLP gates
    train_history, best_epoch_metrics = train_mlp_gates(
        model=model,
        train_loader=train_loader,
        val_loader=test_loader,
        hp=hp,
        device=device
    )

    # 6️⃣ Store results
    all_experiment_results[exp_name] = {
        "hyperparameters": hp,
        "train_history": train_history,
        "best_epoch_metrics": best_epoch_metrics,
        "best_val_gate_loss": best_epoch_metrics["val_gate_loss"],
        "best_val_gate_accuracy": best_epoch_metrics["val_gate_accuracy"]
    }

    print(
        f">>> Finished {exp_name} | "
        f"Best Val Gate Loss: {best_epoch_metrics['val_gate_loss']:.4f} | "
        f"Best Val Gate Acc: {best_epoch_metrics['val_gate_accuracy']:.4f}"
    )


RUNNING MLP GATE EXPERIMENT: mlpH64_lr0.0001_wd0.0_ep2_drop0.0_batch_size16_gateW1.0
Hyperparameters: {'num_labels': 2, 'dropout': 0.0, 'mlp_hidden_dim': 64, 'gate_loss_weight': 1.0, 'learning_rate': 0.0001, 'weight_decay': 0.0, 'num_epochs': 2, 'batch_size': 16, 'max_grad_norm': 1.0, 'log_every': 500}


The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[Epoch 1/2] Step 500/4210 | Gate Loss: 0.7335
[Epoch 1/2] Step 1000/4210 | Gate Loss: 0.3618
[Epoch 1/2] Step 1500/4210 | Gate Loss: 1.8584
[Epoch 1/2] Step 2000/4210 | Gate Loss: 2.4675
[Epoch 1/2] Step 2500/4210 | Gate Loss: 0.5358
[Epoch 1/2] Step 3000/4210 | Gate Loss: 0.4075
[Epoch 1/2] Step 3500/4210 | Gate Loss: 2.4481
[Epoch 1/2] Step 4000/4210 | Gate Loss: 0.2797
[Epoch 2/2] Step 500/4210 | Gate Loss: 1.0572
[Epoch 2/2] Step 1000/4210 | Gate Loss: 0.6933
[Epoch 2/2] Step 1500/4210 | Gate Loss: 0.1141
[Epoch 2/2] Step 2000/4210 | Gate Loss: 2.5462
[Epoch 2/2] Step 2500/4210 | Gate Loss: 1.4532
[Epoch 2/2] Step 3000/4210 | Gate Loss: 2.2469
[Epoch 2/2] Step 3500/4210 | Gate Loss: 0.5971
[Epoch 2/2] Step 4000/4210 | Gate Loss: 0.3990
>>> Finished mlpH64_lr0.0001_wd0.0_ep2_drop0.0_batch_size16_gateW1.0 | Best Val Gate Loss: 2.6939 | Best Val Gate Acc: 0.8949

RUNNING MLP GATE EXPERIMENT: mlpH64_lr0.0003_wd0.0_ep2_drop0.0_batch_size16_gateW1.0
Hyperparameters: {'num_labels': 2, 'dr

In [15]:
all_experiment_results

{'mlpH64_lr0.0001_wd0.0_ep2_drop0.0_batch_size16_gateW1.0': {'hyperparameters': {'num_labels': 2,
   'dropout': 0.0,
   'mlp_hidden_dim': 64,
   'gate_loss_weight': 1.0,
   'learning_rate': 0.0001,
   'weight_decay': 0.0,
   'num_epochs': 2,
   'batch_size': 16,
   'max_grad_norm': 1.0,
   'log_every': 500},
  'train_history': [{'epoch': 1,
    'train_gate_loss': 1.2374307141814054,
    'val_gate_loss': 2.6938558592037722,
    'val_gate_accuracy': 0.8949254587155964},
   {'epoch': 2,
    'train_gate_loss': 1.1649047161198074,
    'val_gate_loss': 2.8530091236260806,
    'val_gate_accuracy': 0.8946387614678899}],
  'best_epoch_metrics': {'epoch': 1,
   'train_gate_loss': 1.2374307141814054,
   'val_gate_loss': 2.6938558592037722,
   'val_gate_accuracy': 0.8949254587155964},
  'best_val_gate_loss': 2.6938558592037722,
  'best_val_gate_accuracy': 0.8949254587155964},
 'mlpH64_lr0.0003_wd0.0_ep2_drop0.0_batch_size16_gateW1.0': {'hyperparameters': {'num_labels': 2,
   'dropout': 0.0,
   'ml

In [16]:
def flatten_experiment_results(all_results):
    """
    Flattens MLP-gate experiment results into a CSV-friendly list of dicts.
    Each row corresponds to ONE hyperparameter configuration
    and its BEST saved gate model.
    """
    flat_data = []

    for exp_name, data in all_results.items():
        row = {
            "experiment_name": exp_name
        }

        # -----------------------------
        # Best model metrics (MLP gate)
        # -----------------------------
        if "best_train_metrics" in data and data["best_train_metrics"] is not None:
            row["best_epoch"] = data["best_train_metrics"].get("epoch")
            row["best_train_gate_loss"] = data["best_train_metrics"].get("gate_loss")
        else:
            row["best_epoch"] = None
            row["best_train_gate_loss"] = None

        # Validation metrics (if available)
        row["val_gate_loss"] = data.get("val_gate_loss")
        row["val_gate_accuracy"] = data.get("val_gate_accuracy")

        # -----------------------------
        # Hyperparameters
        # -----------------------------
        for hp_key, hp_value in data["hyperparameters"].items():
            if isinstance(hp_value, list):
                row[f"hp_{hp_key}"] = str(hp_value)
            else:
                row[f"hp_{hp_key}"] = hp_value

        flat_data.append(row)

    return flat_data

print("\n\n--- Saving All Experiment Results to CSV ---")

flat_results = flatten_experiment_results(all_experiment_results)

results_df = pd.DataFrame(flat_results)

csv_filename = "all_mlp_gate_results.csv"
csv_path = os.path.join(PERSISTENT_BASE_DIR, csv_filename)

results_df.to_csv(csv_path, index=False)

print(f"✅ All experiment results saved to: {csv_path}")



--- Saving All Experiment Results to CSV ---
✅ All experiment results saved to: /content/drive/MyDrive/Efficient_AI_Project/EarlyExit_Experiments/komal/gpt2-medium-MLP/all_mlp_gate_results.csv


In [17]:
results_df

Unnamed: 0,experiment_name,best_epoch,best_train_gate_loss,val_gate_loss,val_gate_accuracy,hp_num_labels,hp_dropout,hp_mlp_hidden_dim,hp_gate_loss_weight,hp_learning_rate,hp_weight_decay,hp_num_epochs,hp_batch_size,hp_max_grad_norm,hp_log_every
0,mlpH64_lr0.0001_wd0.0_ep2_drop0.0_batch_size16...,,,,,2,0.0,64,1.0,0.0001,0.0,2,16,1.0,500
1,mlpH64_lr0.0003_wd0.0_ep2_drop0.0_batch_size16...,,,,,2,0.0,64,1.0,0.0003,0.0,2,16,1.0,500
2,mlpH64_lr3e-05_wd0.0_ep2_drop0.0_batch_size16_...,,,,,2,0.0,64,1.0,3e-05,0.0,2,16,1.0,500
3,mlpH128_lr0.0001_wd0.0_ep2_drop0.0_batch_size1...,,,,,2,0.0,128,1.0,0.0001,0.0,2,16,1.0,500
4,mlpH128_lr0.0003_wd0.0_ep2_drop0.0_batch_size1...,,,,,2,0.0,128,1.0,0.0003,0.0,2,16,1.0,500
5,mlpH128_lr3e-05_wd0.0_ep2_drop0.0_batch_size16...,,,,,2,0.0,128,1.0,3e-05,0.0,2,16,1.0,500
