In [None]:
import json
import random

import torch
import transformer_lens.patching as patching
import numpy as np
import plotly.express as px
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM
from transformer_lens import HookedTransformer
from einops import rearrange

from cfg import ExperimentConfig

In [None]:
def imshow(
    tensor,
    x=None,
    y=None,
    title=None,
    xlabel=None,
    ylabel=None,
    facet_labels=None,
    facet_col=None,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    reverse_y=True,
    **kwargs,
):
    data = tensor.cpu() if hasattr(tensor, "cpu") else np.array(tensor)
    fig = px.imshow(
        data,
        x=x,
        y=y,
        facet_col=facet_col,
        color_continuous_midpoint=color_continuous_midpoint,
        color_continuous_scale=color_continuous_scale,
        **kwargs,
    )

    if reverse_y:
        fig.update_yaxes(autorange="min reversed")
    if title:
        fig.update_layout(title=title)
    if xlabel:
        fig.update_xaxes(title=xlabel)
    if ylabel:
        fig.update_yaxes(title=ylabel)
    if facet_labels and facet_col is not None:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i].text = label

    return fig

In [None]:
torch.set_grad_enabled(False)
cfg = ExperimentConfig()

In [None]:
USE_HF = 0

In [None]:
if USE_HF:
    repo_id = "cheeetoo/trainorder"

    hf_hub_download(repo_id, filename="probes.npz", local_dir=cfg.out_dir)
    hf_hub_download(repo_id, filename="aliases.json", local_dir=cfg.out_dir)

    hf_model = AutoModelForCausalLM.from_pretrained(repo_id)
else:
    hf_model = AutoModelForCausalLM.from_pretrained(f"{cfg.out_dir}/final")

model = HookedTransformer.from_pretrained(cfg.model_id, hf_model=hf_model)

In [None]:
probes = np.load(f"{cfg.out_dir}/probes.npz")
coef = probes[f"{(0, cfg.num_stages - 1)}_coef"]
intercept = probes[f"{(0, cfg.num_stages - 1)}_intercept"]

with open(f"{cfg.out_dir}/aliases.json") as f:
    aliases = json.load(f)

aliases_first = aliases["stage_0"]
aliases_last = aliases[f"stage_{cfg.num_stages - 1}"]
aliases_first = random.choices(aliases_first, k=cfg.n_patching_prompts)
aliases_last = random.choices(aliases_last, k=cfg.n_patching_prompts)

In [None]:
clean_prompts = [cfg.probe_prompt.format(a) for a in aliases_last]
corrupted_prompts = [cfg.probe_prompt.format(a) for a in aliases_first]

clean_tokens = model.to_tokens(clean_prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)

_, clean_cache = model.run_with_cache(clean_tokens)

In [None]:
probing_act = {}


def save_probing_act(act, hook):
    probing_act["act"] = act[:, -1].detach()
    return act


model.reset_hooks(including_permanent=True)
model.add_perma_hook(cfg.hook_point, save_probing_act)


def get_probe_logit(act=probing_act, coef=coef, intercept=intercept):
    return torch.mean(act["act"].cpu() @ coef.T + intercept)


_ = model(clean_tokens)
probe_logit_clean = get_probe_logit()
_ = model(corrupted_tokens)
probe_logit_corrupted = get_probe_logit()


def probing_metric(_logits):
    return (get_probe_logit() - probe_logit_corrupted) / (
        probe_logit_clean - probe_logit_corrupted
    )

In [None]:
block_every = patching.get_act_patch_block_every(
    model=model,
    corrupted_tokens=corrupted_tokens,
    clean_cache=clean_cache,
    metric=probing_metric,
)

In [None]:
imshow(
    block_every,
    title="Activation Patching Per Block",
    x=model.to_str_tokens(cfg.probe_prompt.format("x y z")),
    xlabel="Position",
    ylabel="Layer",
    facet_col=0,
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
)

In [None]:
attn_every = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, probing_metric)

In [None]:
imshow(
    attn_every,
    title="Activation Patching Per Head",
    ylabel="Layer",
    xlabel="Head",
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
)