In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

N_PROBLEMS = 500
MODEL_NAME = "Qwen/Qwen2.5-Math-1.5B"
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
DTYPE = torch.float32
MAX_LENGTH = 256

ds = load_dataset("HuggingFaceH4/MATH-500", split="test[:{}]".format(N_PROBLEMS))
problems = [f"Q: {item['problem']} A:" for item in ds]

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=DTYPE, device_map=None
).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model.eval()

# --- Single-prompt KL computation ---
def kl_per_prompt(text):
    layer_logits = []

    def make_hook():
        def hook(module, inputs, outputs):
            hidden = outputs[0] if isinstance(outputs, tuple) else outputs
            hidden = model.model.norm(hidden)
            logits = hidden @ model.lm_head.weight.T
            layer_logits.append(logits.detach().cpu())
        return hook

    hooks = [blk.register_forward_hook(make_hook()) for blk in model.model.layers]

    inputs = tokenizer(
        text, return_tensors="pt", truncation=True, max_length=MAX_LENGTH
    ).to(DEVICE)

    with torch.no_grad():
        final_logits = model(**inputs).logits.detach().cpu()

    for h in hooks:
        h.remove()

    kl_scores = []
    for logits in layer_logits:
        q = F.log_softmax(logits[:, -1, :], dim=-1)
        p = F.softmax(final_logits[:, -1, :], dim=-1)
        kl = F.kl_div(q, p, reduction="batchmean")
        kl_scores.append(kl.item())

    return kl_scores

all_curves = []
for i, prompt in enumerate(problems):
    print(f"Processing problem {i+1}/{N_PROBLEMS}...")
    kl_curve = kl_per_prompt(prompt)
    all_curves.append(kl_curve)

max_len = max(len(c) for c in all_curves)
arr = np.array([
    np.pad(c, (0, max_len - len(c)), constant_values=np.nan)
    for c in all_curves
])
mean_curve = np.nanmean(arr, axis=0)
std_curve = np.nanstd(arr, axis=0)

plt.figure(figsize=(8, 4))
x = np.arange(len(mean_curve), step=1)
plt.plot(x, mean_curve, color="blue", label="Mean KL-Divergence")
plt.fill_between(
    x, mean_curve - std_curve, mean_curve + std_curve,
    color="blue", alpha=0.2, label="±1 Std"
)
plt.title(f"{MODEL_NAME} – Average LogitLens KL vs Layer ({N_PROBLEMS} MATH-500 Problems)")
plt.xlabel("Layer index")
plt.ylabel("Average KL Divergence to Final Logits")
plt.grid(True)
plt.legend()
plt.xticks(x)
plt.tight_layout()
plt.show()

  from .autonotebook import tqdm as notebook_tqdm
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Processing problem 1/500...
Processing problem 2/500...
Processing problem 3/500...
Processing problem 4/500...
Processing problem 5/500...
Processing problem 6/500...
Processing problem 7/500...
Processing problem 8/500...
Processing problem 9/500...
Processing problem 10/500...
Processing problem 11/500...
Processing problem 12/500...
Processing problem 13/500...
Processing problem 14/500...
Processing problem 15/500...
Processing problem 16/500...
Processing problem 17/500...
Processing problem 18/500...
Processing problem 19/500...
Processing problem 20/500...
Processing problem 21/500...
Processing problem 22/500...
Processing problem 23/500...
Processing problem 24/500...
Processing problem 25/500...
Processing problem 26/500...
Processing problem 27/500...
Processing problem 28/500...
Processing problem 29/500...
Processing problem 30/500...
Processing problem 31/500...
Processing problem 32/500...
Processing problem 33/500...
Processing problem 34/500...
Processing problem 35/5

KeyboardInterrupt: 