In [2]:
import json
import random
from functools import partial

import torch
import transformer_lens.patching as patching
import numpy as np
import plotly.express as px
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM
from transformer_lens import HookedTransformer
from einops import rearrange

from cfg import ExperimentConfig

In [3]:
def imshow(
    tensor,
    x=None,
    y=None,
    title=None,
    xlabel=None,
    ylabel=None,
    facet_labels=None,
    facet_col=None,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    reverse_y=True,
    autotext=False,
    **kwargs,
):
    data = tensor.cpu().numpy() if hasattr(tensor, "cpu") else np.array(tensor)
    fig = px.imshow(
        data,
        x=x,
        y=y,
        facet_col=facet_col,
        color_continuous_midpoint=color_continuous_midpoint,
        color_continuous_scale=color_continuous_scale,
        **kwargs,
    )

    if reverse_y:
        fig.update_yaxes(autorange=True)
    if title:
        fig.update_layout(title=title)
    if xlabel:
        fig.update_xaxes(title=xlabel)
    if ylabel:
        fig.update_yaxes(title=ylabel)
    if facet_labels and facet_col is not None:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i].text = label
    if autotext:
        text = np.where(data.round(2) == 0, "", data.round(2).astype(str))
        fig.update_traces(text=text, texttemplate="%{text}")

    return fig

In [4]:
def patch_at_locations(
    model: HookedTransformer,
    corrupted_tokens,
    clean_cache,
    patching_metric,
    patching_locations: list[tuple],
):
    locations_by_hook = {}
    for loc in patching_locations:
        hook_name = loc[0]
        if hook_name not in locations_by_hook:
            locations_by_hook[hook_name] = []
        locations_by_hook[hook_name].append(loc[1:])  # (pos,) or (pos, head)

    def patch_hook(activation, hook, locations):
        for loc in locations:
            pos = loc[0]
            if len(loc) == 1:
                activation[:, pos] = clean_cache[hook.name][:, pos]
            else:
                head = loc[1]
                activation[:, pos, head] = clean_cache[hook.name][:, pos, head]
        return activation

    fwd_hooks = [
        (hook_name, partial(patch_hook, locations=locations))
        for hook_name, locations in locations_by_hook.items()
    ]

    patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=fwd_hooks)

    return patching_metric(patched_logits)


In [5]:
torch.set_grad_enabled(False)
cfg = ExperimentConfig()

In [6]:
USE_HF = 1

In [7]:
if USE_HF:
    repo_id = "cheeetoo/trainorder"

    hf_hub_download(repo_id, filename="probes.npz", local_dir=cfg.out_dir)
    hf_hub_download(repo_id, filename="aliases.json", local_dir=cfg.out_dir)

    hf_model = AutoModelForCausalLM.from_pretrained(repo_id)
else:
    hf_model = AutoModelForCausalLM.from_pretrained(f"{cfg.out_dir}/final")

model = HookedTransformer.from_pretrained(cfg.model_id, hf_model=hf_model)



Loaded pretrained model meta-llama/Llama-3.2-1B into HookedTransformer


In [8]:
probes = np.load(f"{cfg.out_dir}/probes.npz")
coef = (
    torch.tensor(probes[f"{(0, cfg.num_stages - 1)}_coef"])
    .to(model.cfg.dtype)
    .to(model.cfg.device)
)
intercept = (
    torch.tensor(probes[f"{(0, cfg.num_stages - 1)}_intercept"])
    .to(model.cfg.dtype)
    .to(model.cfg.device)
)

with open(f"{cfg.out_dir}/aliases.json") as f:
    aliases = json.load(f)

aliases_first = aliases["stage_0"]
aliases_last = aliases[f"stage_{cfg.num_stages - 1}"]
aliases_first = random.choices(aliases_first, k=cfg.n_patching_prompts)
aliases_last = random.choices(aliases_last, k=cfg.n_patching_prompts)

In [9]:
clean_prompts = [cfg.probe_prompt.format(a) for a in aliases_last]
corrupted_prompts = [cfg.probe_prompt.format(a) for a in aliases_first]

clean_tokens = model.to_tokens(clean_prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)

_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [None]:
probing_act = {}


def save_probing_act(act, hook):
    probing_act["act"] = act[:, -1].detach()
    return act


model.reset_hooks(including_permanent=True)
model.add_perma_hook(cfg.hook_point, save_probing_act)


def get_probe_logit(act=probing_act, coef=coef, intercept=intercept):
    return torch.mean(act["act"] @ coef.T + intercept)


_ = model(clean_tokens)
probe_logit_clean = get_probe_logit()
_ = model(corrupted_tokens)
probe_logit_corrupted = get_probe_logit()


def act_patch_metric(_logits):
    return (get_probe_logit() - probe_logit_corrupted) / (
        probe_logit_clean - probe_logit_corrupted
    )


def resample_ablate_metric(_logits):
    return (probe_logit_clean - get_probe_logit()) / (
        probe_logit_clean - probe_logit_corrupted
    )

In [11]:
str_toks = model.to_str_tokens(cfg.probe_prompt.format("x y z"))

In [None]:
patched_block = patching.get_act_patch_block_every(
    model=model,
    corrupted_tokens=corrupted_tokens,
    clean_cache=clean_cache,
    metric=act_patch_metric,
)
resampled_block = patching.get_act_patch_block_every(
    model=model,
    corrupted_tokens=clean_tokens,
    clean_cache=corrupted_cache,
    metric=resample_ablate_metric,
)

  0%|          | 0/208 [00:00<?, ?it/s]

100%|██████████| 208/208 [00:14<00:00, 14.16it/s]
100%|██████████| 208/208 [00:14<00:00, 14.18it/s]
100%|██████████| 208/208 [00:14<00:00, 13.89it/s]
100%|██████████| 208/208 [00:15<00:00, 13.50it/s]
100%|██████████| 208/208 [00:15<00:00, 13.32it/s]
100%|██████████| 208/208 [00:15<00:00, 13.15it/s]


In [None]:
imshow(
    patched_block,
    title="Activation Patching Per Block Per Position",
    x=str_toks,
    xlabel="Position",
    ylabel="Layer",
    facet_col=0,
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
)

In [None]:
imshow(
    resampled_block,
    title="Resample Ablation Per Block Per Token",
    x=str_toks,
    xlabel="Position",
    ylabel="Layer",
    facet_col=0,
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
)

In [None]:
patched_mlp = patching.get_act_patch_mlp_out(
    model, corrupted_tokens, clean_cache, act_patch_metric
)
resampled_mlp = patching.get_act_patch_mlp_out(
    model, clean_tokens, corrupted_cache, resample_ablate_metric
)

  0%|          | 0/208 [00:00<?, ?it/s]

100%|██████████| 208/208 [00:14<00:00, 14.15it/s]
100%|██████████| 208/208 [00:14<00:00, 14.14it/s]


In [None]:
imshow(
    patched_mlp,
    title="Activation Patching Per MLP Per Position",
    xlabel="Position",
    ylabel="Layer",
    x=str_toks,
)

In [None]:
imshow(
    resampled_mlp,
    title="Resample Ablation Per MLP Per Position",
    xlabel="Position",
    ylabel="Layer",
    x=str_toks,
)

In [None]:
patched_attn = patching.get_act_patch_attn_head_out_by_pos(
    model, corrupted_tokens, clean_cache, act_patch_metric
)
resampled_attn = patching.get_act_patch_attn_head_out_by_pos(
    model, clean_tokens, corrupted_cache, resample_ablate_metric
)

100%|██████████| 6656/6656 [08:20<00:00, 13.29it/s]
100%|██████████| 6656/6656 [08:29<00:00, 13.06it/s]


In [None]:
interesting_layers = [0, 8, 9, 11, 12]

In [None]:
imshow(
    rearrange(patched_attn[interesting_layers], "layer pos head -> pos (layer head)"),
    title="Activation Patching Per Head Per Position",
    xlabel="Head",
    ylabel="Position",
    x=[f"L{i}H{j}" for i in interesting_layers for j in range(model.cfg.n_heads)],
    y=str_toks,
    height=400,
    width=3000,
)

In [None]:
imshow(
    rearrange(resampled_attn[interesting_layers], "layer pos head -> pos (layer head)"),
    title="Resample Ablation Per Head Per Position",
    xlabel="Head",
    ylabel="Position",
    x=[f"L{i}H{j}" for i in interesting_layers for j in range(model.cfg.n_heads)],
    y=str_toks,
    height=400,
    width=3000,
)

In [13]:
patching_locations_over95 = [
    ("blocks.0.attn.hook_z", 5, 3),
    # ("blocks.8.attn.hook_z", 12, 24),
    # ("blocks.8.attn.hook_z", 12, 25),
    # ("blocks.8.attn.hook_z", 12, 27),
    ("blocks.0.hook_mlp_out", 4),
    # ("blocks.0.hook_mlp_out", 6),
    ("blocks.0.hook_mlp_out", 5),
    ("blocks.1.hook_mlp_out", 5),
    ("blocks.2.hook_mlp_out", 5),
    # ("blocks.3.hook_mlp_out", 5),
    ("blocks.4.hook_mlp_out", 5),
    # ("blocks.5.hook_mlp_out", 5),
    # ("blocks.6.hook_mlp_out", 5),
    # ("blocks.7.hook_mlp_out", 5),
    # ("blocks.8.hook_mlp_out", 12),
    ("blocks.10.hook_mlp_out", 12),
    ("blocks.11.hook_mlp_out", 12),
    # ("blocks.12.hook_mlp_out", 12),
]
patching_locations = [
    # ("blocks.0.attn.hook_z", 5, 3),
    # ("blocks.8.attn.hook_z", 12, 24),
    # ("blocks.8.attn.hook_z", 12, 25),
    # ("blocks.8.attn.hook_z", 12, 27),
    ("blocks.0.hook_mlp_out", 4),
    # ("blocks.0.hook_mlp_out", 6),
    ("blocks.0.hook_mlp_out", 5),
    ("blocks.1.hook_mlp_out", 5),
    ("blocks.2.hook_mlp_out", 5),
    # ("blocks.3.hook_mlp_out", 5),
    # ("blocks.4.hook_mlp_out", 5),
    # ("blocks.5.hook_mlp_out", 5),
    # ("blocks.6.hook_mlp_out", 5),
    # ("blocks.7.hook_mlp_out", 5),
    # ("blocks.8.hook_mlp_out", 12),
    ("blocks.10.hook_mlp_out", 12),
    ("blocks.11.hook_mlp_out", 12),
    # ("blocks.12.hook_mlp_out", 12),
]

base_act = patch_at_locations(
    model, corrupted_tokens, clean_cache, act_patch_metric, patching_locations
)
base_resample = patch_at_locations(
    model, clean_tokens, corrupted_cache, resample_ablate_metric, patching_locations
)

print(f"baslines: {base_act} {base_resample}")

for i in range(len(patching_locations)):
    locs = patching_locations[:i] + patching_locations[i + 1 :]

    new_act = patch_at_locations(
        model, corrupted_tokens, clean_cache, act_patch_metric, locs
    )
    new_resample = patch_at_locations(
        model, clean_tokens, corrupted_cache, resample_ablate_metric, locs
    ).item()

    new_act = (base_act - new_act).item()
    new_resample = (base_resample - new_resample).item()

    print(patching_locations[i], new_act, new_resample)

KeyboardInterrupt: 

In [None]:
attn_locs = [("blocks.0.attn.hook_z", 5, 3), ("blocks.8.attn.hook_z", 12, 24)]
patch_at_locations(
    model, clean_tokens, corrupted_cache, resample_ablate_metric, attn_locs
)

tensor(1.1302, device='mps:0')

In [29]:
imshow(
    clean_cache["pattern", 0][:, 3].mean(0),
    title="Attention Pattern for L0H3",
    x=str_toks,
    y=str_toks
)

In [None]:
imshow(
    clean_cache["pattern", 11][:, 27].mean(0),
    title="Attention Pattern for L0H24",
    x=str_toks,
    y=str_toks
)