In [None]:
# pip install transformers accelerate scikit-learn
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, math
import numpy as np
from sklearn.metrics import roc_auc_score

In [None]:
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

In [None]:
examples = [
    {
    "question": "In the context of time complexity, which of the following sorting algorithms has a worst-case performance of O(n log n)?",
    "choices": [
        "QuickSort",
        "MergeSort",
        "BubbleSort",
        "InsertionSort"
    ],
    "answer_idx": 1,
    },
    {
    "question": "The Treaty of Versailles, which formally ended World War I, was signed in which year?",
    "choices": ["1917", "1918", "1919", "1920"],
    "answer_idx": 2,
    },
    {
    "question": "In the U.S. legal system, the 'exclusionary rule' primarily prohibits the use of what in a criminal trial?",
    "choices": [
        "Hearsay evidence",
        "Evidence obtained through an unlawful search and seizure",
        "Testimony from an accomplice",
        "Circumstantial evidence"
    ],
    "answer_idx": 1,
    },
    {
        "question": "Which of the following is a prime number?",
        "choices": ["21", "27", "29", "33"],
        "answer_idx": 2
    },
]

letters = "ABCD"

In [None]:
def build_prompt(q, choices):
    choices_str = "\n".join(f"{letters[i]}) {choices[i]}" for i in range(len(choices)))
    return f"Question: {q}\n\nChoices:\n{choices_str}\n\nAnswer with a single letter (A, B, C, ...).\nAnswer:"

prompts = [build_prompt(e["question"], e["choices"]) for e in examples]

In [None]:
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=32,
    return_dict_in_generate=True,
    output_scores=True,
    temperature=0.0  # greedy
)

# transition_scores: (batch, gen_len)
transition_scores = model.compute_transition_scores(
    outputs.sequences, outputs.scores, normalize_logits=True
)

In [None]:
# exclude prompt tokens
input_lens = (inputs["input_ids"] != tokenizer.pad_token_id).sum(dim=1).tolist()
gen_logprobs_batch = []
decoded_batch = []
for b in range(len(prompts)):
    start = input_lens[b]
    gen_logprobs = transition_scores[b, start:]
    gen_logprobs_batch.append(gen_logprobs.cpu().numpy())
    decoded = tokenizer.decode(outputs.sequences[b, start:], skip_special_tokens=True)
    decoded_batch.append(decoded.strip())


In [None]:
# TPU per example
def compute_tpu(logprobs):
    if len(logprobs) == 0:
        return float("nan")
    avg_log = np.mean(logprobs)
    return 1.0 - math.exp(avg_log)

tpu_values = [compute_tpu(lp) for lp in gen_logprobs_batch]

In [None]:
# Parse model answers
def parse_answer(text):
    if not text:
        return None
    first = text.strip()[0].upper()
    if first in letters:
        return letters.index(first)
    return None

pred_indices = [parse_answer(out) for out in decoded_batch]
correct_mask = [
    (pred is not None and pred == ex["answer_idx"])
    for pred, ex in zip(pred_indices, examples)
]

In [None]:
# Metrics: ECE & AUROC
def expected_calibration_error(unc, correct, n_bins=10):
    unc = np.asarray(unc)
    correct = np.asarray(correct)
    valid = ~np.isnan(unc)
    unc, correct = unc[valid], correct[valid]
    if len(unc) == 0: return float("nan")

    bins = np.linspace(0, 1, n_bins+1)
    bin_idx = np.digitize(unc, bins) - 1
    ece = 0.0
    n = len(unc)
    for b in range(n_bins):
        mask = bin_idx == b
        if mask.sum() == 0: continue
        acc = correct[mask].mean()
        avg_unc = unc[mask].mean()
        # predicted confidence = 1 - unc
        ece += (mask.sum()/n) * abs(acc - (1.0 - avg_unc))
    return ece

In [None]:
ece = expected_calibration_error(tpu_values, correct_mask, n_bins=10)
y_true = (~np.array(correct_mask)).astype(int)  # incorrect=1
try:
    auroc = roc_auc_score(y_true, np.nan_to_num(tpu_values))
except Exception:
    auroc = float("nan")

In [None]:
# Print results
print("\nResults per example:")
for i, ex in enumerate(examples):
    print(f"Q: {ex['question']}")
    print(f"Model output: {decoded_batch[i]}")
    print(f"Pred idx: {pred_indices[i]}, GT idx: {ex['answer_idx']}, Correct: {correct_mask[i]}")
    print(f"TPU: {tpu_values[i]:.4f}\n")

print("Metrics:")
print(f"ECE:   {ece:.4f}")
print(f"AUROC: {auroc:.4f}")

In [None]:
import matplotlib.pyplot as plt

def plot_reliability_diagram(unc, correct, n_bins=10):
    unc = np.asarray(unc)
    correct = np.asarray(correct)
    valid = ~np.isnan(unc)
    unc, correct = unc[valid], correct[valid]
    if len(unc) == 0:
        print("No valid data to plot")
        return
    
    conf = 1.0 - unc
    
    # Bin by confidence
    bins = np.linspace(0.0, 1.0, n_bins+1)
    bin_idx = np.digitize(conf, bins) - 1
    bin_acc = []
    bin_conf = []
    
    for b in range(n_bins):
        mask = bin_idx == b
        if mask.sum() == 0:
            continue
        acc = correct[mask].mean()
        avg_conf = conf[mask].mean()
        bin_acc.append(acc)
        bin_conf.append(avg_conf)
    
    # Plot
    plt.figure(figsize=(5,5))
    plt.plot([0,1],[0,1], 'k--', label="Perfect calibration")
    plt.plot(bin_conf, bin_acc, marker='o', label="TPU")
    plt.xlabel("Predicted confidence (1 - TPU)")
    plt.ylabel("Empirical accuracy")
    plt.title("Reliability Diagram (TPU)")
    plt.legend()
    plt.grid(True)
    plt.show()

plot_reliability_diagram(tpu_values, correct_mask, n_bins=10)
