In [10]:
import json
import random

import torch
import transformer_lens.patching as patching
import numpy as np
import plotly.express as px
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM
from transformer_lens import HookedTransformer
from einops import rearrange

from cfg import ExperimentConfig

In [11]:
def imshow(
    tensor,
    x=None,
    y=None,
    title=None,
    xlabel=None,
    ylabel=None,
    facet_labels=None,
    facet_col=None,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    reverse_y=True,
    text_auto=False,
    **kwargs,
):
    data = tensor.cpu().numpy() if hasattr(tensor, "cpu") else np.array(tensor)

    fig = px.imshow(
        data,
        x=x,
        y=y,
        facet_col=facet_col,
        color_continuous_midpoint=color_continuous_midpoint,
        color_continuous_scale=color_continuous_scale,
        **kwargs,
    )

    if reverse_y:
        fig.update_yaxes(autorange=True)
    if title:
        fig.update_layout(title=title)
    if xlabel:
        fig.update_xaxes(title=xlabel)
    if ylabel:
        fig.update_yaxes(title=ylabel)
    if facet_labels and facet_col is not None:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i].text = label
    if text_auto is True:
        text = np.where(data.round(2) == 0, '', np.round(data, 2).astype(str))
        fig.update_traces(text=text, texttemplate="%{text}")

    return fig

In [12]:
torch.set_grad_enabled(False)
cfg = ExperimentConfig()

In [13]:
USE_HF = 1

In [14]:
if USE_HF:
    repo_id = "cheeetoo/trainorder"

    hf_hub_download(repo_id, filename="probes.npz", local_dir=cfg.out_dir)
    hf_hub_download(repo_id, filename="aliases.json", local_dir=cfg.out_dir)

    hf_model = AutoModelForCausalLM.from_pretrained(repo_id)
else:
    hf_model = AutoModelForCausalLM.from_pretrained(f"{cfg.out_dir}/final")

model = HookedTransformer.from_pretrained(cfg.model_id, hf_model=hf_model)



Loaded pretrained model meta-llama/Llama-3.2-1B into HookedTransformer


In [15]:
probes = np.load(f"{cfg.out_dir}/probes.npz")
coef = probes[f"{(0, cfg.num_stages - 1)}_coef"]
intercept = probes[f"{(0, cfg.num_stages - 1)}_intercept"]

with open(f"{cfg.out_dir}/aliases.json") as f:
    aliases = json.load(f)

aliases_first = aliases["stage_0"]
aliases_last = aliases[f"stage_{cfg.num_stages - 1}"]
aliases_first = random.choices(aliases_first, k=cfg.n_patching_prompts)
aliases_last = random.choices(aliases_last, k=cfg.n_patching_prompts)

In [17]:
clean_prompts = [cfg.probe_prompt.format(a) for a in aliases_last]
corrupted_prompts = [cfg.probe_prompt.format(a) for a in aliases_first]

clean_tokens = model.to_tokens(clean_prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)

_, clean_cache = model.run_with_cache(clean_tokens)

OutOfMemoryError: CUDA out of memory. Tried to allocate 26.00 MiB. GPU 0 has a total capacity of 31.37 GiB of which 23.81 MiB is free. Including non-PyTorch memory, this process has 31.34 GiB memory in use. Of the allocated memory 30.63 GiB is allocated by PyTorch, and 117.45 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
probing_act = {}


def save_probing_act(act, hook):
    probing_act["act"] = act[:, -1].detach()
    return act


model.reset_hooks(including_permanent=True)
model.add_perma_hook(cfg.hook_point, save_probing_act)


def get_probe_logit(act=probing_act, coef=coef, intercept=intercept):
    return torch.mean(act["act"].cpu() @ coef.T + intercept)


_ = model(clean_tokens)
probe_logit_clean = get_probe_logit()
_ = model(corrupted_tokens)
probe_logit_corrupted = get_probe_logit()


def probing_metric(_logits):
    return (get_probe_logit() - probe_logit_corrupted) / (
        probe_logit_clean - probe_logit_corrupted
    )

  return torch.mean(act["act"].cpu() @ coef.T + intercept)


In [None]:
patched_block = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, probing_metric)


__array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)

  0%|          | 1/208 [00:00<00:30,  6.81it/s]

100%|██████████| 208/208 [00:33<00:00,  6.15it/s]
100%|██████████| 208/208 [00:34<00:00,  6.09it/s]
100%|██████████| 208/208 [00:34<00:00,  6.06it/s]


In [None]:
imshow(
    patched_block,
    title="Activation Patching Per Block",
    x=model.to_str_tokens(cfg.probe_prompt.format("x y z")),
    xlabel="Position",
    ylabel="layer",
    facet_col=0,
    facet_labels=["Residual Stream", "Attention", "MLP"]
)

In [None]:
patched_mlp = patching.get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, probing_metric)

  return torch.mean(act["act"].cpu() @ coef.T + intercept)
100%|██████████| 208/208 [00:33<00:00,  6.13it/s]


In [None]:
imshow(
    patched_mlp,
    title="Activation Patching MLP Output per Layer per Token",
    x=model.to_str_tokens(cfg.probe_prompt.format("x y z")),
    xlabel="Position",
    ylabel="Layer",
    text_auto=True,
)

In [None]:
patched_attn = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, probing_metric)


__array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)

  3%|▎         | 225/6656 [00:37<17:42,  6.05it/s]

In [None]:
imshow(
    rearrange(patched_attn, "... layer pos head -> ... (layer head) pos"),
    title="Activation Patching Per Head Per Token",
    x=model.to_str_tokens(cfg.probe_prompt.format("x y z")),
    y=[f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)],
    ylabel="Head",
    xlabel="Position",
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
)