In [1]:
import json
import random

import torch
import transformer_lens.patching as patching
import numpy as np
import plotly.express as px
from transformers import AutoModelForCausalLM
from transformer_lens import HookedTransformer
from einops import rearrange

from cfg import ExperimentConfig

  from .autonotebook import tqdm as notebook_tqdm


In [59]:
def imshow(
    tensor,
    x=None,
    y=None,
    title=None,
    xlabel=None,
    ylabel=None,
    facet_labels=None,
    facet_col=None,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    reverse_y=True,
    **kwargs,
):
    data = tensor.cpu() if hasattr(tensor, "cpu") else np.array(tensor)
    fig = px.imshow(
        data,
        x=x,
        y=y,
        facet_col=facet_col,
        color_continuous_midpoint=color_continuous_midpoint,
        color_continuous_scale=color_continuous_scale,
        **kwargs,
    )

    if reverse_y:
        fig.update_yaxes(autorange=True)
    if title:
        fig.update_layout(title=title)
    if xlabel:
        fig.update_xaxes(title=xlabel)
    if ylabel:
        fig.update_xaxes(title=ylabel)
    if facet_labels and facet_col is not None:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i].text = label

    return fig

In [3]:
torch.set_grad_enabled(False)
cfg = ExperimentConfig()

In [4]:
#hf_model = AutoModelForCausalLM.from_pretrained(f"{cfg.out_dir}/final")
hf_model = AutoModelForCausalLM.from_pretrained("cheeetoo/trainorder")
model = HookedTransformer.from_pretrained(cfg.model_id, hf_model=hf_model)



Loaded pretrained model meta-llama/Llama-3.2-1B into HookedTransformer


In [5]:
probes = np.load(f"{cfg.out_dir}/probes.npz")
coef = probes[f"{(0, cfg.num_stages - 1)}_coef"]
intercept = probes[f"{(0, cfg.num_stages - 1)}_intercept"]

with open(f"{cfg.out_dir}/aliases.json") as f:
    aliases = json.load(f)

aliases_first = aliases["stage_0"]
aliases_last = aliases[f"stage_{cfg.num_stages - 1}"]
aliases_first = random.choices(aliases_first, k=cfg.n_patching_prompts)
aliases_last = random.choices(aliases_last, k=cfg.n_patching_prompts)

In [6]:
clean_prompts = [cfg.probe_prompt.format(a) for a in aliases_last]
corrupted_prompts = [cfg.probe_prompt.format(a) for a in aliases_first]

clean_tokens = model.to_tokens(clean_prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)

_, clean_cache = model.run_with_cache(clean_tokens)

In [7]:
probing_act = {}


def save_probing_act(act, hook):
    probing_act["act"] = act[:, -1].detach()
    return act


model.reset_hooks(including_permanent=True)
model.add_perma_hook(cfg.hook_point, save_probing_act)


def get_probe_logit(act=probing_act, coef=coef, intercept=intercept):
    return torch.mean(act["act"].cpu() @ coef.T + intercept)


_ = model(clean_tokens)
probe_logit_clean = get_probe_logit()
_ = model(corrupted_tokens)
probe_logit_corrupted = get_probe_logit()


def probing_metric(_logits):
    return (get_probe_logit() - probe_logit_corrupted) / (
        probe_logit_clean - probe_logit_corrupted
    )

  return torch.mean(act["act"].cpu() @ coef.T + intercept)


In [49]:
block_every = patching.get_act_patch_block_every(
    model=model,
    corrupted_tokens=corrupted_tokens,
    clean_cache=clean_cache,
    metric=probing_metric,
)

  0%|          | 0/208 [00:00<?, ?it/s]


__array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)

100%|██████████| 208/208 [00:04<00:00, 42.21it/s]
100%|██████████| 208/208 [00:04<00:00, 42.05it/s]
100%|██████████| 208/208 [00:04<00:00, 41.96it/s]


In [60]:
imshow(
    block_every,
    title="Activation Patching Per Block",
    x=model.to_str_tokens(cfg.probe_prompt.format("x y z")),
    xlabel="Position",
    ylabel="Layer",
    facet_col=0,
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
)

In [68]:
attn_every = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, probing_metric)


__array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)

100%|██████████| 6656/6656 [02:41<00:00, 41.12it/s]
100%|██████████| 6656/6656 [02:43<00:00, 40.66it/s]
100%|██████████| 1664/1664 [00:40<00:00, 40.68it/s]
100%|██████████| 1664/1664 [00:40<00:00, 40.73it/s]
100%|██████████| 6656/6656 [02:43<00:00, 40.79it/s]


In [69]:
attn_every.shape

torch.Size([5, 16, 13, 32])

In [70]:
clean_tokens.shape

torch.Size([25, 13])

In [73]:
attn_every = rearrange(attn_every, "... layer pos head -> ... (layer head) pos")

In [111]:
attn_every.shape

torch.Size([5, 512, 13])

In [127]:
imshow(
    rearrange(attn_every, "facet lh pos -> lh (facet pos)"),
    title="Output",
    height=20000,
    x=[f"{f}{t}" for f in ["Output", "Query", "Key", "Value", "Pattern"] for t in model.to_str_tokens(cfg.probe_prompt.format("x y z"))],
    y=[f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)],
)

In [129]:
model.cfg.n_heads, model.cfg.n_key_value_heads

(32, 8)

In [67]:
imshow(
    attn_every,
    title="Activation Patching Per Head",
    xlabel="Head",
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
)