In [1]:
%load_ext autoreload
%autoreload 2

import json
import random
import string
import re
from pathlib import Path
from collections import Counter

import torch
import nnsight
from transformers import AutoTokenizer

# --- Parameters ---
# Match the plan's requirements
DATASET_SIZE = 32
LEXICON_CONDITION = "mixed" 
SEED = 23
MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"
OUTPUT_FILE = Path("../data/patching_dataset_N32_mixed.jsonl")

# For reproducibility
random.seed(SEED)
torch.manual_seed(SEED)

# Make sure the output directory exists
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)

In [2]:
import gc

def free_gpu_cache():
    gc.collect()
    torch.cuda.empty_cache()

free_gpu_cache()

In [3]:
from transformers import AutoTokenizer
import nnsight
import torch

# --- Load Model and Tokenizer (Revised) ---

# 1. Load the "fast" tokenizer explicitly.
# This is crucial because the fast version has the .char_to_token() method we need.
# trust_remote_code=True is often needed for newer models like Llama 3.
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)

# Llama-3 tokenizers don't have a pad_token by default. We'll set it to the eos_token.
# This is good practice and prevents potential issues later.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 2. Load the model via nnsight, passing our pre-loaded tokenizer.
# This ensures nnsight uses the exact tokenizer object we need.
model = nnsight.LanguageModel(
    MODEL_NAME,
    tokenizer=tokenizer,  # Pass the tokenizer here
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="eager"
)

print("Model and fast tokenizer loaded successfully.")

Model and fast tokenizer loaded successfully.


In [37]:
FV_HEADS = [
    (33, 52, 0.0297),
    (37, 2, 0.0224),
    (34, 2, 0.0209),
    (34, 24, 0.0194),
    (66, 7, 0.0185),
    (34, 25, 0.0167),
    (76, 10, 0.0155),
    (69, 39, 0.0152),
    (31, 18, 0.0147),
    (76, 63, 0.0142),
    (34, 41, 0.0116),
    (38, 49, 0.0112),
    (49, 1, 0.0107),
    (32, 3, 0.0102),
    (36, 22, 0.0102),
    (22, 55, 0.0101),
    (34, 51, 0.01),
    (67, 30, 0.0098),
    (69, 41, 0.0097),
    (16, 37, 0.0097),
    (34, 52, 0.0096),
    (55, 20, 0.0094),
    (40, 39, 0.0093),
    (37, 29, 0.0091),
    (76, 57, 0.009),
    (38, 23, 0.0087),
    (33, 31, 0.0087),
    (77, 42, 0.0087),
    (35, 57, 0.0084),
    (44, 7, 0.0081),
    (34, 31, 0.0076),
    (18, 7, 0.0075),
    (38, 55, 0.0073),
    (55, 17, 0.0069),
    (79, 1, 0.0068),
    (33, 3, 0.0067),
    (54, 8, 0.0067),
    (35, 47, 0.0066),
    (74, 9, 0.0065),
    (45, 6, 0.0064),
    (48, 56, 0.0063),
    (38, 48, 0.0061),
    (33, 11, 0.0061),
    (38, 28, 0.006),
    (39, 44, 0.0059),
    (38, 62, 0.0058),
    (21, 34, 0.0057),
    (33, 9, 0.0057),
    (75, 28, 0.0057),
    (70, 40, 0.0053),
    (35, 21, 0.0051),
    (34, 5, 0.0051),
    (44, 1, 0.0049),
    (77, 43, 0.0047),
    (34, 54, 0.0047),
    (38, 61, 0.0046),
    (39, 41, 0.0046),
    (38, 18, 0.0046),
    (14, 61, 0.0045),
    (77, 21, 0.0044),
    (73, 15, 0.0044),
    (33, 32, 0.0043),
    (72, 40, 0.004),
    (35, 16, 0.004),
    (36, 3, 0.0039),
    (78, 51, 0.0037),
    (31, 7, 0.0036),
    (17, 14, 0.0035),
    (47, 18, 0.0034),
    (20, 53, 0.0034),
    (49, 4, 0.0033),
    (63, 59, 0.0033),
    (33, 42, 0.0033),
    (32, 22, 0.0033),
    (36, 4, 0.0033),
    (24, 7, 0.0033),
    (27, 22, 0.0032),
    (75, 1, 0.0032),
    (73, 61, 0.0031),
    (35, 44, 0.0031),
    (49, 6, 0.003),
    (30, 24, 0.003),
    (41, 22, 0.003),
    (31, 17, 0.003),
    (60, 28, 0.003),
    (79, 11, 0.003),
    (34, 40, 0.003),
    (31, 6, 0.0029),
    (73, 14, 0.0028),
    (73, 7, 0.0028),
    (32, 5, 0.0028),
    (75, 54, 0.0028),
    (77, 6, 0.0028),
    (35, 45, 0.0027),
    (31, 55, 0.0027),
    (26, 31, 0.0027),
    (34, 47, 0.0027),
    (79, 32, 0.0027),
    (74, 13, 0.0027),
    (33, 48, 0.0026),
]

In [4]:
ANALOGY_HEADS = [(33, 16), (14, 21), (27, 21), (35, 19), (34, 4), (30, 52), (39, 43), (34, 48), (33, 42), (33, 14), (35, 18), (28, 1), (34, 44), (27, 38), (49, 7), (28, 45), (35, 17), (34, 43), (27, 22), (31, 38), (31, 35), (16, 48), (20, 8), (38, 48), (32, 18), (28, 40), (29, 30), (35, 34), (30, 33), (30, 24), (24, 6), (33, 41), (2, 60), (5, 46), (27, 20), (18, 43), (33, 54), (25, 60), (37, 30), (28, 41), (38, 52), (26, 37), (32, 43), (39, 40), (49, 2), (33, 45), (31, 20), (28, 6), (24, 30), (54, 11), (31, 39), (18, 37), (8, 43), (30, 29), (28, 46), (32, 7), (15, 19), (15, 21), (17, 61), (6, 6), (4, 45), (28, 47), (30, 51), (33, 15), (21, 26), (19, 51), (22, 7), (32, 21), (33, 9), (11, 34), (16, 31), (26, 23), (31, 28), (31, 0), (12, 3), (22, 29), (1, 47), (34, 46), (24, 7), (23, 42)]

In [40]:
FILTER_HEADS = [
    (29, 60), (29, 62), (30, 48), (30, 52), (30, 62), (31, 0), (31, 32), (31, 34), (31, 36), (31, 37), (31, 38), (31, 40), (31, 43), (31, 49), (32, 0), (32, 6), (33, 0), (33, 16), (33, 18), (33, 41), (33, 45), (34, 1), (34, 6), (34, 7), (34, 37), (34, 43), (34, 46), (35, 0), (35, 7), (35, 17), (35, 18), (35, 19), (35, 20), (35, 22), (35, 27), (35, 28), (35, 36), (35, 40), (35, 43), (35, 49), (36, 17), (36, 22), (36, 40), (36, 44), (36, 47), (36, 52), (36, 54), (37, 0), (37, 4), (37, 7), (37, 16), (37, 30), (37, 32), (37, 35), (37, 36), (37, 38), (37, 39), (38, 19), (38, 49), (38, 50), (38, 51), (39, 35), (39, 36), (39, 41), (39, 42), (39, 44), (39, 45), (41, 19), (42, 28), (42, 30), (42, 31), (45, 1), (45, 59), (45, 62), (47, 17), (47, 18), (49, 1), (49, 4), (49, 5), (49, 7)
]

In [14]:
import json

def read_jsonl_to_list(file_path):
    """Read a JSONL file and return a list of dictionaries."""
    with open(file_path, 'r', encoding='utf-8') as file:
        return [json.loads(line.strip()) for line in file if line.strip()]

# Usage
data = read_jsonl_to_list(OUTPUT_FILE)

In [19]:
prompts = [d['story']+"\n"+d['analogy'] for d in data]
prompts[0]

'The feclur is made of placlugav. The artifact is crafted out of resin. The feclur is within the frifeslo. The artifact is within the library.\nartifact is to feclur as resin is to'

In [20]:
from circuitsvis.tokens import colored_tokens

def patch(
    model,
    tokenizer,
    inputs,
    model_kwargs,
):
    seq_len = inputs.input_ids

def get_model_out(
    input,
    model,
    tokenizer
):

    with model.trace(input, output_attentions=True):
        output = model.model.output.save()
        logits = model.lm_head.output.save()
    return output, logits

def visualize_attn_matrix(
    attn_matrix: torch.Tensor,
    tokens: list[str],
    q_index: int = -1,
    start_from: int = 1,
):
    assert len(tokens) == attn_matrix.shape[-1], (
        f"{len(tokens)=}, {attn_matrix.shape[-1]=}"
    )
    attn_matrix = attn_matrix[q_index][start_from:]
    tokens = tokens[start_from:]
    display(colored_tokens(tokens=tokens, values=attn_matrix))

def verify_head_patterns(
    prompt,
    model,
    tokenizer,
    heads = None,
    query_idx = -1,
    start_from = 1
):

    tokenized_prompt_ids = tokenizer(prompt).input_ids
    str_tokens = tokenizer.convert_ids_to_tokens(tokenized_prompt_ids)
    str_tokens = [s.replace("Ġ", " ") for s in str_tokens]
    str_tokens = [s.replace("Ċ", " ") for s in str_tokens]

    return_dict = {}

    output, logits = get_model_out(
        input=prompt,
        model=model,
        tokenizer=tokenizer,
    )
    output_attns = [attn[0].cuda() for attn in output.attentions]
    attn_matrices = torch.stack(output_attns, dim=0)

    print(attn_matrices.shape)

    if heads is not None and len(heads) > 0:
        combined = []
        if len(heads[0]) == 3:
            for layer_idx, head_idx, _ in heads:
                head_matrix = torch.Tensor(
                attn_matrices[layer_idx, head_idx].squeeze()
                )
                combined.append(head_matrix)
        elif len(heads[0]) == 2:
            for layer_idx, head_idx in heads:
                    head_matrix = torch.Tensor(
                    attn_matrices[layer_idx, head_idx].squeeze()
                    )
                    combined.append(head_matrix)

        combined_matrix = torch.stack(combined).mean(dim=0)

        visualize_attn_matrix(
            attn_matrix=combined_matrix,
            tokens=str_tokens,
            q_index=query_idx,
            start_from=start_from,
        )
    
    return attn_matrices

prompt = """Finish the analogy in one word.
The feclur is made of placlugav. The artifact is crafted out of resin. The feclur is within the frifeslo. The artifact is within the library.

artifact is to feclur as resin is to"""

#verify_head_patterns(prompt, model, tokenizer, heads=FV_HEADS, query_idx=-3)
#verify_head_patterns(prompt, model, tokenizer, heads=FILTER_HEADS, query_idx=-3)
for prompt in prompts:
    verify_head_patterns(prompt, model, tokenizer, heads=ANALOGY_HEADS[:50], query_idx=-3)

torch.Size([80, 64, 49, 49])


torch.Size([80, 64, 51, 51])


torch.Size([80, 64, 55, 55])


torch.Size([80, 64, 48, 48])


torch.Size([80, 64, 47, 47])


torch.Size([80, 64, 46, 46])


torch.Size([80, 64, 48, 48])


torch.Size([80, 64, 59, 59])


torch.Size([80, 64, 49, 49])


torch.Size([80, 64, 50, 50])


torch.Size([80, 64, 62, 62])


torch.Size([80, 64, 55, 55])


torch.Size([80, 64, 47, 47])


torch.Size([80, 64, 52, 52])


torch.Size([80, 64, 58, 58])


torch.Size([80, 64, 54, 54])


torch.Size([80, 64, 51, 51])


torch.Size([80, 64, 56, 56])


torch.Size([80, 64, 52, 52])


torch.Size([80, 64, 55, 55])


torch.Size([80, 64, 54, 54])


torch.Size([80, 64, 48, 48])


torch.Size([80, 64, 56, 56])


torch.Size([80, 64, 56, 56])


torch.Size([80, 64, 48, 48])


torch.Size([80, 64, 53, 53])


torch.Size([80, 64, 50, 50])


torch.Size([80, 64, 55, 55])


torch.Size([80, 64, 56, 56])


torch.Size([80, 64, 50, 50])


torch.Size([80, 64, 55, 55])


torch.Size([80, 64, 47, 47])
