In [58]:
import json
import random

import torch
import torch.nn.functional as F
import transformer_lens.patching as patching
import numpy as np
import plotly.express as px
from transformers import AutoModelForCausalLM
from transformer_lens import HookedTransformer
from einops import rearrange

from cfg import ExperimentConfig

In [119]:
def imshow(
    tensor,
    x=None,
    y=None,
    title=None,
    xlabel=None,
    ylabel=None,
    facet_labels=None,
    facet_col=None,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    reverse_y=True,
    **kwargs,
):
    data = tensor.cpu() if hasattr(tensor, "cpu") else np.array(tensor)
    fig = px.imshow(
        data,
        x=x,
        y=y,
        facet_col=facet_col,
        color_continuous_midpoint=color_continuous_midpoint,
        color_continuous_scale=color_continuous_scale,
        **kwargs,
    )

    if reverse_y:
        fig.update_yaxes(autorange=True)
    if title:
        fig.update_layout(title=title)
    if xlabel:
        fig.update_xaxes(title=xlabel)
    if ylabel:
        fig.update_yaxes(title=ylabel)
    if facet_labels and facet_col is not None:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i].text = label

    return fig

In [59]:
torch.set_grad_enabled(False)
cfg = ExperimentConfig()

In [None]:
# hf_model = AutoModelForCausalLM.from_pretrained(f"{cfg.out_dir}/final")
hf_model = AutoModelForCausalLM.from_pretrained("cheeetoo/trainorder")
model = HookedTransformer.from_pretrained(cfg.model_id, hf_model=hf_model)



Loaded pretrained model meta-llama/Llama-3.2-1B into HookedTransformer


In [60]:
probes = np.load(f"{cfg.out_dir}/probes.npz")
coef = probes[f"{(0, cfg.num_stages - 1)}_coef"]
intercept = probes[f"{(0, cfg.num_stages - 1)}_intercept"]

with open(f"{cfg.out_dir}/aliases.json") as f:
    aliases = json.load(f)

aliases_first = aliases["stage_0"]
aliases_last = aliases[f"stage_{cfg.num_stages - 1}"]
aliases_first = random.choices(aliases_first, k=cfg.n_patching_prompts)
aliases_last = random.choices(aliases_last, k=cfg.n_patching_prompts)

In [61]:
clean_prompts = [cfg.probe_prompt.format(a) for a in aliases_last]
corrupted_prompts = [cfg.probe_prompt.format(a) for a in aliases_first]

clean_tokens = model.to_tokens(clean_prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)

_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [62]:
probing_act = {}


def save_probing_act(act, hook):
    probing_act["act"] = act[:, -1].detach()
    return act


model.reset_hooks(including_permanent=True)
model.add_perma_hook(cfg.hook_point, save_probing_act)


def get_probe_logit(act=probing_act, coef=coef, intercept=intercept):
    return torch.mean(act["act"].cpu() @ coef.T + intercept)


_ = model(clean_tokens)
probe_logit_clean = get_probe_logit()
_ = model(corrupted_tokens)
probe_logit_corrupted = get_probe_logit()


def probing_metric(_logits):
    return (get_probe_logit() - probe_logit_corrupted) / (
        probe_logit_clean - probe_logit_corrupted
    )


__array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)



In [63]:
block_every = patching.get_act_patch_block_every(
    model=model,
    corrupted_tokens=corrupted_tokens,
    clean_cache=clean_cache,
    metric=probing_metric,
)


__array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)

100%|██████████| 208/208 [00:05<00:00, 36.65it/s]
100%|██████████| 208/208 [00:05<00:00, 35.46it/s]
100%|██████████| 208/208 [00:05<00:00, 36.16it/s]


In [120]:
imshow(
    block_every,
    title="Activation Patching Per Block",
    x=model.to_str_tokens(cfg.probe_prompt.format("x y z")),
    xlabel="Position",
    ylabel="Layer",
    facet_col=0,
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
)

In [None]:
mlps = patching.get_act_patch_mlp_out(
    model, corrupted_tokens, clean_cache, probing_metric
)


__array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)

100%|██████████| 208/208 [00:05<00:00, 35.86it/s]


In [None]:
imshow(
    mlps,
    ylabel="Layer",
    xlabel="Position",
    title="Activation Patching Per MLP Per Position",
    x=model.to_str_tokens(cfg.probe_prompt.format("x y z")),
    text_auto=".2f",
)

In [118]:
imshow(
    block_every,
    title="Activation Patching Per Block",
    x=model.to_str_tokens(cfg.probe_prompt.format("x y z")),
    ylabel="Layer",
    facet_col=0,
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
)

In [None]:
attn_every = patching.get_act_patch_attn_head_by_pos_every(
    model, corrupted_tokens, clean_cache, probing_metric
)


__array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)

100%|██████████| 6656/6656 [03:06<00:00, 35.69it/s]
100%|██████████| 6656/6656 [03:11<00:00, 34.82it/s]
100%|██████████| 1664/1664 [00:49<00:00, 33.62it/s]
100%|██████████| 1664/1664 [00:47<00:00, 35.21it/s]
100%|██████████| 6656/6656 [03:03<00:00, 36.24it/s]


In [128]:
interesting_layers = [0, 8, 9, 10, 11]
imshow(
    rearrange(
        attn_every[:, interesting_layers],
        "facet layer pos head -> (layer head) (facet pos)",
    ),
    title="Output",
    # height=10000,
    height=2500,
    x=[
        f"{f}{t}"
        for f in ["Output", "Query", "Key", "Value", "Pattern"]
        for t in model.to_str_tokens(cfg.probe_prompt.format("x y z"))
    ],
    # y=[f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)],
    y=[f"L{i}H{j}" for i in interesting_layers for j in range(model.cfg.n_heads)],
)

In [19]:
print(model.to_str_tokens(cfg.probe_prompt.format("x y z")))

['<|begin_of_text|>', 'What', ' does', ' <|', 'x', ' y', ' z', '|', '>', ' mean', '?\n', ' A', ':']


In [130]:
things = [(0, [3, 29], [6]), (8, [27], [-1]), (9, [1], [-1]), (11, [24], [-1])]

In [139]:
for layer, heads, toks in things:
    # layer = 11
    # heads = [24]
    # toks = [-1]
    pattern = corrupted_cache["pattern", layer][:, heads][:, :, toks]
    pattern = rearrange(pattern, "batch head q k -> batch (head q) k").mean(0)
    imshow(
        pattern,
        x=model.to_str_tokens(cfg.probe_prompt.format("x y z")),
        y=[f"h{h}t{t}" for h in heads for t in toks],
        height=200,
        width=600,
        text_auto=".2f",
    ).show()

In [None]:
model.cfg.d_head

64

In [None]:
model.cfg.n_heads, model.cfg.n_key_value_heads

(32, 8)

In [None]:
imshow(
    attn_every,
    title="Activation Patching Per Head",
    xlabel="Head",
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
)