In [5]:
#Imports + Config

!pip install transformer_lens

import json
from functools import partial

import torch
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer

# HF model name
HF_MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

# Device (GPU preferred)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Path to your tokenized dataset file
TOKENIZED_PATH = "/mnt/data/selected_all_tokenized (1).jsonl"




In [3]:
from huggingface_hub import login

login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
# Load the tokenizer (for debugging only)
# We’ll use the token IDs as inputs, but we still want a tokenizer to decode tokens for sanity checks

def load_tokenizer(model_name: str = HF_MODEL_NAME):
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

    # Ensure pad token exists
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return tokenizer

tokenizer = load_tokenizer()
print("Tokenizer loaded.")

tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

Tokenizer loaded.


In [7]:
# Load model

def load_model(model_name: str = HF_MODEL_NAME, device: str = DEVICE):
    model = HookedTransformer.from_pretrained(
        model_name,
        device=device,
        dtype="float16",
    )
    model.eval()
    return model

model = load_model()
print("Model loaded on:", DEVICE)

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]



Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Model loaded on: cuda


In [8]:
# --- Attention Store Dictionary + Hook Function ---

def make_attention_store():
    """
    Creates an empty Python dictionary that will be filled
    during the forward pass with attention matrices.

    After a forward pass, it will look like:
        {
          0: tensor([n_heads, L, L]),
          1: tensor([n_heads, L, L]),
          ...
          n_layers-1: tensor([n_heads, L, L])
        }
    """
    return {}


def save_attention_hook(attn, hook, store):
    """
    Hook called by TransformerLens every time attention is computed
    in a specific layer.

    Parameters
    ----------
    attn : torch.Tensor
        Shape: [batch, n_heads, L, L]
        The raw attention probabilities for a single layer.

    hook : HookPoint
        Contains metadata; specifically hook.layer() gives the layer index.

    store : dict
        The store dictionary into which we save the attention for this layer.

    Behavior
    --------
    - Extract the layer index from the hook.
    - Move attention to CPU.
    - Detach from compute graph.
    - Save into store[layer_idx].
    """
    layer_idx = hook.layer()
    store[layer_idx] = attn.detach().cpu()


print("Attention hook system ready.")

Attention hook system ready.


In [16]:
from functools import partial

def get_attention_hooks(model, store):
    """
    Attach hooks to the attention *pattern* (softmax probs) in every layer.

    store: dict that will be filled with
       store[layer_idx] = [batch, n_heads, L, L] tensor on CPU
    """
    hooks = []
    all_hooks = set(hp.name for hp in model.hook_points())

    for layer in range(model.cfg.n_layers):
        hook_name = f"blocks.{layer}.attn.hook_pattern"

        if hook_name not in all_hooks:
            raise KeyError(
                f"{hook_name} not found in model.hook_points(). "
                f"Available attention hooks include: "
                f"{[h for h in all_hooks if 'attn' in h]}"
            )

        hook_fn = partial(save_attention_hook, store=store)
        hooks.append((hook_name, hook_fn))

    return hooks


In [10]:
# --- run_and_get_attention(model, tokens) ---

def run_and_get_attention(model, tokens):
    """
    Core function.

    Parameters
    ----------
    model : HookedTransformer
        The TransformerLens model loaded earlier.

    tokens : torch.Tensor
        Shape [1, L], i.e. one example's token IDs (P||U||A concatenated).

    Returns
    -------
    final_store : dict
        Mapping layer_idx -> attention tensor of shape [n_heads, L, L],
        all moved to CPU for easy downstream processing.
    """

    # Ensure tokens = [1, L]
    assert tokens.ndim == 2 and tokens.shape[0] == 1, \
        f"Expected tokens of shape [1, L], got {tokens.shape}"

    # 1. Create empty store
    store = make_attention_store()

    # 2. Build hooks (one for each layer)
    hooks = get_attention_hooks(model, store)

    # 3. Run forward pass with hooks
    with torch.no_grad():
        _ = model.run_with_hooks(
            tokens,
            fwd_hooks=hooks
        )

    # 4. Post-process:
    # Each entry is [1, n_heads, L, L] → squeeze batch dim to [n_heads, L, L]
    final_store = {}
    for layer_idx, attn_tensor in store.items():
        final_store[layer_idx] = attn_tensor[0]  # remove batch dimension

    return final_store

print("run_and_get_attention() defined.")


run_and_get_attention() defined.


In [11]:
# Load one tokenized example from the finished file ---

TOKENIZED_PATH = "selected_all_tokenized (1).jsonl"

def load_one_tokenized_example(path: str):
    """
    Returns the first non-empty JSON line as a Python dict.
    Used just for sanity-testing the attention extraction pipeline.
    """
    with open(path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            return json.loads(line)
    raise ValueError(f"No non-empty lines found in {path}")

row = load_one_tokenized_example(TOKENIZED_PATH)

print("Loaded example:")
print("  id:     ", row.get("id", "<no-id>"))
print("  dataset:", row.get("dataset", "<no-dataset>"))

Loaded example:
  id:      flan:664
  dataset: flan


In [12]:
# Build [P || U || A] token sequence ---

# These fields should already exist in selected_all_tokenized.jsonl
p_ids = row["p_token_ids"]
u_ids = row["u_token_ids"]
a_ids = row["a_token_ids"]

full_ids = p_ids + u_ids + a_ids
L = len(full_ids)

print(f"Token lengths → P={len(p_ids)}, U={len(u_ids)}, A={len(a_ids)}, Total={L}")

# Build tensor of shape [1, L] on the same device as the model
tokens = torch.tensor([full_ids], device=DEVICE, dtype=torch.long)
print("Token tensor shape:", tokens.shape)

# Optional: sanity decode snippets for each segment
print("\n[Decoded P snippet]")
print(tokenizer.decode(p_ids)[:200].replace("\n", "\\n"))

print("\n[Decoded U snippet]")
print(tokenizer.decode(u_ids)[:200].replace("\n", "\\n"))

print("\n[Decoded A snippet]")
print(tokenizer.decode(a_ids)[:200].replace("\n", "\\n"))


Token lengths → P=279, U=0, A=0, Total=279
Token tensor shape: torch.Size([1, 279])

[Decoded P snippet]
Can we conclude from "A man is jumping into a screened-in outdoor pool." that "The man is crazy."? Options: - yes - no - it is not possible to tell it is not possible to tell Explanation: Not only cra

[Decoded U snippet]


[Decoded A snippet]



In [15]:
# Inspect available hook points that involve attention
for hp in model.hook_points():
    if ".attn." in hp.name:
        print(hp.name)

blocks.0.attn.hook_k
blocks.0.attn.hook_q
blocks.0.attn.hook_v
blocks.0.attn.hook_z
blocks.0.attn.hook_attn_scores
blocks.0.attn.hook_pattern
blocks.0.attn.hook_result
blocks.0.attn.hook_rot_k
blocks.0.attn.hook_rot_q
blocks.1.attn.hook_k
blocks.1.attn.hook_q
blocks.1.attn.hook_v
blocks.1.attn.hook_z
blocks.1.attn.hook_attn_scores
blocks.1.attn.hook_pattern
blocks.1.attn.hook_result
blocks.1.attn.hook_rot_k
blocks.1.attn.hook_rot_q
blocks.2.attn.hook_k
blocks.2.attn.hook_q
blocks.2.attn.hook_v
blocks.2.attn.hook_z
blocks.2.attn.hook_attn_scores
blocks.2.attn.hook_pattern
blocks.2.attn.hook_result
blocks.2.attn.hook_rot_k
blocks.2.attn.hook_rot_q
blocks.3.attn.hook_k
blocks.3.attn.hook_q
blocks.3.attn.hook_v
blocks.3.attn.hook_z
blocks.3.attn.hook_attn_scores
blocks.3.attn.hook_pattern
blocks.3.attn.hook_result
blocks.3.attn.hook_rot_k
blocks.3.attn.hook_rot_q
blocks.4.attn.hook_k
blocks.4.attn.hook_q
blocks.4.attn.hook_v
blocks.4.attn.hook_z
blocks.4.attn.hook_attn_scores
blocks.4.attn

In [17]:
# Run attention extraction + sanity checks ---

attn_store = run_and_get_attention(model, tokens)

# How many layers did we capture?
layer_indices = sorted(attn_store.keys())
print("Layers captured:", len(layer_indices))
print("Model reports n_layers =", model.cfg.n_layers)

# Show shapes for the first couple of layers
for layer_idx in layer_indices[:2]:
    attn_layer = attn_store[layer_idx]  # shape [n_heads, L, L]
    print(f"Layer {layer_idx} attention shape:", attn_layer.shape)

# Row-sum sanity check for one head in one layer
some_layer = layer_indices[0]
attn_layer0 = attn_store[some_layer]    # [n_heads, L, L]
head0 = attn_layer0[0]                  # [L, L]

row_sums = head0.sum(dim=-1)            # [L]

print(f"\nRow sums for layer {some_layer}, head 0:")
print(f"  mean: {row_sums.mean().item():.6f}")
print(f"  std:  {row_sums.std().item():.6f}")

print("\nIf mean ≈ 1.0 and std is small, attention extraction is working.")


Layers captured: 32
Model reports n_layers = 32
Layer 0 attention shape: torch.Size([32, 279, 279])
Layer 1 attention shape: torch.Size([32, 279, 279])

Row sums for layer 0, head 0:
  mean: 1.000000
  std:  0.000000

If mean ≈ 1.0 and std is small, attention extraction is working.
