In [None]:
!pip install ipykernel transformer_lens plotly 

In [1]:
import transformer_lens
import torch
# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
# Import stuff
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load a model (eg GPT-2 Small)
model = transformer_lens.HookedTransformer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
# Run the model and get logits and activations
logits, activations = model.run_with_cache("Hello World")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loaded pretrained model Qwen/Qwen2.5-1.5B-Instruct into HookedTransformer


# Convert HF Transformers weights to HookedTransformer weights

In [3]:
from transformer_lens.pretrained.weight_conversions.llama import convert_llama_weights
from transformers import AutoConfig, AutoModelForCausalLM

config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
model = AutoModelForCausalLM.from_config(config)
model.load_state_dict(torch.load("checkpoints/old/llama_instruct_value_assign_finetuned.weights", weights_only=True))
hooked_config = transformer_lens.loading.get_pretrained_model_config("meta-llama/Llama-3.2-1B-Instruct")
new_state_dict = convert_llama_weights(model, hooked_config)
torch.save(new_state_dict, "checkpoints/old/llama_instruct_value_assign_finetuned_hooked.weights")


In [4]:
config = transformer_lens.loading.get_pretrained_model_config("meta-llama/Llama-3.2-1B-Instruct")
model = transformer_lens.HookedTransformer(config)
model.load_and_process_state_dict(torch.load("checkpoints/old/llama_instruct_value_assign_finetuned_hooked.weights", weights_only=True))
#config = transformer_lens.loading.get_pretrained_model_config("Qwen/Qwen2.5-1.5B-Instruct")
#model = transformer_lens.HookedTransformer(config)
#model.load_and_process_state_dict(torch.load("checkpoints/Qwen2.5-1.5B-Instruct_value_assign_finetuned_hooked.weights", weights_only=True))




In [5]:
import sys
sys.path.insert(0, "..")
sys.path.insert(0, "../generators")
sys.path.insert(0, "generators")


from generators.value_assignment import ValueAssignmentGenerator
import transformers
import itertools


data_generator = ValueAssignmentGenerator(tokenizer=model.tokenizer,
                                       seed=0,
                                       use_few_shot=True, 
                                       use_instruction=True, 
                                       apply_chat_template=True,
                                       length=(5,5),
                                       encoding_length=(5,5)
                                       )
samples = itertools.islice(data_generator.generate_samples(), 10)

In [6]:
samples = list(samples)
texts = [sample[0] for sample in samples]
input_ids = [sample[1] for sample in samples]
target_ids = [sample[2] for sample in samples]
attn_labels = [sample[3] for sample in samples]

In [7]:
texts[0]

'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 16 May 2025\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n\nYou are a great problem solver excellent in tasks requiring manipulation of symbols. You will be tasked with encoding a sequence of characters given a set of encodings. \nAn encoding is a tuple of characters, e.g. A1, where A is the character to be encoded, and 1 is its encoding. For example, given some string "A", its encoding is a sequence "1". \nGiven a set of these encodings and an input string, encode this string.\nDo not generate any words. You can only generate characters from the set of encodings.\nHere are some examples:\nH3K8U1D5G9F1 HKKHDFG=3883519\nABCDEFGHIJKL ACEGIK=BDFHJLUse them to solve the following task:\n\x04_WT&\x06CGc^ &c&WC=<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\x06^\x06TG<|eot_id|>'

In [8]:
import matplotlib.pyplot as plt
import numpy as np
import random

def visualize_attention_heads(attention_matrix, tokens=None, title="Attention Heatmap", save_path=None):
    """
    Visualize a multi-head attention matrix as heatmaps for each head.

    Parameters:
        attention_matrix (torch.Tensor): Tensor of shape (1, num_heads, seq_len, seq_len).
        tokens (list, optional): List of tokens corresponding to the rows/columns of the matrix.
        title (str): Title of the heatmap.
        save_path (str, optional): Path to save the heatmap images. If None, the heatmaps are not saved.
    """
    # Detach and convert to NumPy
    scores = attention_matrix.squeeze(0).cpu().detach().numpy()  # Shape: (num_heads, seq_len, seq_len)
    num_heads = scores.shape[0]
    for head in range(num_heads):
        head_scores = scores[head]  # Shape: (seq_len, seq_len)
        min_score = head_scores.min()
        max_score = head_scores.max()
        normalized = (head_scores - min_score) / (max_score - min_score)  # Normalize to [0, 1]
        plt.figure(figsize=(10, 8))
        plt.imshow(normalized, cmap="viridis", interpolation="nearest")
        plt.colorbar(label="Attention Score")
        plt.title(f"{title} - Head {head + 1}")
        
        if tokens:
            plt.xticks(ticks=np.arange(len(tokens)), labels=tokens, rotation=90)
            plt.yticks(ticks=np.arange(len(tokens)), labels=tokens)
        else:
            plt.xticks([])
            plt.yticks([])

        plt.xlabel("Tokens")
        plt.ylabel("Tokens")
        plt.tight_layout()

        if save_path:
            plt.savefig(f"{save_path}_head_{head + 1}.png")
        plt.show()

In [21]:
global noise
def compute_scores(logits, target):
    ignore_index = -100  # Optional: index to ignore in loss/accuracy computation
    # Compute loss
    loss = torch.nn.functional.cross_entropy(
    logits.view(-1, logits.size(-1)),  # Flatten logits to (batch_size * seq_len, vocab_size)
    target.view(-1),               # Flatten target_ids to (batch_size * seq_len)
    ignore_index=ignore_index          # Ignore padding tokens if applicable
)

    # Compute accuracy
    predictions = torch.argmax(logits, dim=-1)  # Get predicted token indices
    correct = (predictions == target) & (target != ignore_index)  # Ignore padding tokens
    accuracy = correct.sum().item() / (target != ignore_index).sum().item()

    print(f"Loss: {loss.item()}, Accuracy: {accuracy:.2%}")
    return loss.item(), accuracy

def generate_noise(value_shape, noise_level, sample_length):
    """
    Generate noise only in the lower triangle of the attention matrix for the last n tokens,
    where n is the number of tokens in the target_ids that are not -100.

    Parameters:
        value_shape (tuple): Shape of the attention matrix (batch, seq_len, head_index, d_head).
        noise_level (float): Level of noise to apply.
        target_ids (torch.Tensor): Target IDs tensor to determine the length of the sample.

    Returns:
        np.ndarray: Noise matrix with noise only in the lower triangle for the sample tokens.
    """
    batch_size, heads, seq_len, _ = value_shape
    global noise
    # Determine the length of the sample (number of tokens not equal to -100 in target_ids)
    # Initialize noise matrix
    noise = np.zeros(value_shape)

    for h in range(heads):
        # Generate noise only for the lower triangle of the sample tokens
        for i in range(seq_len):
            for j in range(i+1):  # Lower triangle condition (j < i)
                if j >= seq_len - sample_length:  # Ensure it's within the sample tokens
                    noise[0, h, i, j] = random.random() * noise_level

    return noise

def attn_noise_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    noise_level: float,
    sample_length
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    noise = generate_noise(value.shape, noise_level, sample_length)
    value[:, :, :, :] += torch.tensor(noise, dtype=torch.float32).cuda()
    #visualize_attention_heads(value, None, title="Example Attention Heatmap")
    return value

device = "cuda"

is_attn_scores = lambda name: name.endswith("hook_pattern")

model.to(device)
input_ids[0] = input_ids[0].to(device)
target_ids[0] = target_ids[0].to(device)

ablated_scores = []

for noise_level in np.linspace(0, 0.2, 100):
    print("Running ablated with noise", noise_level)
    cols_to_keep = ~np.all(attn_labels[0].cpu().numpy() == -100, axis=0)
    logits = model.run_with_hooks(input_ids[0], return_type="logits",
                                  fwd_hooks=[
                                      (is_attn_scores, partial(attn_noise_hook, noise_level=noise_level, sample_length=len(cols_to_keep)))
                                  ])
    ablated_scores.append(compute_scores(logits, target_ids[0]))
    del logits
    torch.cuda.empty_cache()

print(noise)
    

Moving model to device:  cuda
Running ablated with noise 0.0
Loss: 0.00016773004608694464, Accuracy: 100.00%
Running ablated with noise 0.00202020202020202
Loss: 0.047034427523612976, Accuracy: 100.00%
Running ablated with noise 0.00404040404040404
Loss: 2.879239082336426, Accuracy: 50.00%
Running ablated with noise 0.006060606060606061
Loss: 4.704164028167725, Accuracy: 16.67%
Running ablated with noise 0.00808080808080808
Loss: 5.344171047210693, Accuracy: 16.67%
Running ablated with noise 0.0101010101010101
Loss: 5.738537311553955, Accuracy: 16.67%
Running ablated with noise 0.012121212121212121


KeyboardInterrupt: 

In [None]:
noise[0][0]

In [14]:
noise[0, 0, 210:, 210:]

array([[0.02097557, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        ],
       [0.00995123, 0.00530414, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        ],
       [0.00084715, 0.00436938, 0.00925099, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
      