In [4]:
import json
from functools import partial

import torch
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer

# HF model name
HF_MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

# Device (GPU preferred)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Path to your tokenized dataset file
TOKENIZED_PATH = "/content/token_segmentation_metadata.json"


In [5]:
from huggingface_hub import login

login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
def load_tokenizer(model_name: str = HF_MODEL_NAME):
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

    # Ensure pad token exists
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return tokenizer

tokenizer = load_tokenizer()
print("Tokenizer loaded.")


Tokenizer loaded.


In [7]:
import torch
torch.cuda.empty_cache()  # Clear GPU memory cache


In [8]:
def load_model(model_name: str = HF_MODEL_NAME, device: str = DEVICE):
    model = HookedTransformer.from_pretrained(
        model_name,
        device=device,
        dtype=torch.float16,  # Use FP16 for reduced memory usage
    )
    model.eval()
    return model

model = load_model()
print("Model loaded on:", DEVICE)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Model loaded on: cuda


In [9]:
def make_attention_store():
    """
    Creates an empty Python dictionary that will be filled
    during the forward pass with attention matrices.
    """
    return {}

def save_attention_hook(attn, hook, store):
    """
    Hook called by TransformerLens every time attention is computed
    in a specific layer.
    """
    layer_idx = hook.layer()
    store[layer_idx] = attn.detach().cpu()

print("Attention hook system ready.")


Attention hook system ready.


In [10]:
def get_attention_hooks(model, store):
    """
    Attach hooks to the attention *pattern* (softmax probs) in every layer.
    """
    hooks = []
    all_hooks = set(hp.name for hp in model.hook_points())

    for layer in range(model.cfg.n_layers):
        hook_name = f"blocks.{layer}.attn.hook_pattern"

        if hook_name not in all_hooks:
            raise KeyError(
                f"{hook_name} not found in model.hook_points(). "
                f"Available attention hooks include: "
                f"{[h for h in all_hooks if 'attn' in h]}"
            )

        hook_fn = partial(save_attention_hook, store=store)
        hooks.append((hook_name, hook_fn))

    return hooks


In [11]:
def run_and_get_attention(model, row):
    """
    Core function to extract attention weights.

    Parameters
    ----------
    model : HookedTransformer
        The TransformerLens model loaded earlier.

    row : dict
        A dictionary containing tokenized information, including 'input_ids'.

    Returns
    -------
    final_store : dict
        Mapping layer_idx -> attention tensor of shape [n_heads, L, L],
        all moved to CPU for easy downstream processing.
    """

    # Extract the input_ids from the row (which contains the concatenated tokens)
    input_ids = row['input_ids']

    # Ensure tokens = [1, L]
    tokens = torch.tensor([input_ids], device=DEVICE, dtype=torch.long)  # Create a tensor with shape [1, L]

    assert tokens.ndim == 2 and tokens.shape[0] == 1, \
        f"Expected tokens of shape [1, L], got {tokens.shape}"

    # 1. Create empty store
    store = make_attention_store()

    # 2. Build hooks (one for each layer)
    hooks = get_attention_hooks(model, store)

    # 3. Run forward pass with hooks
    with torch.no_grad():
        _ = model.run_with_hooks(
            tokens,
            fwd_hooks=hooks
        )

    # 4. Post-process:
    # Each entry is [1, n_heads, L, L] → squeeze batch dim to [n_heads, L, L]
    final_store = {}
    for layer_idx, attn_tensor in store.items():
        final_store[layer_idx] = attn_tensor[0]  # remove batch dimension

    return final_store


In [12]:
def load_tokenized_example(path: str):
    """
    Returns the first non-empty JSON line as a Python dict.
    """
    with open(path, "r") as f:
        try:
            data = json.load(f)  # Load the entire JSON array
            if data:
                return data[0]  # Return the first example
            else:
                raise ValueError("The JSON array is empty.")
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
            return None

# Assuming you've loaded the correct row using the function above
row = load_tokenized_example(TOKENIZED_PATH)

# Make sure row is a valid dictionary before processing
if isinstance(row, dict):
    # Run the attention extraction and checks
    attn_store = run_and_get_attention(model, row)

    # How many layers did we capture?
    layer_indices = sorted(attn_store.keys())
    print("Layers captured:", len(layer_indices))
    print("Model reports n_layers =", model.cfg.n_layers)

    # Show shapes for the first couple of layers
    for layer_idx in layer_indices[:2]:
        attn_layer = attn_store[layer_idx]  # shape [n_heads, L, L]
        print(f"Layer {layer_idx} attention shape:", attn_layer.shape)

    # Row-sum sanity check for one head in one layer
    some_layer = layer_indices[0]  # Select the first layer for the sanity check
    attn_layer0 = attn_store[some_layer]    # Attention for the first layer, shape [n_heads, L, L]
    head0 = attn_layer0[0]                  # Select the first head in this layer, shape [L, L]

    # Sum along the last dimension (for each row) to verify normalization
    row_sums = head0.sum(dim=-1)            # Sum along the last axis (columns), shape [L]

    # Print the mean and standard deviation of row sums
    print(f"\nRow sums for layer {some_layer}, head 0:")
    print(f"  mean: {row_sums.mean().item():.6f}")
    print(f"  std:  {row_sums.std().item():.6f}")

    print("\nIf mean ≈ 1.0 and std is small, attention extraction is working.")
else:
    print("Error: 'row' is not a valid dictionary.")


Layers captured: 32
Model reports n_layers = 32
Layer 0 attention shape: torch.Size([32, 22, 22])
Layer 1 attention shape: torch.Size([32, 22, 22])

Row sums for layer 0, head 0:
  mean: 1.000000
  std:  0.000000

If mean ≈ 1.0 and std is small, attention extraction is working.


In [13]:
for hp in model.hook_points():
    if ".attn." in hp.name:  # Check if the hook point involves attention
        print(hp.name)


blocks.0.attn.hook_k
blocks.0.attn.hook_q
blocks.0.attn.hook_v
blocks.0.attn.hook_z
blocks.0.attn.hook_attn_scores
blocks.0.attn.hook_pattern
blocks.0.attn.hook_result
blocks.0.attn.hook_rot_k
blocks.0.attn.hook_rot_q
blocks.1.attn.hook_k
blocks.1.attn.hook_q
blocks.1.attn.hook_v
blocks.1.attn.hook_z
blocks.1.attn.hook_attn_scores
blocks.1.attn.hook_pattern
blocks.1.attn.hook_result
blocks.1.attn.hook_rot_k
blocks.1.attn.hook_rot_q
blocks.2.attn.hook_k
blocks.2.attn.hook_q
blocks.2.attn.hook_v
blocks.2.attn.hook_z
blocks.2.attn.hook_attn_scores
blocks.2.attn.hook_pattern
blocks.2.attn.hook_result
blocks.2.attn.hook_rot_k
blocks.2.attn.hook_rot_q
blocks.3.attn.hook_k
blocks.3.attn.hook_q
blocks.3.attn.hook_v
blocks.3.attn.hook_z
blocks.3.attn.hook_attn_scores
blocks.3.attn.hook_pattern
blocks.3.attn.hook_result
blocks.3.attn.hook_rot_k
blocks.3.attn.hook_rot_q
blocks.4.attn.hook_k
blocks.4.attn.hook_q
blocks.4.attn.hook_v
blocks.4.attn.hook_z
blocks.4.attn.hook_attn_scores
blocks.4.attn

In [14]:
import json
from pathlib import Path

TOKENIZED_PATH = "/content/token_segmentation_metadata.json"  # adjust if needed

def load_tokenized_examples(path: str):
    """
    Loads the entire JSON array from the file.
    """
    with open(path, "r") as f:
        data = json.load(f)
    return data

examples = load_tokenized_examples(TOKENIZED_PATH)
print(f"Loaded {len(examples)} examples")


Loaded 300 examples


In [15]:
row = examples[0]

print("Keys:", row.keys())
print("Example id:", row.get("id"))
print("constraint_tags:", row.get("constraint_tags"))

input_ids = row["input_ids"]
p_span = row["p_span"]   # or row["persona_span"]
u_span = row["u_span"]
a_span = row["a_span"]

L = len(input_ids)

print(f"seq_len (L): {L}")
print("p_span:", p_span)
print("u_span:", u_span)
print("a_span:", a_span)

p_len = p_span[1] - p_span[0] + 1
u_len = u_span[1] - u_span[0] + 1
a_len = a_span[1] - a_span[0] + 1

print(f"p_len={p_len}, u_len={u_len}, a_len={a_len}, sum={p_len + u_len + a_len}")


Keys: dict_keys(['id', 'dataset', 'input_ids', 'p_span', 'u_span', 'a_span'])
Example id: alpaca:22052
constraint_tags: None
seq_len (L): 22
p_span: [0, 18]
u_span: [20, 20]
a_span: [22, 22]
p_len=19, u_len=1, a_len=1, sum=21


In [16]:
from functools import partial

def make_attention_store():
    return {}

def save_attention_hook(attn, hook, store):
    layer_idx = hook.layer()
    # attn shape: [batch, n_heads, L, L]
    store[layer_idx] = attn.detach().cpu()

def get_attention_hooks(model, store):
    hooks = []
    all_hooks = {hp.name for hp in model.hook_points()}

    for layer in range(model.cfg.n_layers):
        hook_name = f"blocks.{layer}.attn.hook_pattern"
        if hook_name not in all_hooks:
            raise KeyError(
                f"{hook_name} not found in model.hook_points(). "
                f"Available attention hooks include: "
                f"{[h for h in all_hooks if 'attn' in h]}"
            )
        hook_fn = partial(save_attention_hook, store=store)
        hooks.append((hook_name, hook_fn))
    return hooks


In [17]:
import torch

def run_and_get_attention(model, row):
    """
    Given a tokenized row (with 'input_ids'), return:
        dict[layer_idx] -> tensor [n_heads, L, L]
    """
    input_ids = row["input_ids"]
    tokens = torch.tensor([input_ids], device=None, dtype=torch.long)  # [1, L]

    store = make_attention_store()
    hooks = get_attention_hooks(model, store)

    with torch.no_grad():
        _ = model.run_with_hooks(tokens, fwd_hooks=hooks)

    final_store = {}
    for layer_idx, attn_tensor in store.items():
        # attn_tensor: [1, n_heads, L, L]
        final_store[layer_idx] = attn_tensor[0]  # [n_heads, L, L]

    return final_store

In [18]:
attn_store = run_and_get_attention(model, row)

print("Layers in attn_store:", sorted(attn_store.keys()))
layer0 = sorted(attn_store.keys())[0]
print("Layer 0 attention shape:", attn_store[layer0].shape)  # expect [n_heads, L, L]

# Check softmax normalization: row sums ~1
head0 = attn_store[layer0][0]  # [L, L]
row_sums = head0.sum(dim=-1)
print("Row sums mean:", row_sums.mean().item())
print("Row sums std:", row_sums.std().item())


Layers in attn_store: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
Layer 0 attention shape: torch.Size([32, 22, 22])
Row sums mean: 1.0
Row sums std: 3.4412757088375656e-08


In [19]:
# Extract span indices

def get_span_indices(row):
    L = len(row["input_ids"])

    def clamp_span(span):
        s, e = span
        s = max(0, min(s, L-1))
        e = max(0, min(e, L-1))
        if e < s:
            e = s
        return list(range(s, e+1))

    P = clamp_span(row["p_span"])
    U = clamp_span(row["u_span"])
    A = clamp_span(row["a_span"])

    return P, U, A


P, U, A = get_span_indices(row)


In [20]:
import torch

def region_score(attn_head, from_idx, to_idx):
    if len(from_idx) == 0 or len(to_idx) == 0:
        return 0.0

    from_t = torch.tensor(from_idx, dtype=torch.long)
    to_t = torch.tensor(to_idx, dtype=torch.long)

    # submatrix: rows = FROM tokens, columns = TO tokens
    sub = attn_head[from_t][:, to_t]       # [|FROM|, |TO|]
    num = sub.sum().item()

    total_from = attn_head[from_t].sum().item()  # ≈ |FROM| because rows sum to 1

    return num / total_from if total_from > 0 else 0.0


In [21]:
from statistics import mean

def compute_layer_metrics(attn_layer, P, U, A):
    """
    attn_layer: [n_heads, L, L]
    """

    heads = []
    for h in range(attn_layer.shape[0]):
        head = attn_layer[h]  # [L, L]

        pam = region_score(head, A, P)  # assistant → persona
        qam = region_score(head, A, U)  # assistant → user
        sam = region_score(head, A, A)  # assistant → assistant

        heads.append({
            "head": h,
            "PAM": pam,
            "QAM": qam,
            "SAM": sam
        })

    # Layer-level aggregates
    layer_pam = mean(h["PAM"] for h in heads)
    layer_qam = mean(h["QAM"] for h in heads)
    layer_sam = mean(h["SAM"] for h in heads)

    return {
        "PAM": layer_pam,
        "QAM": layer_qam,
        "SAM": layer_sam,
        "heads": heads,
    }


In [22]:
# Test the first layer

P, U, A = get_span_indices(row)
print("len P/U/A:", len(P), len(U), len(A))

layer0 = sorted(attn_store.keys())[0]      # you already used this above
attn_layer0 = attn_store[layer0]           # [n_heads, L, L]

metrics_layer0 = compute_layer_metrics(attn_layer0, P, U, A)

print("Layer 0 — PAM/QAM/SAM (averaged over heads):")
print(metrics_layer0["PAM"], metrics_layer0["QAM"], metrics_layer0["SAM"])

print("\nFirst few heads:")
for h in metrics_layer0["heads"][:5]:
    print(h)


len P/U/A: 19 1 1
Layer 0 — PAM/QAM/SAM (averaged over heads):
0.6838161382990233 0.09756367403000328 0.1226047438733739

First few heads:
{'head': 0, 'PAM': 0.5002431273460388, 'QAM': 0.16298165917396545, 'SAM': 0.15198931097984314}
{'head': 1, 'PAM': 0.00022489417460747063, 'QAM': 0.14286160469055176, 'SAM': 0.8524389863014221}
{'head': 2, 'PAM': 0.5554549694061279, 'QAM': 0.14663861691951752, 'SAM': 0.13089054822921753}
{'head': 3, 'PAM': 0.7453128099441528, 'QAM': 0.08623126149177551, 'SAM': 0.0899389386177063}
{'head': 4, 'PAM': 0.1423267275094986, 'QAM': 0.34588828682899475, 'SAM': 0.009316587820649147}


In [23]:
# Get spans once for this row
P, U, A = get_span_indices(row)

# Compute all the layers
all_layer_metrics = []

for layer_idx in sorted(attn_store.keys()):
    attn_layer = attn_store[layer_idx]   # [n_heads, L, L]
    metrics = compute_layer_metrics(attn_layer, P, U, A)

    all_layer_metrics.append({
        "layer": int(layer_idx),
        "PAM": metrics["PAM"],
        "QAM": metrics["QAM"],
        "SAM": metrics["SAM"],
        "heads": metrics["heads"],
    })

# Quick preview
print("Number of layers:", len(all_layer_metrics))
print("Layer 0 metrics summary:", all_layer_metrics[0])
print("Layer 1 metrics summary:", all_layer_metrics[1])

Number of layers: 32
Layer 0 metrics summary: {'layer': 0, 'PAM': 0.6838161382990233, 'QAM': 0.09756367403000328, 'SAM': 0.1226047438733739, 'heads': [{'head': 0, 'PAM': 0.5002431273460388, 'QAM': 0.16298165917396545, 'SAM': 0.15198931097984314}, {'head': 1, 'PAM': 0.00022489417460747063, 'QAM': 0.14286160469055176, 'SAM': 0.8524389863014221}, {'head': 2, 'PAM': 0.5554549694061279, 'QAM': 0.14663861691951752, 'SAM': 0.13089054822921753}, {'head': 3, 'PAM': 0.7453128099441528, 'QAM': 0.08623126149177551, 'SAM': 0.0899389386177063}, {'head': 4, 'PAM': 0.1423267275094986, 'QAM': 0.34588828682899475, 'SAM': 0.009316587820649147}, {'head': 5, 'PAM': 0.7184080481529236, 'QAM': 0.09277660399675369, 'SAM': 0.09983798861503601}, {'head': 6, 'PAM': 0.9020717740058899, 'QAM': 0.032831061631441116, 'SAM': 0.032264720648527145}, {'head': 7, 'PAM': 0.7081997990608215, 'QAM': 0.09755054861307144, 'SAM': 0.08535626530647278}, {'head': 8, 'PAM': 0.859629980303644, 'QAM': 0.041041376652799645, 'SAM': 0.

In [24]:
# Build the full JSON

import json

def build_example_record(row, attn_store):
    # spans + lengths
    P, U, A = get_span_indices(row)
    L = len(row["input_ids"])

    layers_list = []
    for layer_idx in sorted(attn_store.keys()):
        metrics = compute_layer_metrics(attn_store[layer_idx], P, U, A)
        layers_list.append({
            "layer": int(layer_idx),
            "PAM": metrics["PAM"],
            "QAM": metrics["QAM"],
            "SAM": metrics["SAM"],
            "heads": metrics["heads"],
        })

    record = {
        "id": row.get("id"),
        "dataset": row.get("dataset", "sharegpt"),
        "constraint_tags": row.get("constraint_tags", []),
        "seq_len": L,
        "p_len": len(P),
        "u_len": len(U),
        "a_len": len(A),
        "layers": layers_list,
    }
    return record

# Build and inspect for this row
record0 = build_example_record(row, attn_store)

print(json.dumps(record0, indent=2)[:1500], "...\n")

{
  "id": "alpaca:22052",
  "dataset": "alpaca",
  "constraint_tags": [],
  "seq_len": 22,
  "p_len": 19,
  "u_len": 1,
  "a_len": 1,
  "layers": [
    {
      "layer": 0,
      "PAM": 0.6838161382990233,
      "QAM": 0.09756367403000328,
      "SAM": 0.1226047438733739,
      "heads": [
        {
          "head": 0,
          "PAM": 0.5002431273460388,
          "QAM": 0.16298165917396545,
          "SAM": 0.15198931097984314
        },
        {
          "head": 1,
          "PAM": 0.00022489417460747063,
          "QAM": 0.14286160469055176,
          "SAM": 0.8524389863014221
        },
        {
          "head": 2,
          "PAM": 0.5554549694061279,
          "QAM": 0.14663861691951752,
          "SAM": 0.13089054822921753
        },
        {
          "head": 3,
          "PAM": 0.7453128099441528,
          "QAM": 0.08623126149177551,
          "SAM": 0.0899389386177063
        },
        {
          "head": 4,
          "PAM": 0.1423267275094986,
          "QAM": 0.345888

In [25]:
def region_scores_batch(attn_layer, from_idx, to_idx):
    """
    attn_layer: [batch, n_heads, L, L]
    from_idx, to_idx: 1D LongTensors
    Returns: [batch, n_heads] fraction of attention mass FROM -> TO.
    """
    if len(from_idx) == 0 or len(to_idx) == 0:
        return attn_layer.new_zeros(attn_layer.shape[0], attn_layer.shape[1])

    sub = attn_layer[:, :, from_idx][:, :, :, to_idx]    # [B, H, |FROM|, |TO|]
    num = sub.sum(dim=(-1, -2))                          # [B, H]

    total_from = attn_layer[:, :, from_idx].sum(dim=(-1, -2))  # [B, H]
    scores = num / (total_from + 1e-9)
    return scores



In [26]:
from functools import partial

def make_metric_hook(layer_idx, store, P, U, A):
    """
    store[layer_idx] will become:
      {
        "PAM": <float>,
        "QAM": <float>,
        "SAM": <float>,
        "heads": [
          {"head": h, "PAM": float, "QAM": float, "SAM": float}, ...
        ]
      }
    """
    def hook_fn(attn, hook):
        # attn: [batch, n_heads, L, L] (batch=1)
        pam = region_scores_batch(attn, A, P)[0]  # [n_heads]
        qam = region_scores_batch(attn, A, U)[0]  # [n_heads]
        sam = region_scores_batch(attn, A, A)[0]  # [n_heads]

        pam_list = pam.detach().cpu().tolist()
        qam_list = qam.detach().cpu().tolist()
        sam_list = sam.detach().cpu().tolist()

        heads = []
        for h, (p, q, s) in enumerate(zip(pam_list, qam_list, sam_list)):
            heads.append({"head": h, "PAM": p, "QAM": q, "SAM": s})

        layer_pam = float(sum(pam_list) / len(pam_list))
        layer_qam = float(sum(qam_list) / len(qam_list))
        layer_sam = float(sum(sam_list) / len(sam_list))

        store[layer_idx] = {
            "PAM": layer_pam,
            "QAM": layer_qam,
            "SAM": layer_sam,
            "heads": heads,
        }
    return hook_fn


In [33]:
def run_and_get_metrics(model, row, device=None):
    if device is None:
        device = next(model.parameters()).device

    input_ids = torch.tensor([row["input_ids"]], device=device)

    P_idx, U_idx, A_idx = get_span_indices(row)
    P = torch.tensor(P_idx, device=device, dtype=torch.long)
    U = torch.tensor(U_idx, device=device, dtype=torch.long)
    A = torch.tensor(A_idx, device=device, dtype=torch.long)

    store = {}

    all_hooks = {hp.name for hp in model.hook_points()}
    fwd_hooks = []
    for layer in range(model.cfg.n_layers):
        hook_name = f"blocks.{layer}.attn.hook_pattern"
        if hook_name not in all_hooks:
            raise KeyError(f"{hook_name} not in hook_points")
        hook_fn = make_metric_hook(layer, store, P, U, A)
        fwd_hooks.append((hook_name, hook_fn))

    # inference_mode uses less memory than no_grad
    with torch.inference_mode():
        _ = model.run_with_hooks(input_ids, fwd_hooks=fwd_hooks)

    return store


In [28]:
def build_example_record_from_store(row, layer_store):
    P_idx, U_idx, A_idx = get_span_indices(row)
    L = len(row["input_ids"])

    layers_list = []
    for layer_idx in sorted(layer_store.keys()):
        lm = layer_store[layer_idx]
        layers_list.append({
            "layer": int(layer_idx),
            "PAM": lm["PAM"],
            "QAM": lm["QAM"],
            "SAM": lm["SAM"],
            "heads": lm["heads"],
        })

    return {
        "id": row.get("id"),
        "dataset": row.get("dataset", "sharegpt"),
        "constraint_tags": row.get("constraint_tags", []),
        "seq_len": L,
        "p_len": len(P_idx),
        "u_len": len(U_idx),
        "a_len": len(A_idx),
        "layers": layers_list,
    }


In [29]:
metrics_store = run_and_get_metrics(model, row)
record0 = build_example_record_from_store(row, metrics_store)

import json
print(json.dumps(record0, indent=2)[:1500], "...\n")

{
  "id": "alpaca:22052",
  "dataset": "alpaca",
  "constraint_tags": [],
  "seq_len": 22,
  "p_len": 19,
  "u_len": 1,
  "a_len": 1,
  "layers": [
    {
      "layer": 0,
      "PAM": 0.6838161504788332,
      "QAM": 0.09756367398767907,
      "SAM": 0.12260474392314791,
      "heads": [
        {
          "head": 0,
          "PAM": 0.5002431869506836,
          "QAM": 0.16298165917396545,
          "SAM": 0.15198931097984314
        },
        {
          "head": 1,
          "PAM": 0.0002248941600555554,
          "QAM": 0.14286160469055176,
          "SAM": 0.8524389863014221
        },
        {
          "head": 2,
          "PAM": 0.5554550290107727,
          "QAM": 0.14663861691951752,
          "SAM": 0.13089054822921753
        },
        {
          "head": 3,
          "PAM": 0.7453128099441528,
          "QAM": 0.08623126149177551,
          "SAM": 0.0899389386177063
        },
        {
          "head": 4,
          "PAM": 0.1423267275094986,
          "QAM": 0.345888

In [31]:
MAX_CTX = model.cfg.n_ctx  # should be 4096 for Llama-2
print("Model max context:", MAX_CTX)

Model max context: 4096


In [34]:
import json, time, gc, torch

OUTPUT_PATH = "attention_metrics.jsonl"

MAX_CTX = model.cfg.n_ctx          # still keep this
print("Model n_ctx:", MAX_CTX)

start = time.time()
processed = 0
skipped_long = 0
skipped_oom = 0

with open(OUTPUT_PATH, "w") as f:
    for i, ex in enumerate(examples):
        L = len(ex["input_ids"])

        # 1) skip sequences that exceed model context
        if L > MAX_CTX:
            skipped_long += 1
            if skipped_long <= 10:
                print(f"Skipping example {i} (len={L} > MAX_CTX={MAX_CTX})")
            continue

        try:
            # 2) try full-length analysis
            metrics_store = run_and_get_metrics(model, ex)
            record = build_example_record_from_store(ex, metrics_store)
            f.write(json.dumps(record) + "\n")
            processed += 1

            del metrics_store, record
            gc.collect()

        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                skipped_oom += 1
                print(f"OOM on example {i} (len={L}) – skipping. "
                      f"OOMS so far: {skipped_oom}")
                torch.cuda.empty_cache()
                continue
            else:
                # if it's some other error, re-raise
                raise

        if (processed + skipped_long + skipped_oom) % 50 == 0:
            print(f"{processed + skipped_long + skipped_oom}/{len(examples)} processed "
                  f"(kept {processed}, skipped_long {skipped_long}, skipped_oom {skipped_oom})")

elapsed = time.time() - start
print(f"\nDone. Kept {processed}, skipped_long {skipped_long}, "
      f"skipped_oom {skipped_oom} in {elapsed:.1f}s "
      f"(~{elapsed/max(processed,1):.3f}s per kept example)")


Model n_ctx: 4096
OOM on example 2 (len=3635) – skipping. OOMS so far: 1
Skipping example 6 (len=24291 > MAX_CTX=4096)
Skipping example 14 (len=10750 > MAX_CTX=4096)
Skipping example 15 (len=9342 > MAX_CTX=4096)
Skipping example 18 (len=7653 > MAX_CTX=4096)
Skipping example 21 (len=6829 > MAX_CTX=4096)
OOM on example 27 (len=3493) – skipping. OOMS so far: 2
Skipping example 34 (len=4221 > MAX_CTX=4096)
OOM on example 42 (len=3605) – skipping. OOMS so far: 3
Skipping example 48 (len=5533 > MAX_CTX=4096)
50/300 processed (kept 40, skipped_long 7, skipped_oom 3)
OOM on example 50 (len=3332) – skipping. OOMS so far: 4
Skipping example 78 (len=9460 > MAX_CTX=4096)
Skipping example 96 (len=7953 > MAX_CTX=4096)
100/300 processed (kept 87, skipped_long 9, skipped_oom 4)
Skipping example 102 (len=6303 > MAX_CTX=4096)
OOM on example 114 (len=2695) – skipping. OOMS so far: 5
OOM on example 121 (len=3668) – skipping. OOMS so far: 6
150/300 processed (kept 132, skipped_long 12, skipped_oom 6)
OOM o