### Not finished - working on

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

label_names = {
    "sst2": ["negative", "positive"],
    "agnews": ["world", "sports", "business", "tech"]
}

class GPT2EarlyExitClassifier(torch.nn.Module):
    def __init__(self, model_name, exit_layers, threshold, verbalizers, use_prompt):
        super().__init__()

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, output_hidden_states=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.exit_layers = exit_layers
        self.threshold = threshold
        self.verbalizers = verbalizers
        self.num_layers = len(self.model.transformer.h)
        self.use_prompt = use_prompt

        # Precompute verbalizer token ids
        self.verbalizer_token_ids = {}
        for dataset, class_map in verbalizers.items():
            ids = {}
            for cls, words in class_map.items():
                tok_lists = [self.tokenizer.encode(" " + w) for w in words]
                ids[cls] = tok_lists
            self.verbalizer_token_ids[dataset] = ids

    # Prompt builder
    def build_prompt(self, text, dataset_name):
        if not self.use_prompt:
            return text

        if dataset_name == "sst2":
            return (
                "To which category does the text belong?\n"
                "\"positive sentiment\", \"negative sentiment\"\n\n"
                f"Text: {text}\n"
            )

        if dataset_name == "agnews":
            return (
                "What is the topic of this article?\n"
                "\"world\", \"sports\", \"business\", \"technology\"\n\n"
                f"News Article: {text}\n"
            )

        return text

    @torch.no_grad()
    def classify(self, text, dataset_name):
        labels = label_names[dataset_name]
        prompt = self.build_prompt(text, dataset_name)

        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"]

        outputs = self.model(input_ids, output_hidden_states=True)
        hidden_states = outputs.hidden_states
        class_verbalizers = self.verbalizer_token_ids[dataset_name]

        # ---- EARLY EXIT LOOP ----
        for layer_idx in self.exit_layers:
            h = hidden_states[layer_idx][:, -1, :]
            logits = self.model.lm_head(h)[0]

            class_scores = []
            for cls, tok_lists in class_verbalizers.items():
                vals = [logits[torch.tensor(toks)].mean().item() for toks in tok_lists]
                class_scores.append(float(np.mean(vals)))

            class_scores = torch.tensor(class_scores)
            probs = torch.softmax(class_scores, dim=-1)

            pred = int(torch.argmax(probs))
            conf = float(probs[pred])

            if conf >= self.threshold:
                return {
                    "layer": layer_idx,
                    "pred": labels[pred],
                    "conf": conf,
                    "scores": {labels[i]: float(class_scores[i]) for i in range(len(labels))},
                    "probs": {labels[i]: float(probs[i]) for i in range(len(labels))}
                }

        # ---- FINAL EXIT AT LAYER 12 ----
        final_layer = self.num_layers
        h = hidden_states[final_layer][:, -1, :]
        logits = self.model.lm_head(h)[0]

        class_scores = []
        for cls, tok_lists in class_verbalizers.items():
            vals = [logits[torch.tensor(toks)].mean().item() for toks in tok_lists]
            class_scores.append(float(np.mean(vals)))

        class_scores = torch.tensor(class_scores)
        probs = torch.softmax(class_scores, dim=-1)
        pred = int(torch.argmax(probs))

        return {
            "layer": final_layer,
            "pred": labels[pred],
            "conf": float(probs[pred]),
            "scores": {labels[i]: float(class_scores[i]) for i in range(len(labels))},
            "probs": {labels[i]: float(probs[i]) for i in range(len(labels))}
        }

In [12]:
from evaluation.dataset_loaders.sst2 import load_sst2
from evaluation.dataset_loaders.agnews import load_agnews
import pandas as pd
import time

verbalizers = {
    "sst2": {
        0: ["negative"],
        1: ["positive"]
    },
    "agnews": {
        0: ["international", "world", "global"],
        1: ["sports", "sport"],
        2: ["business", "finance", "market"],
        3: ["technology", "tech"]
    }
}

dataset_loaders = [
    ("sst2", load_sst2),
    ("agnews", load_agnews),
]


def extract(sample):
    if isinstance(sample, dict):
        for k in ["text", "sentence", "input_text"]:
            if k in sample:
                return sample[k], sample["label"]
    return sample[0], sample[1]

def run_experiment(use_prompt, threshold, samples):

    model = GPT2EarlyExitClassifier(
        model_name="gpt2",
        exit_layers=[2,4,6,8,10],
        threshold=threshold,
        verbalizers=verbalizers,
        use_prompt=use_prompt
    )

    records = []
    summary_records = []

    for name, loader in dataset_loaders:
        dataset = loader(number=samples)

        correct = 0
        total = 0
        layers = []
        total_tokens = 0

        t0 = time.time()

        for sample in dataset:
            text, gold = extract(sample)
            out = model.classify(text, name)

            gold_name = label_names[name][gold]

            # per-sample record
            row = {
                "dataset": name,
                "text": text,
                "gold": gold,
                "gold_name": gold_name,
                "pred": out["pred"],
                "layer": out["layer"],
                "conf": out["conf"],
            }

            # add class-wise scores
            for label, sc in out["scores"].items():
                row[f"{label}_score"] = sc
            for label, sc in out["probs"].items():
                row[f"{label}_prob"] = sc

            records.append(row)

            # summary stats
            layers.append(out["layer"])
            correct += (out["pred"] == gold_name)
            total += 1
            total_tokens += len(model.tokenizer(text)["input_ids"])

        t1 = time.time()

        summary = {
            "dataset": name,
            "mode": "with_prompt" if use_prompt else "without_prompt",
            "threshold": threshold,
            "metric": "accuracy",
            "score": correct/total,
            "avg_latency_sec": (t1-t0)/total,
            "tokens_per_sec": total_tokens/(t1-t0),
            "avg_layers_used": float(np.mean(layers)),
            "num_samples": total
        }

        summary_records.append(summary)

    return pd.DataFrame(records), pd.DataFrame(summary_records)

In [13]:
th = 0.8
sam = 500
df_with_prompt, df_sum_with_prompt = run_experiment(use_prompt=True, threshold=th, samples=sam)
df_without_prompt, df_sum_without_prompt = run_experiment(use_prompt=False, threshold=th, samples=sam)

df_with_prompt.head()
df_sum_with_prompt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Unnamed: 0,dataset,mode,threshold,metric,score,avg_latency_sec,tokens_per_sec,avg_layers_used,num_samples
0,sst2,with_prompt,0.8,accuracy,0.512,0.030532,772.110025,2.708,500
1,agnews,with_prompt,0.8,accuracy,0.38,0.038963,1325.457028,3.296,500


In [14]:
df_sum_without_prompt

Unnamed: 0,dataset,mode,threshold,metric,score,avg_latency_sec,tokens_per_sec,avg_layers_used,num_samples
0,sst2,without_prompt,0.8,accuracy,0.614,0.026552,887.840841,6.156,500
1,agnews,without_prompt,0.8,accuracy,0.31,0.033115,1559.544666,2.472,500


In [15]:
df_without_prompt

Unnamed: 0,dataset,text,gold,gold_name,pred,layer,conf,negative_score,positive_score,negative_prob,positive_prob,world_score,sports_score,business_score,tech_score,world_prob,sports_prob,business_prob,tech_prob
0,sst2,a science-fiction pastiche so lacking in origi...,0,negative,negative,12,0.658594,-79.474403,-80.131439,0.658594,0.341406,,,,,,,,
1,sst2,"it haunts you , you ca n't forget it , you adm...",1,positive,positive,6,0.950830,-4.314972,-1.352921,0.049170,0.950830,,,,,,,,
2,sst2,"nicks , seemingly uncertain what 's going to m...",0,negative,positive,6,0.809370,-3.339309,-1.893386,0.190630,0.809370,,,,,,,,
3,sst2,if there 's one thing this world needs less of...,0,negative,negative,10,0.801071,-24.434387,-25.827389,0.801071,0.198929,,,,,,,,
4,sst2,chokes on its own depiction of upper-crust dec...,0,negative,negative,10,0.816541,-22.957804,-24.450891,0.816541,0.183459,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,agnews,South Koreans Say Secret Work Refined Uranium ...,0,world,world,2,0.947984,,,,,-6.752152,-13.032355,-11.453743,-9.877625,0.947984,0.001776,0.008608,0.041632
996,agnews,Testaverde accepts Parcells #39; nomination Wh...,1,sports,world,2,0.993671,,,,,-1.000196,-7.019704,-11.387556,-6.545090,0.993671,0.002415,0.000031,0.003883
997,agnews,"Gray, Demon Deacons Swing Back In Action, Seek...",1,sports,world,2,0.975014,,,,,-6.519638,-13.623807,-11.563251,-10.517523,0.975014,0.000801,0.006289,0.017896
998,agnews,Still no beef resolution after latest talks NE...,2,business,world,2,0.996820,,,,,-1.210962,-7.580405,-11.728655,-7.747224,0.996820,0.001708,0.000027,0.001445
