In [1]:
!pip install transformer_lens

import json
from functools import partial
import torch
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer

# HF model name
HF_MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

# Device (GPU preferred)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Path to your tokenized dataset file
TOKENIZED_PATH = "/content/token_segmentation_metadata (1).json"





VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from huggingface_hub import login

login()


In [2]:
def load_tokenizer(model_name: str = HF_MODEL_NAME):
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

    # Ensure pad token exists
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return tokenizer

tokenizer = load_tokenizer()
print("Tokenizer loaded.")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Tokenizer loaded.


In [3]:
import torch
torch.cuda.empty_cache()  # Clear GPU memory cache


In [4]:
def load_model(model_name: str = HF_MODEL_NAME, device: str = DEVICE):
    model = HookedTransformer.from_pretrained(
        model_name,
        device=device,
        dtype=torch.float16,  # Use FP16 for reduced memory usage
    )
    model.eval()
    return model

model = load_model()
print("Model loaded on:", DEVICE)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Model loaded on: cuda


In [8]:
def make_attention_store():
    """
    Creates an empty Python dictionary that will be filled
    during the forward pass with attention matrices.
    """
    return {}

def save_attention_hook(attn, hook, store):
    """
    Hook called by TransformerLens every time attention is computed
    in a specific layer.
    """
    layer_idx = hook.layer()
    store[layer_idx] = attn.detach().cpu()

print("Attention hook system ready.")


Attention hook system ready.


In [9]:
def get_attention_hooks(model, store):
    """
    Attach hooks to the attention *pattern* (softmax probs) in every layer.
    """
    hooks = []
    all_hooks = set(hp.name for hp in model.hook_points())

    for layer in range(model.cfg.n_layers):
        hook_name = f"blocks.{layer}.attn.hook_pattern"

        if hook_name not in all_hooks:
            raise KeyError(
                f"{hook_name} not found in model.hook_points(). "
                f"Available attention hooks include: "
                f"{[h for h in all_hooks if 'attn' in h]}"
            )

        hook_fn = partial(save_attention_hook, store=store)
        hooks.append((hook_name, hook_fn))

    return hooks


In [11]:
def run_and_get_attention(model, row):
    """
    Core function to extract attention weights.

    Parameters
    ----------
    model : HookedTransformer
        The TransformerLens model loaded earlier.

    row : dict
        A dictionary containing tokenized information, including 'input_ids'.

    Returns
    -------
    final_store : dict
        Mapping layer_idx -> attention tensor of shape [n_heads, L, L],
        all moved to CPU for easy downstream processing.
    """

    # Extract the input_ids from the row (which contains the concatenated tokens)
    input_ids = row['input_ids']

    # Ensure tokens = [1, L]
    tokens = torch.tensor([input_ids], device=DEVICE, dtype=torch.long)  # Create a tensor with shape [1, L]

    assert tokens.ndim == 2 and tokens.shape[0] == 1, \
        f"Expected tokens of shape [1, L], got {tokens.shape}"

    # 1. Create empty store
    store = make_attention_store()

    # 2. Build hooks (one for each layer)
    hooks = get_attention_hooks(model, store)

    # 3. Run forward pass with hooks
    with torch.no_grad():
        _ = model.run_with_hooks(
            tokens,
            fwd_hooks=hooks
        )

    # 4. Post-process:
    # Each entry is [1, n_heads, L, L] → squeeze batch dim to [n_heads, L, L]
    final_store = {}
    for layer_idx, attn_tensor in store.items():
        final_store[layer_idx] = attn_tensor[0]  # remove batch dimension

    return final_store


In [12]:
def load_tokenized_example(path: str):
    """
    Returns the first non-empty JSON line as a Python dict.
    """
    with open(path, "r") as f:
        try:
            data = json.load(f)  # Load the entire JSON array
            if data:
                return data[0]  # Return the first example
            else:
                raise ValueError("The JSON array is empty.")
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
            return None

# Assuming you've loaded the correct row using the function above
row = load_tokenized_example(TOKENIZED_PATH)

# Make sure row is a valid dictionary before processing
if isinstance(row, dict):
    # Run the attention extraction and checks
    attn_store = run_and_get_attention(model, row)

    # How many layers did we capture?
    layer_indices = sorted(attn_store.keys())
    print("Layers captured:", len(layer_indices))
    print("Model reports n_layers =", model.cfg.n_layers)

    # Show shapes for the first couple of layers
    for layer_idx in layer_indices[:2]:
        attn_layer = attn_store[layer_idx]  # shape [n_heads, L, L]
        print(f"Layer {layer_idx} attention shape:", attn_layer.shape)

    # Row-sum sanity check for one head in one layer
    some_layer = layer_indices[0]  # Select the first layer for the sanity check
    attn_layer0 = attn_store[some_layer]    # Attention for the first layer, shape [n_heads, L, L]
    head0 = attn_layer0[0]                  # Select the first head in this layer, shape [L, L]

    # Sum along the last dimension (for each row) to verify normalization
    row_sums = head0.sum(dim=-1)            # Sum along the last axis (columns), shape [L]

    # Print the mean and standard deviation of row sums
    print(f"\nRow sums for layer {some_layer}, head 0:")
    print(f"  mean: {row_sums.mean().item():.6f}")
    print(f"  std:  {row_sums.std().item():.6f}")

    print("\nIf mean ≈ 1.0 and std is small, attention extraction is working.")
else:
    print("Error: 'row' is not a valid dictionary.")


Layers captured: 32
Model reports n_layers = 32
Layer 0 attention shape: torch.Size([32, 22, 22])
Layer 1 attention shape: torch.Size([32, 22, 22])

Row sums for layer 0, head 0:
  mean: 1.000000
  std:  0.000000

If mean ≈ 1.0 and std is small, attention extraction is working.


In [13]:
for hp in model.hook_points():
    if ".attn." in hp.name:  # Check if the hook point involves attention
        print(hp.name)


blocks.0.attn.hook_k
blocks.0.attn.hook_q
blocks.0.attn.hook_v
blocks.0.attn.hook_z
blocks.0.attn.hook_attn_scores
blocks.0.attn.hook_pattern
blocks.0.attn.hook_result
blocks.0.attn.hook_rot_k
blocks.0.attn.hook_rot_q
blocks.1.attn.hook_k
blocks.1.attn.hook_q
blocks.1.attn.hook_v
blocks.1.attn.hook_z
blocks.1.attn.hook_attn_scores
blocks.1.attn.hook_pattern
blocks.1.attn.hook_result
blocks.1.attn.hook_rot_k
blocks.1.attn.hook_rot_q
blocks.2.attn.hook_k
blocks.2.attn.hook_q
blocks.2.attn.hook_v
blocks.2.attn.hook_z
blocks.2.attn.hook_attn_scores
blocks.2.attn.hook_pattern
blocks.2.attn.hook_result
blocks.2.attn.hook_rot_k
blocks.2.attn.hook_rot_q
blocks.3.attn.hook_k
blocks.3.attn.hook_q
blocks.3.attn.hook_v
blocks.3.attn.hook_z
blocks.3.attn.hook_attn_scores
blocks.3.attn.hook_pattern
blocks.3.attn.hook_result
blocks.3.attn.hook_rot_k
blocks.3.attn.hook_rot_q
blocks.4.attn.hook_k
blocks.4.attn.hook_q
blocks.4.attn.hook_v
blocks.4.attn.hook_z
blocks.4.attn.hook_attn_scores
blocks.4.attn

In [16]:
import json

# Function to load all tokenized examples from the file
def load_tokenized_examples(path: str):
    """
    Loads the entire JSON array from the file.
    """
    with open(path, "r") as f:
        try:
            data = json.load(f)  # Load the entire JSON array
            return data
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
            return []

# Load the tokenized examples from the file
TOKENIZED_PATH = "/content/token_segmentation_metadata (1).json"  # Adjust path as needed
examples = load_tokenized_examples(TOKENIZED_PATH)

# Ensure examples are loaded correctly
if not examples:
    print("No examples loaded from the file.")
else:
    # Iterate over each example in the list
    for example in examples:
        input_ids = example["input_ids"]
        p_span = example["p_span"]
        u_span = example["u_span"]
        a_span = example["a_span"]

        L = len(input_ids)

        print(f"\nProcessing example {example['id']}:")
        print(f"Token lengths → P={p_span[1] - p_span[0] + 1}, U={u_span[1] - u_span[0] + 1}, A={a_span[1] - a_span[0] + 1}, Total={L}")

        tokens = torch.tensor([input_ids], device=DEVICE, dtype=torch.long)
        print("Token tensor shape:", tokens.shape)

        print("\n[Decoded P snippet]")
        print(tokenizer.decode(input_ids[p_span[0]:p_span[1] + 1])[:200].replace("\n", "\\n"))

        print("\n[Decoded U snippet]")
        print(tokenizer.decode(input_ids[u_span[0]:u_span[1] + 1])[:200].replace("\n", "\\n"))

        print("\n[Decoded A snippet]")
        print(tokenizer.decode(input_ids[a_span[0]:a_span[1] + 1])[:200].replace("\n", "\\n"))



Processing example alpaca:22052:
Token lengths → P=19, U=1, A=1, Total=22
Token tensor shape: torch.Size([1, 22])

[Decoded P snippet]
Summarize the our goals with GPT model in no more than 8 words.\n

[Decoded U snippet]
\n

[Decoded A snippet]


Processing example alpaca:24364:
Token lengths → P=17, U=1, A=1, Total=20
Token tensor shape: torch.Size([1, 20])

[Decoded P snippet]
Write a poem about spring. Output should be less than 80 words.\n

[Decoded U snippet]
\n

[Decoded A snippet]


Processing example sharegpt_en:8239:
Token lengths → P=264, U=537, A=2833, Total=3635
Token tensor shape: torch.Size([1, 3635])

[Decoded P snippet]
You are a pregnancy health &amp; nutrition expert and a mother of 3 children. \n\nYou have strong knowledge and hands-on experience on pregnancy topics.\n\nYou have your own column in a major media.\n\nYou 

[Decoded U snippet]
[topik]\nciri-ciri kehamilan tidak berkembang\n\n[outline]\nI. Introduction\na. Definition of "kehamilan tidak berkembang"\nb.