In [1]:
from google.colab import drive
import os
drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Efficient_AI_Project/EarlyExit_Experiments"

Mounted at /content/drive


In [2]:
CLASSIFIER_MODEL_NAME = "komal/gpt2-medium-classifier_finetuning_24layer"
MODEL_FOLDER_NAME = "lr5e-05_wd0.01_ep3_drop0.0_lossW0.9-0.8-0.7-0.6-0.5-0.4-0.3-0.2"
CHECKPOINT_PATH = (
    f"{PROJECT_ROOT}/"
    f"{CLASSIFIER_MODEL_NAME}/"
    f"{MODEL_FOLDER_NAME}/"
    f"best_model/pytorch_model.bin"
)

assert os.path.exists(CHECKPOINT_PATH), "‚ùå Classifier checkpoint not found!"

In [3]:
MLP_MODEL_NAME = "komal/gpt2-medium-MLP"

PERSISTENT_BASE_DIR = f"{PROJECT_ROOT}/{MLP_MODEL_NAME}"
os.makedirs(PERSISTENT_BASE_DIR, exist_ok=True)

In [4]:
import sys

!git clone https://github.com/komalniraula/adaptive-inference-llm

repo_name = 'adaptive-inference-llm' # Must match the folder created by git clone
project_path = os.path.join('/content', repo_name)

# Append the project root directory to the system path

sys.path.append(project_path)

Cloning into 'adaptive-inference-llm'...
remote: Enumerating objects: 281, done.[K
remote: Counting objects: 100% (281/281), done.[K
remote: Compressing objects: 100% (202/202), done.[K
remote: Total 281 (delta 137), reused 213 (delta 71), pack-reused 0 (from 0)[K
Receiving objects: 100% (281/281), 6.71 MiB | 20.58 MiB/s, done.
Resolving deltas: 100% (137/137), done.


In [5]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
import random
import itertools
import json
import time
import pandas as pd

from transformers import AutoModelForCausalLM, AutoTokenizer

from evaluation.dataset_loaders.sst2 import load_sst2

In [6]:
# -----------------------------
# Load datasets
# -----------------------------
# Load the labeled SST-2 train split (67k samples)
sst2_train_data = load_sst2(task='train', fraction=1)

# Load the SST-2 validation split (872 samples) for testing
# This is equivalent to your previous sst2_test = ds["validation"]
sst2_test_data = load_sst2(task='test', fraction=1)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

sst2/train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

sst2/validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

sst2/test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

In [7]:
# -----------------------------
# Convert datasets to (text,label)
# -----------------------------
# The external load_sst2 function should already return a list/Dataset
# where elements can be converted.
# We adapt the conversion based on the expected output of your loader module.

def convert_sst2(sample):
    # Assuming your module returns samples with 'text' and 'label' keys
    # based on the preprocess function in your original loader.
    if "text" in sample:
        return sample["text"], int(sample["label"])
    elif "sentence" in sample:
        # Fallback for the raw SST-2 key
        return sample["sentence"], int(sample["label"])
    else:
         raise ValueError("SST-2 sample missing expected keys for conversion.")

test_pairs  = [convert_sst2(x) for x in sst2_test_data]

print(f"Total test samples: {len(test_pairs)}")


# -----------------------------
# Dataset class
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

class SentimentDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, label = self.data[idx]
        enc = self.tokenizer(
            text,
            add_special_tokens=False,
            return_tensors="pt"
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }


# -----------------------------
# Dynamic padding collate_fn
# -----------------------------
def collate_fn(batch):
    input_ids = [b["input_ids"] for b in batch]
    labels = torch.tensor([b["labels"] for b in batch])

    max_len = max(len(ids) for ids in input_ids)
    pad_id = tokenizer.pad_token_id

    padded = torch.full((len(batch), max_len), pad_id)
    for i, ids in enumerate(input_ids):
        padded[i, :len(ids)] = ids

    attention_mask = (padded != tokenizer.eos_token_id).float()


    return {
        "input_ids": padded,
        "attention_mask": attention_mask,
        "labels": labels
    }

Total test samples: 872


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [8]:
test_ds = SentimentDataset(test_pairs, tokenizer)

test_loader = DataLoader(
    test_ds,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn
)

print("Test loader batches:", len(test_loader))

# Optional: Inspect first batch
example = next(iter(test_loader))
print(example["input_ids"].shape)
print(example["attention_mask"].shape)
print(example["labels"])

Test loader batches: 872
torch.Size([1, 27])
torch.Size([1, 27])
tensor([0])


In [9]:
class ExitMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return torch.sigmoid(self.net(x))


class GPT2EarlyExitClassifier(nn.Module):
    def __init__(self, model_name, exit_layers, mlp_hidden_dim, num_labels=2):
        super().__init__()

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            return_dict=True
        )

        self.exit_layers = sorted(exit_layers)
        self.num_labels = num_labels

        hidden_size = self.model.config.hidden_size

        # ---- classifier heads (must match training) ----
        self.exit_heads = nn.ModuleDict()
        for layer in self.exit_layers:
            self.exit_heads[str(layer)] = nn.Sequential(
                nn.Dropout(0.0),
                nn.Linear(hidden_size, self.num_labels)
            )

        # ---- MLP gates (must match training) ----
        self.exit_mlps = nn.ModuleDict()
        for layer in self.exit_layers:
            self.exit_mlps[str(layer)] = ExitMLP(
                input_dim=hidden_size + 2,
                hidden_dim=mlp_hidden_dim,      # must match training
                dropout=0.0
            )

    @torch.no_grad()
    def early_exit_forward(
        self,
        input_ids,
        attention_mask,
        gate_threshold=0.5
    ):
        """
        MLP-based adaptive inference:
        - executes GPT-2 layer-by-layer
        - exits when ExitMLP predicts EXIT
        """

        transformer = self.model.transformer
        exit_set = set(self.exit_layers)
        max_layer = self.exit_layers[-1]

        # ---- embeddings (HF-consistent) ----
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 0)

        hidden = transformer.wte(input_ids) + transformer.wpe(position_ids)

        # ---- layer-wise execution ----
        for layer_idx, block in enumerate(transformer.h, start=1):
            hidden = block(hidden, attention_mask=attention_mask)[0]

            if layer_idx in exit_set:
                cls_vec = hidden[:, -1, :]

                # ---- classifier ----
                logits = self.exit_heads[str(layer_idx)](cls_vec)
                probs = torch.softmax(logits, dim=-1)
                conf, pred = probs.max(dim=-1)

                # ---- MLP gate ----
                entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1, keepdim=True)
                mlp_input = torch.cat(
                    [cls_vec, conf.unsqueeze(-1), entropy],
                    dim=-1
                )

                p_exit = self.exit_mlps[str(layer_idx)](mlp_input)

                if p_exit.item() >= gate_threshold or layer_idx == max_layer:
                    return {
                        "prediction": pred.item(),
                        "confidence": conf.item(),
                        "exit_layer": layer_idx,
                        "p_exit": p_exit.item(),
                        "logits": logits
                    }

        # ---- safety fallback ----
        return {
            "prediction": pred.item(),
            "confidence": conf.item(),
            "exit_layer": max_layer,
            "p_exit": p_exit.item(),
            "logits": logits
        }

In [10]:
@torch.no_grad()
def adaptive_early_exit_eval_mlp(
    model,
    data_loader,
    device,
    gate_threshold=0.5,
):
    model.eval()

    results = {}
    correct = 0
    total = 0
    total_layers_used = 0
    total_tokens = 0
    start_time = time.time()

    transformer = model.model.transformer
    exit_set = set(model.exit_layers)
    max_layer = model.exit_layers[-1]

    for batch in data_loader:
        assert batch["input_ids"].size(0) == 1

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        label = batch["labels"].item()

        tokens = attention_mask.sum().item()
        total_tokens += tokens

        # ---- embeddings ----
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 0)

        hidden = transformer.wte(input_ids) + transformer.wpe(position_ids)

        exited = False

        for layer_idx, block in enumerate(transformer.h, start=1):
            hidden = block(hidden, attention_mask=attention_mask)[0]

            if layer_idx in exit_set:
                cls_vec = hidden[:, -1, :]

                # classifier head
                logits = model.exit_heads[str(layer_idx)](cls_vec)
                probs = torch.softmax(logits, dim=-1)
                pred = probs.argmax(dim=-1).item()

                # MLP gate
                max_prob = probs.max(dim=-1, keepdim=True)[0]
                entropy = -(probs * torch.log(probs + 1e-8)).sum(
                    dim=-1, keepdim=True
                )

                mlp_input = torch.cat(
                    [cls_vec, max_prob, entropy],
                    dim=-1
                )

                p_exit = model.exit_mlps[str(layer_idx)](mlp_input).item()

                if p_exit >= gate_threshold or layer_idx == max_layer:
                    correct += int(pred == label)
                    total_layers_used += layer_idx
                    total += 1
                    exited = True
                    break

        if not exited:
            raise RuntimeError("No exit triggered ‚Äî should never happen")

    elapsed = time.time() - start_time

    return {
        "accuracy": correct / total,
        "avg_layers_used": total_layers_used / total,
        "avg_latency_sec": elapsed / total,
        "tokens_per_sec": total_tokens / elapsed,
        "cost_saving_pct": 100 * (1 - (total_layers_used / total) / max_layer),
        "num_samples": total
    }

In [11]:
def evaluate_all_models_mlp(
    persistent_base_dir,
    exit_layers,
    test_loader,
    device,
    gate_thresholds,
):
    """
    Evaluates a single fine-tuned classifier with multiple trained MLP gates.
    """

    all_rows = []

    # ---------- sanity check ----------
    assert os.path.exists(CHECKPOINT_PATH), "Classifier checkpoint not found"

    for model_name in sorted(os.listdir(persistent_base_dir)):
        model_dir = os.path.join(persistent_base_dir, model_name)
        best_model_dir = os.path.join(model_dir, "best_model")
        mlp_ckpt = os.path.join(best_model_dir, "mlp_gates.bin")

        if not os.path.exists(mlp_ckpt):
            continue

        # parse hidden dim from folder name
        m = re.search(r"mlpH(\d+)", model_name)
        if m is None:
            print(f"‚ö†Ô∏è Skipping {model_name} (cannot parse mlpH)")
            continue

        mlp_hidden_dim = int(m.group(1))
        print(f"\nüöÄ Evaluating MLP gates: {model_name} | H={mlp_hidden_dim}")

        # ---------- load model ----------
        model = GPT2EarlyExitClassifier(
            model_name="gpt2-medium",
            exit_layers=exit_layers,
            mlp_hidden_dim=mlp_hidden_dim,
            num_labels=2,
        ).to(device)

        # load shared classifier
        model.load_state_dict(
            torch.load(CHECKPOINT_PATH, map_location=device),
            strict=False,
        )

        # load MLP gates
        model.exit_mlps.load_state_dict(
            torch.load(mlp_ckpt, map_location=device)
        )

        model.eval()

        # ---------- evaluate ----------
        for th in gate_thresholds:
            r = adaptive_early_exit_eval_mlp(
                model=model,
                data_loader=test_loader,
                device=device,
                gate_threshold=th,
            )

            all_rows.append({
                "model_name": model_name,
                "mlp_hidden_dim": mlp_hidden_dim,
                "gate_threshold": th,
                "accuracy": r["accuracy"],
                "avg_layers_used": r["avg_layers_used"],
                "avg_latency_sec": r["avg_latency_sec"],
                "cost_saving_pct": r["cost_saving_pct"],
                "tokens_per_sec": r["tokens_per_sec"],
            })

    df = pd.DataFrame(all_rows)

    if df.empty:
        print("‚ö†Ô∏è No valid MLP models evaluated.")
        return df

    return df.sort_values("avg_latency_sec").reset_index(drop=True)

In [13]:
import re
exit_layers = [3, 6, 9, 12, 15, 18, 21, 24]
THRESHOLDS = [0.5, 0.6, 0.7, 0.8, 0.9]
device = "cuda" if torch.cuda.is_available() else "cpu"

df_results = evaluate_all_models_mlp(
    persistent_base_dir=PERSISTENT_BASE_DIR,
    exit_layers=exit_layers,
    test_loader=test_loader,
    device = device,
    gate_thresholds=THRESHOLDS
)


üöÄ Evaluating MLP gates: mlpH128_lr0.0001_wd0.0_ep2_drop0.0_batch_size16_gateW1.0 | H=128


config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]


üöÄ Evaluating MLP gates: mlpH128_lr0.0003_wd0.0_ep2_drop0.0_batch_size16_gateW1.0 | H=128

üöÄ Evaluating MLP gates: mlpH128_lr3e-05_wd0.0_ep2_drop0.0_batch_size16_gateW1.0 | H=128

üöÄ Evaluating MLP gates: mlpH64_lr0.0001_wd0.0_ep2_drop0.0_batch_size16_gateW1.0 | H=64

üöÄ Evaluating MLP gates: mlpH64_lr0.0003_wd0.0_ep2_drop0.0_batch_size16_gateW1.0 | H=64

üöÄ Evaluating MLP gates: mlpH64_lr3e-05_wd0.0_ep2_drop0.0_batch_size16_gateW1.0 | H=64


In [16]:
pd.set_option("display.max_colwidth", None)
pd.set_option("display.width", None)

df_results = df_results.sort_values(
    by="avg_latency_sec",
    ascending=True
).reset_index(drop=True)

df_results

Unnamed: 0,model_name,mlp_hidden_dim,gate_threshold,accuracy,avg_layers_used,avg_latency_sec,cost_saving_pct,tokens_per_sec
0,mlpH128_lr0.0001_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,128,0.6,0.840596,3.0,0.003448,87.5,6963.389493
1,mlpH128_lr3e-05_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,128,0.5,0.841743,3.006881,0.003463,87.47133,6933.757557
2,mlpH128_lr0.0001_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,128,0.7,0.83945,3.048165,0.003475,87.299312,6909.540576
3,mlpH128_lr3e-05_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,128,0.6,0.83945,3.034404,0.003499,87.356651,6861.552672
4,mlpH64_lr0.0003_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,64,0.5,0.840596,3.0,0.003503,87.5,6853.884035
5,mlpH64_lr0.0003_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,64,0.6,0.841743,3.010321,0.003509,87.456995,6842.966905
6,mlpH128_lr0.0003_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,128,0.5,0.840596,3.0,0.003513,87.5,6834.555554
7,mlpH64_lr0.0003_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,64,0.7,0.840596,3.024083,0.003514,87.399656,6832.778256
8,mlpH64_lr0.0001_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,64,0.6,0.840596,3.013761,0.003528,87.442661,6806.271998
9,mlpH64_lr0.0001_wd0.0_ep2_drop0.0_batch_size16_gateW1.0,64,0.5,0.840596,3.0,0.003536,87.5,6790.482689


In [17]:
os.makedirs(PERSISTENT_BASE_DIR, exist_ok=True)

# Path
csv_path = os.path.join(
    PERSISTENT_BASE_DIR,
    "finetuned_model_early_exit_stats.csv"
)

# Save
df_results.to_csv(csv_path, index=False)

print(f"Saved df_results to: {csv_path}")

Saved df_results to: /content/drive/MyDrive/Efficient_AI_Project_/EarlyExit_Experiments/komal/gpt2-medium-MLP/finetuned_model_early_exit_stats.csv
