In [1]:
from collections.abc import MutableMapping
from collections import UserDict
import numpy as np
import torch
from torch import nn
import torch.utils.data as data_utils
from tqdm.auto import tqdm

from laplace import Laplace

import logging
import warnings

logging.basicConfig(level="ERROR")
warnings.filterwarnings("ignore")

from transformers import ( # noqa: E402
    GPT2Config,
    GPT2ForSequenceClassification,
    GPT2Tokenizer,
    DataCollatorWithPadding,
    PreTrainedTokenizer,
)
from peft import LoraConfig, get_peft_model # noqa: E402
from datasets import Dataset, load_dataset# noqa: E402
# make deterministic

torch.manual_seed(0)
np.random.seed(0)

In [3]:
model_name = "tingtone/go_emo_gpt"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

# data = [
#     {"text": "Today is hot, but I will manage!!!!", "label": 1},
#     {"text": "Tomorrow is cold", "label": 0},
#     {"text": "Carpe diem", "label": 1},
#     {"text": "Tempus fugit", "label": 1},
# ]
dataset = load_dataset("go_emotions")
test_dataset = dataset["test"]
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
num_labels = dataset["train"].features["labels"].feature.num_classes  # => 28
print(f"Number of labels: {num_labels}")


train_dataset = train_dataset.shuffle(seed=42).select(range(10))
def preprocess(batch):
    # 1) Tokenize to lists only
    toks = tokenizer(
        batch["text"],
        truncation=True,      # truncate long sequences
        max_length=1024,      # but DON'T pad here
        padding=False
    )
    # 2) Build multi-hot labels as Python lists
    mh = np.zeros((len(batch["labels"]), num_labels), dtype=np.float32)
    for i, labs in enumerate(batch["labels"]):
        mh[i, labs] = 1
    toks["label"] = mh.tolist()
    return toks

# Apply without setting torch format
train_dataset = train_dataset.map(preprocess, batched=True, remove_columns=["text"])
val_dataset   = val_dataset.map(preprocess, batched=True, remove_columns=["text"])
test_dataset  = test_dataset.map(preprocess, batched=True, remove_columns=["text"])

train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

collator = DataCollatorWithPadding(tokenizer)


train_dataloader = data_utils.DataLoader(
    train_dataset, batch_size=100, collate_fn=collator
)

val_dataloader = data_utils.DataLoader(
    val_dataset, batch_size=100, collate_fn=collator
)
test_dataloader = data_utils.DataLoader(
    test_dataset, batch_size=100, collate_fn=collator
)

# data = next(iter(train_dataloader))
# print(
#     f"Huggingface data defaults to UserDict, which is a MutableMapping? {isinstance(data, UserDict)}"
# )
# for k, v in data.items():
#     print(k, v.shape)

Number of labels: 28


In [4]:
class MyGPT2(nn.Module):
    """
    Huggingface LLM wrapper.

    Args:
        tokenizer: The tokenizer used for preprocessing the text data. Needed
            since the model needs to know the padding token id.
    """

    def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
        super().__init__()
        config = GPT2Config.from_pretrained(model_name)
        config.pad_token_id = tokenizer.pad_token_id
        config.num_labels = num_labels
        self.hf_model = GPT2ForSequenceClassification.from_pretrained(
            model_name, config=config
        )

    def forward(self, data: MutableMapping) -> torch.Tensor:
        """
        Custom forward function. Handles things like moving the
        input tensor to the correct device inside.

        Args:
            data: A dict-like data structure with `input_ids` inside.
                This is the default data structure assumed by Huggingface
                dataloaders.

        Returns:
            logits: An `(batch_size, n_classes)`-sized tensor of logits.
        """
        device = next(self.parameters()).device
        input_ids = data["input_ids"].to(device)
        attn_mask = data["attention_mask"].to(device)
        output_dict = self.hf_model(input_ids=input_ids, attention_mask=attn_mask)
        return output_dict.logits

model = MyGPT2(tokenizer)

In [5]:
model = MyGPT2(tokenizer)
model.eval()

la = Laplace(
    model,
    likelihood="classification",
    subset_of_weights="last_layer",
    hessian_structure="diag",
    # This must reflect faithfully the reduction technique used in the model
    # Otherwise, correctness is not guaranteed
    feature_reduction="pick_last",
)
la.fit(train_loader=train_dataloader)
la.optimize_prior_precision()



In [8]:
all_logits = []
all_la_preds = []
all_labels = []
total_loss = 0.0
device = "cuda" if torch.cuda.is_available() else "mps"
for batch in tqdm(test_dataloader, desc="Evaluating"):
    data = {k: v.to(device) for k, v in batch.items()}
    # input_ids      = batch["input_ids"].to(device)
    # attention_mask = batch["attention_mask"].to(device)
    labels         = batch["labels"].to(device)
    labels = labels.float()
    with torch.no_grad():
        logits = model(data)
        # loss   = loss_fn(logits, labels)
    la_pred = la(batch)

    # total_loss += loss.item()
    all_logits.append(logits.cpu().numpy())
    all_la_preds.append(la_pred.cpu().numpy())
    all_labels.append(labels.cpu().numpy())

Evaluating:   0%|          | 0/55 [00:00<?, ?it/s]

In [10]:
# stack arrays
all_logits = np.vstack(all_logits)
all_labels = np.vstack(all_labels)
all_la_preds = np.vstack(all_la_preds)
probs      = torch.sigmoid(torch.tensor(all_logits)).numpy()
preds      = (probs >= 0.5).astype(int)
la_preds  = (all_la_preds >= 0.5).astype(int)
print(all_labels.shape, all_logits.shape)
print(probs.shape)
print(la_preds.shape)


(5427, 28) (5427, 28)
(5427, 28)
(5427, 28)


In [11]:
from metrics import (
    elementwise_accuracy,
    subset_accuracy,
    hamming_loss,
    f1_scores,
    roc_auc_scores,
    log_loss_multilabel,
    brier_score_multilabel,
    get_calibration,
)

# ——— Base GPT‑2 classifier metrics ———
print("=== Base GPT‑2 classifier ===")
print(f"Element-wise Accuracy: {elementwise_accuracy(all_labels, preds):.4f}")
print(f"Subset (Exact) Acc    : {subset_accuracy(all_labels, preds):.4f}")
print(f"Hamming Loss          : {hamming_loss(all_labels, preds):.4f}")
print(f"F1 (micro)            : {f1_scores(all_labels, preds, average='micro'): .4f}")
print(f"F1 (macro)            : {f1_scores(all_labels, preds, average='macro'): .4f}")
print(f"ROC AUC (micro)       : {roc_auc_scores(all_labels, probs, average='micro'): .4f}")
print(f"ROC AUC (macro)       : {roc_auc_scores(all_labels, probs, average='macro'): .4f}")
print(f"Log Loss              : {log_loss_multilabel(all_labels, probs):.4f}")
print(f"Brier Score           : {brier_score_multilabel(all_labels, probs):.4f}")
ece_cls, ece_glob, mce_glob = get_calibration(probs, all_labels)
print(f"Global ECE            : {ece_glob:.4f}")
print(f"Global MCE            : {mce_glob:.4f}")
print(f"Per‑class ECE         : {ece_cls}")

# ——— Laplace‑augmented model metrics ———
print("\n=== Last‑layer Laplace ===")
print(f"Element-wise Accuracy: {elementwise_accuracy(all_labels, la_preds):.4f}")
print(f"Subset (Exact) Acc    : {subset_accuracy(all_labels, la_preds):.4f}")
print(f"Hamming Loss          : {hamming_loss(all_labels, la_preds):.4f}")
print(f"F1 (micro)            : {f1_scores(all_labels, la_preds, average='micro'): .4f}")
print(f"F1 (macro)            : {f1_scores(all_labels, la_preds, average='macro'): .4f}")
print(f"ROC AUC (micro)       : {roc_auc_scores(all_labels, all_la_preds, average='micro'): .4f}")
print(f"ROC AUC (macro)       : {roc_auc_scores(all_labels, all_la_preds, average='macro'): .4f}")
print(f"Log Loss              : {log_loss_multilabel(all_labels, all_la_preds):.4f}")
print(f"Brier Score           : {brier_score_multilabel(all_labels, all_la_preds):.4f}")
ece_cls_la, ece_glob_la, mce_glob_la = get_calibration(all_la_preds, all_labels)
print(f"Global ECE            : {ece_glob_la:.4f}")
print(f"Global MCE            : {mce_glob_la:.4f}")
print(f"Per‑class ECE         : {ece_cls_la}")

=== Base GPT‑2 classifier ===
Element-wise Accuracy: 0.9703
Subset (Exact) Acc    : 0.5043
Hamming Loss          : 0.0297
F1 (micro)            :  0.6024
F1 (macro)            :  0.4722
ROC AUC (micro)       :  0.9557
ROC AUC (macro)       :  0.9219
Log Loss              : 0.0921
Brier Score           : 0.0239
Global ECE            : 0.0138
Global MCE            : 0.2876
Per‑class ECE         : [0.03275771 0.0102767  0.01654598 0.02848275 0.02734285 0.01318049
 0.01122281 0.02363256 0.00763362 0.0155058  0.02236475 0.01278246
 0.00305376 0.01160106 0.00443426 0.00629778 0.00038578 0.01055669
 0.01219592 0.00305186 0.01285845 0.00185309 0.01476529 0.00080934
 0.00546548 0.01274358 0.01265493 0.10553359]

=== Last‑layer Laplace ===
Element-wise Accuracy: 0.9583
Subset (Exact) Acc    : 0.0000
Hamming Loss          : 0.0417
F1 (micro)            :  0.0000
F1 (macro)            :  0.0000
ROC AUC (micro)       :  0.9307
ROC AUC (macro)       :  0.9188
Log Loss              : 0.1427
Brier Sco