In [1]:
import pandas as pd
import numpy as np

In [2]:
import torch
from transformers import pipeline
from transformers import AutoTokenizer, ModernBertForSequenceClassification

import torch.nn.functional as F

In [12]:
import torch
import torch.nn.functional as F
from transformers import ModernBertForSequenceClassification, AutoTokenizer


class ModernBERTBayesianEarlyExit:
    def __init__(
        self,
        model_name="answerdotai/ModernBERT-base",
        exit_layers=[5, 10, 15, 20],
        dropout_passes=10,          # Number of MC samples
        uncertainty_threshold=0.02, # Lower = more confident
        use_entropy=False           # use variance (default) or entropy
    ):
        self.device = "cpu"

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.model = ModernBertForSequenceClassification.from_pretrained(
            model_name,
            num_labels=2
        )
        self.model.to(self.device)
        self.model.eval()  # BUT dropout will still work because we unfreeze it manually later

        # Extract structure
        self.embeddings = self.model.model.embeddings
        self.layers = self.model.model.layers
        self.final_norm = self.model.model.final_norm
        self.classifier = self.model.classifier

        self.exit_layers = sorted(exit_layers)
        self.dropout_passes = dropout_passes
        self.uncertainty_threshold = uncertainty_threshold
        self.use_entropy = use_entropy
        self.num_layers = len(self.layers)

        # Enable dropout modules even in eval mode
        self._enable_dropout(self.model)

    def _enable_dropout(self, module):
        """Force dropout modules to stay active even in eval mode (Bayesian trick)."""
        for m in module.modules():
            if isinstance(m, torch.nn.Dropout):
                m.train()

    # ---------- Bayesian uncertainty metrics ----------
    def predictive_variance(self, probs_mc):
        """
        probs_mc: shape [K, num_classes]
        Returns scalar uncertainty = variance of predicted probability for predicted class
        """
        mean_prob = probs_mc.mean(dim=0)             # [num_classes]
        var = ((probs_mc - mean_prob)**2).mean(dim=0)
        pred_class = torch.argmax(mean_prob).item()
        return var[pred_class].item(), mean_prob, pred_class

    def predictive_entropy(self, probs):
        """
        probs: mean probability vector
        Entropy = -sum(p log p)
        """
        entropy = -(probs * torch.log(probs + 1e-12)).sum().item()
        return entropy

    @torch.no_grad()
    def classify(self, text):
        enc = self.tokenizer(text, return_tensors="pt", truncation=True, padding=False)
        input_ids = enc["input_ids"].to(self.device)

        attention_mask = (enc["attention_mask"] == 1).to(self.device)
        position_ids = torch.arange(
            0, input_ids.size(1), dtype=torch.long, device=self.device
        ).unsqueeze(0)

        hidden = self.embeddings(input_ids)

        # Manual forward block by block
        for i, layer in enumerate(self.layers):

            hidden = layer(
                hidden_states=hidden,
                attention_mask=attention_mask,
                position_ids=position_ids,
            )[0]

            # ---- BAYESIAN EXIT LAYER ----
            if i in self.exit_layers:

                cls = hidden[:, 0, :]  # CLS embedding

                # MONTE CARLO DROPOUT
                probs_mc = []
                for _ in range(self.dropout_passes):
                    logits = self.classifier(cls)
                    probs_mc.append(F.softmax(logits, dim=-1))

                probs_mc = torch.stack(probs_mc, dim=0).squeeze(1)   # FIX: remove batch dim

                # Bayesian uncertainty (variance)
                var, mean_prob, pred_class = self.predictive_variance(probs_mc)

                # Alternatively use entropy
                if self.use_entropy:
                    entropy = self.predictive_entropy(mean_prob)
                    unc_value = entropy
                else:
                    unc_value = var

                # EXIT RULE
                if unc_value <= self.uncertainty_threshold:
                    conf = mean_prob[pred_class].item()
                    return pred_class, i + 1, conf, unc_value

        # Final layer (if no early exit)
        hidden = self.final_norm(hidden)
        cls = hidden[:, 0, :]
        logits = self.classifier(cls)
        final_probs = F.softmax(logits, dim=-1)[0]
        conf, final_pred = torch.max(final_probs, dim=0)

        return final_pred.item(), self.num_layers, conf.item(), 0.0

In [4]:
class ModernBERTBaselineClassifier:
    def __init__(self, model_name="answerdotai/ModernBERT-base"):
        # Detect device: MPS (Apple Silicon) → else CPU
        self.device = "cpu"

        # Load tokenizer + classification model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = ModernBertForSequenceClassification.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()

        # ModernBERT layer count (22 for base)
        self.num_layers = self.model.config.num_hidden_layers

    @torch.no_grad()
    def classify(self, text):
        # Tokenize
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            padding=False
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Forward pass
        outputs = self.model(**inputs)
        logits = outputs.logits  # [1, num_labels]

        # Softmax for confidence
        probs = F.softmax(logits, dim=-1)[0]
        conf, pred = torch.max(probs, dim=0)

        return pred.item(), (self.num_layers - 1), conf.item(), None

In [5]:
import time
import numpy as np

from evaluation.dataset_loaders.sst2 import load_sst2
from evaluation.dataset_loaders.agnews import load_agnews
from evaluation.dataset_loaders.amazon import load_amazon_polarity
from evaluation.dataset_loaders.imdb import load_imdb
from evaluation.dataset_loaders.dbpedia import load_dbpedia
from evaluation.dataset_loaders.yanswers import load_yahoo

dataset_loaders = [
    ("sst2", load_sst2, "classification"),
    ("imdb", load_imdb, "classification"),
    ("amazon_polarity", load_amazon_polarity, "classification")
]

cached_datasets = {}
print("Loading datasets once...\n")

for name, loader, task in dataset_loaders:
    print(f"Loading {name}...")
    cached_datasets[name] = {
        "data": loader(number=500),
        "task": task
    }

print("\nAll datasets loaded.\n")

Loading datasets once...

Loading sst2...
Loading imdb...
Loading amazon_polarity...

All datasets loaded.



In [6]:
# Extract (text, label) from any format
def extract_text_label(sample):
    if isinstance(sample, dict):
        if "text" in sample:
            return sample["text"], sample["label"]
        elif "sentence" in sample:
            return sample["sentence"], sample["label"]
        elif "input_text" in sample:
            return sample["input_text"], sample["label"]
        else:
            raise ValueError("Unknown dict format:", sample)

    if isinstance(sample, (tuple, list)):
        return sample[0], sample[1]

    raise ValueError("Unknown sample format:", sample)

def evaluate_dataset(model, dataset, dataset_name):
    correct = 0
    total = 0
    layers_used = []
    total_tokens = 0
    uncertainties = []     # NEW

    start = time.time()

    for sample in dataset:
        text, label = extract_text_label(sample)

        # Bayesian model returns 4 outputs
        # Baseline model: pred, layer, conf, unc=None  (you already fixed baseline)
        pred, layer, conf, unc = model.classify(text)

        correct += (pred == label)
        total += 1
        layers_used.append(layer)
        uncertainties.append(unc if unc is not None else 0.0)

        total_tokens += len(model.tokenizer(text)["input_ids"])

    end = time.time()
    latency = (end - start) / total

    return {
        "metric": "accuracy",
        "score": correct / total,
        "avg_latency_sec": latency,
        "tokens_per_sec": total_tokens / (end - start),
        "avg_layers_used": float(np.mean(layers_used)),
        "avg_uncertainty": float(np.mean(uncertainties)),   # NEW
        "num_samples": total
    }

In [8]:
results_table = []

print("Running BASELINE")

baseline_model = ModernBERTBaselineClassifier()

for name, meta in cached_datasets.items():
    dataset = meta["data"]
    print(f"\nTesting BASELINE on {name}...")

    result = evaluate_dataset(baseline_model, dataset, name)
    print(name, result)

    results_table.append({
        "dataset": name,
        "threshold": None,
        "mode": "baseline",
        "metric": result["metric"],
        "score": float(result["score"]),
        "avg_latency_sec": float(result["avg_latency_sec"]),
        "tokens_per_sec": float(result["tokens_per_sec"]),
        "avg_layers_used": float(result["avg_layers_used"]),
        "num_samples": int(result["num_samples"]),
    })

Running BASELINE


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Testing BASELINE on sst2...
sst2 {'metric': 'accuracy', 'score': 0.508, 'avg_latency_sec': 0.04796639823913574, 'tokens_per_sec': 542.6715566640596, 'avg_layers_used': 21.0, 'avg_uncertainty': 0.0, 'num_samples': 500}

Testing BASELINE on imdb...
imdb {'metric': 'accuracy', 'score': 0.528, 'avg_latency_sec': 0.19430025577545165, 'tokens_per_sec': 1488.9944372196362, 'avg_layers_used': 21.0, 'avg_uncertainty': 0.0, 'num_samples': 500}

Testing BASELINE on amazon_polarity...
amazon_polarity {'metric': 'accuracy', 'score': 0.496, 'avg_latency_sec': 0.07481225967407226, 'tokens_per_sec': 1291.7134226529881, 'avg_layers_used': 21.0, 'avg_uncertainty': 0.0, 'num_samples': 500}


In [None]:
# -------------------------
# EARLY EXIT NEXT
# -------------------------
exit_layer_groups = [
    [2, 4, 6, 8, 10],
    [4, 8],
    [3, 6, 9],
    [6]
]

thresholds = [0.005, 0.01, 0.02, 0.05, 0.1]

for exit_layers in exit_layer_groups:

    print("\n=================================================")
    print(f"Testing EXIT LAYERS: {exit_layers}")
    print("=================================================")

    for th in thresholds:

        print(f"\n----------------------------")
        print(f"Threshold = {th}")
        print(f"Exit Layers = {exit_layers}")
        print("----------------------------")

        # Create modernbert early exit model
        model = ModernBERTBayesianEarlyExit(
            model_name="answerdotai/ModernBERT-base",
            exit_layers=exit_layers,
            dropout_passes=10,
            uncertainty_threshold=th,
            use_entropy=False
        )

        # Evaluate across datasets
        for name, meta in cached_datasets.items():
            dataset = meta["data"]

            print(f"\nTesting {name} (exit_layers={exit_layers}, threshold={th})...")

            result = evaluate_dataset(model, dataset, name)
            print(name, result)

            # Save results with exit_layers column added
            results_table.append({
                "dataset": name,
                "threshold": th,
                "exit_layers": str(exit_layers),
                "mode": "early_exit",
                "metric": result["metric"],
                "score": float(result["score"]),
                "tokens_per_sec": float(result["tokens_per_sec"]),
                "avg_latency_sec": float(result["avg_latency_sec"]),
                "avg_layers_used": float(result["avg_layers_used"]),
                "avg_uncertainty": float(result["avg_uncertainty"]),
                "num_samples": int(result["num_samples"]),
            })



Testing EXIT LAYERS: [2, 4, 6, 8, 10]

----------------------------
Threshold = 0.005
Exit Layers = [2, 4, 6, 8, 10]
----------------------------


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Testing sst2 (exit_layers=[2, 4, 6, 8, 10], threshold=0.005)...
sst2 {'metric': 'accuracy', 'score': 0.546, 'avg_latency_sec': 0.004455510139465332, 'tokens_per_sec': 5842.204188794336, 'avg_layers_used': 3.0, 'avg_uncertainty': 2.842170943040401e-17, 'num_samples': 500}

Testing imdb (exit_layers=[2, 4, 6, 8, 10], threshold=0.005)...
imdb {'metric': 'accuracy', 'score': 0.52, 'avg_latency_sec': 0.01574358606338501, 'tokens_per_sec': 18376.499409677403, 'avg_layers_used': 3.0, 'avg_uncertainty': 2.913225216616411e-16, 'num_samples': 500}

Testing amazon_polarity (exit_layers=[2, 4, 6, 8, 10], threshold=0.005)...


In [None]:
import pandas as pd
df=pd.DataFrame(results_table)
df_sorted = (
    df.groupby("dataset", group_keys=True)
      .apply(lambda g: g.sort_values("score", ascending=False))
      .reset_index(drop=True)
)
df_sorted

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math

df2 = df.copy()

# Replace NaN thresholds with 'baseline'
df2['threshold'] = df2['threshold'].apply(
    lambda x: "baseline" if pd.isna(x) else x
)

# Convert exit_layers list → string for grouping
df2['exit_layers_str'] = df2['exit_layers'].apply(lambda x: str(x))

datasets = df2['dataset'].unique()

for ds in datasets:

    df_ds = df2[df2['dataset'] == ds]

    # Get all exit_layers configs EXCEPT baseline
    exit_configs = sorted(df_ds['exit_layers_str'].unique())
    exit_configs = [cfg for cfg in exit_configs if cfg != "nan"]

    # Only keep configs that are real lists (not baseline)
    # ensure exactly 4 for plotting (or fewer → still works)
    exit_configs = exit_configs[:4]

    # Prepare 2×2 subplot grid
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()

    # Extract baseline row(s)
    df_base = df_ds[df_ds['mode'] == "baseline"]

    # ---------- LOOP OVER EXIT CONFIGS ----------
    for idx, exit_cfg in enumerate(exit_configs):

        ax = axes[idx]

        # Subset for this exit layers
        df_sub = df_ds[df_ds['exit_layers_str'] == exit_cfg]

        ###########################################
        # 1) Extract values (early-exit run)
        ###########################################
        # Early exit values
        thresholds = df_sub['threshold'].tolist()
        scores     = df_sub['score'].tolist()
        layers     = df_sub['avg_layers_used'].tolist()
        tps        = df_sub['tokens_per_sec'].tolist()
        
        # Add baseline explicitly (first element)
        if not df_base.empty:
            thresholds = ["baseline"] + thresholds
            scores     = [df_base['score'].iloc[0]] + scores
            layers     = [df_base['avg_layers_used'].iloc[0]] + layers
            tps        = [df_base['tokens_per_sec'].iloc[0]] + tps


        ###########################################
        # 2) Extract baseline values
        ###########################################
        base_thresholds = df_base['threshold'].tolist()
        base_scores     = df_base['score'].tolist()
        base_layers     = df_base['avg_layers_used'].tolist()
        base_tps        = df_base['tokens_per_sec'].tolist()

        # ---- Sorting helper: baseline always first ----
        def sort_key(x):
            return -1 if x == "baseline" else float(x)

        sorted_idx = sorted(
            range(len(thresholds)),
            key=lambda i: sort_key(thresholds[i])
        )

        thresholds = [thresholds[i] for i in sorted_idx]
        scores     = [scores[i]     for i in sorted_idx]
        layers     = [layers[i]     for i in sorted_idx]
        tps        = [tps[i]        for i in sorted_idx]

        x_pos = np.arange(len(thresholds))

        ################################################
        # ---- PLOTS (same style as your original) ----
        ################################################
        ax1 = ax

        # Accuracy curve
        ax1.plot(x_pos, scores, marker="o", color="tab:blue", label="Accuracy")
        ax1.set_ylabel("Accuracy", color="tab:blue")
        ax1.tick_params(axis="y", labelcolor="tab:blue")

        # Plot baseline point in gold
        for i, th in enumerate(thresholds):
            if th == "baseline":
                ax1.scatter(
                    x_pos[i], scores[i],
                    color="gold", s=140, edgecolor="black", zorder=6,
                    label="Baseline"
                )

        # Avg layers
        ax2 = ax1.twinx()
        ax2.plot(
            x_pos, layers,
            marker="s", linestyle="--",
            color="tab:red", label="Layers Used"
        )
        ax2.set_ylabel("Avg Layers Used", color="tab:red")
        ax2.tick_params(axis="y", labelcolor="tab:red")

        # Tokens/sec
        ax3 = ax1.twinx()
        ax3.spines["right"].set_position(("outward", 50))
        ax3.plot(
            x_pos, tps,
            marker="^", linestyle=":",
            color="tab:green", label="Tokens/sec"
        )
        ax3.set_ylabel("Tokens/sec", color="tab:green")
        ax3.tick_params(axis="y", labelcolor="tab:green")

        ax.set_xticks(x_pos)
        ax.set_xticklabels(thresholds)
        ax.set_title(f"exit_layers = {exit_cfg}")

        ax.grid(True, linestyle="--", alpha=0.3)

    # ------- FIGURE TITLE -------
    fig.suptitle(f"ModernBERT with Bayesian Early-Exit Trade-Off Curves: {ds}", fontsize=16)
    plt.tight_layout()
    plt.show()