In [1]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import random
import itertools
# -----------------------------
# Load datasets
# -----------------------------
def load_sst2():
    ds = load_dataset("glue", "sst2")
    return ds["train"], ds["validation"]

def load_rotten():
    ds = load_dataset("rotten_tomatoes")
    return ds["train"], ds["test"]

sst2_train, sst2_test = load_sst2()
rt_train, rt_test = load_rotten()


# -----------------------------
# Convert datasets to (text,label)
# -----------------------------
def convert_sst2(sample):
    return sample["sentence"], int(sample["label"])

def convert_rotten(sample):
    return sample["text"], int(sample["label"])

train_pairs = [convert_sst2(x) for x in sst2_train] + \
              [convert_rotten(x) for x in rt_train]

test_pairs  = [convert_sst2(x) for x in sst2_test] + \
              [convert_rotten(x) for x in rt_test]

# Optional shuffle
random.shuffle(train_pairs)
random.shuffle(test_pairs)

print(f"Total train samples: {len(train_pairs)}")
print(f"Total test samples: {len(test_pairs)}")


# -----------------------------
# Dataset class
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

class SentimentDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, label = self.data[idx]
        enc = self.tokenizer(
            text,
            add_special_tokens=False,
            return_tensors="pt"
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }


# -----------------------------
# Dynamic padding collate_fn
# -----------------------------
def collate_fn(batch):
    input_ids = [b["input_ids"] for b in batch]
    labels = torch.tensor([b["labels"] for b in batch])

    max_len = max(len(ids) for ids in input_ids)
    pad_id = tokenizer.pad_token_id

    padded = torch.full((len(batch), max_len), pad_id)
    for i, ids in enumerate(input_ids):
        padded[i, :len(ids)] = ids

    # attention mask: float
    attention_mask = (padded != tokenizer.eos_token_id).float()


    return {
        "input_ids": padded,
        "attention_mask": attention_mask,
        "labels": labels
    }

Total train samples: 75879
Total test samples: 1938


In [2]:
# -----------------------------
# Create datasets
# -----------------------------
train_ds = SentimentDataset(train_pairs, tokenizer)
test_ds = SentimentDataset(test_pairs, tokenizer)

# -----------------------------
# Create dataloaders
# -----------------------------
train_loader = DataLoader(
    train_ds,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_ds,
    batch_size=16,
    shuffle=False,
    collate_fn=collate_fn
)

print("Train loader batches:", len(train_loader))
print("Test loader batches:", len(test_loader))

# Optional: Inspect first batch
example = next(iter(train_loader))
print(example["input_ids"].shape)
print(example["attention_mask"].shape)
print(example["labels"])

Train loader batches: 4743
Test loader batches: 122
torch.Size([16, 44])
torch.Size([16, 44])
tensor([0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1])


In [3]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer


class GPT2EarlyExitClassifier(nn.Module):
    def __init__(self, model_name, exit_layers, hyperparameters):
        super().__init__()

        # Load GPT-2 as causal LM (we will use only hidden states)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            output_hidden_states=True,   # <-- IMPORTANT
            return_dict=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.exit_layers = sorted(exit_layers)
        self.hp = hyperparameters
        self.num_labels = self.hp.get("num_labels", 2)
        dropout_rate = self.hp.get("dropout", 0.0)

        # Loss weights Î»_e for each exit
        self.exit_loss_weights = self.hp.get(
            "exit_loss_weights",
            [1.0] * len(self.exit_layers)
        )

        hidden_size = self.model.config.hidden_size

        # Create classification heads for each exit layer
        self.exit_heads = nn.ModuleDict()
        for layer in self.exit_layers:
            self.exit_heads[str(layer)] = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_size, self.num_labels)
            )

        self.ce = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        Forward pass:
        - Calls GPT-2 normally (causal mask automatically handled)
        - Retrieves hidden_states for each layer
        - Applies classifier at each exit layer
        """

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )

        # hidden_states is a tuple:
        # [0] = embedding output
        # [1] = layer 1 output
        # ...
        # [12] = last layer output  (if GPT2 base)
        hidden_states = outputs.hidden_states

        logits_dict = {}
        total_loss = 0.0

        # For each early exit
        for i, layer in enumerate(self.exit_layers):

            # hidden_states[layer] has shape [batch, seq_len, hidden_dim]
            cls_vec = hidden_states[layer][:, -1, :]   # last token rep

            logits = self.exit_heads[str(layer)](cls_vec)
            logits_dict[layer] = logits

            # Add weighted loss
            if labels is not None:
                weight = self.exit_loss_weights[i]
                total_loss += weight * self.ce(logits, labels)

        return {
            "loss": total_loss if labels is not None else None,
            "logits": logits_dict
        }

In [9]:
hyperparameter_grid = {
    "num_labels": [2],                 # fixed
    "dropout": [0.0],
    "exit_loss_weights": [
        [1, 1, 1, 1, 1, 1, 1]],

    # training hyperparams
    "learning_rate": [1e-5],
    "weight_decay": [0.01],
    "num_epochs": [3],
    "max_grad_norm": [1.0],
    "batch_size": [16],

    # logging
    "log_every": [1000]
}

In [10]:
def make_experiment_name(hp):
    return (
        f"lr{hp['learning_rate']}_"
        f"wd{hp['weight_decay']}_"
        f"ep{hp['num_epochs']}_"
        f"drop{hp['dropout']}"
    )

In [11]:
def train(model, train_loader, test_loader, hp, device):

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=hp["learning_rate"],
        weight_decay=hp["weight_decay"]
    )

    for epoch in range(hp["num_epochs"]):
        model.train()
        total_train_loss = 0

        for step, batch in enumerate(train_loader, start=1):
            batch = {k: v.to(device) for k, v in batch.items()}

            out = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"]
            )

            loss = out["loss"]
            total_train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), hp["max_grad_norm"])
            optimizer.step()

            if step % hp["log_every"] == 0:
                print(f"Epoch {epoch+1} | Step {step} | Loss: {loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_loader)
        print(f"\n>>> Epoch {epoch+1} completed | Avg Train Loss: {avg_train_loss:.4f}")

        # Run evaluation
        evaluate(model, test_loader, device)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

exit_layers = [3, 6, 9, 12, 15, 18, 21]

# generate all hyperparameter combinations
keys = list(hyperparameter_grid.keys())
values = list(hyperparameter_grid.values())

for combo in itertools.product(*values):
    
    hp = dict(zip(keys, combo))
    hp["num_labels"] = 2  # fixed
    
    exp_name = make_experiment_name(hp)
    print("\n==============================")
    print("RUNNING EXPERIMENT:", exp_name)
    print("Hyperparameters:", hp)
    print("==============================")

    # 1. Create fresh model for this configuration
    model = GPT2EarlyExitClassifier(
        model_name="gpt2-medium",
        exit_layers=exit_layers,
        hyperparameters=hp
    ).to(device)

    # 2. Train
    train(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        hp=hp,
        device=device
    )

    # 3. Save checkpoint
    save_path = f"./checkpoints/{exp_name}"
    model.save_pretrained(save_path)
    print(f">>> Saved model to {save_path}")


RUNNING EXPERIMENT: lr1e-05_wd0.01_ep3_drop0.0
Hyperparameters: {'num_labels': 2, 'dropout': 0.0, 'exit_loss_weights': [1, 1, 1, 1, 1, 1, 1], 'learning_rate': 1e-05, 'weight_decay': 0.01, 'num_epochs': 3, 'max_grad_norm': 1.0, 'batch_size': 16, 'log_every': 1000}
