## Setup

In [1]:
import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys

try:
    from google.colab import drive # type: ignore
    %pip install transformer_lens
    %pip install gdown
    # %pip install plotly
    # %pip install jaxtyping
    # %pip install einops
    # %pip install protobuf==3.20.*
    from pathlib import Path
    import gdown
    if not Path("ioi_dataset.py").exists():
        urls = {
            "ioi_dataset.py": "https://drive.google.com/uc?id=19UjxFnb6kztuhvz6dGAXjA9oRZmd84kC",
            "path_patching.py": "https://drive.google.com/uc?id=1duF7B3IjG_E5nGcjT_BuoSrSkynUhZI5",
        }
        for filename, url in urls.items():
            output = str(Path(filename).resolve())
            gdown.download(url, output)
except:
    # from IPython import get_ipython
    # ipython = get_ipython()
    # ipython.run_line_magic("load_ext", "autoreload")
    # ipython.run_line_magic("autoreload", "2")
    pass

import torch as t
from torch import Tensor
from tqdm.notebook import tqdm
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set
from rich import print as rprint
from transformer_lens import utils, HookedTransformer, ActivationCache

import torch as t
from typing import List, Union
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import re
from transformer_lens import utils
import circuitsvis as cv
import functools

t.set_grad_enabled(False)

from ioi_dataset import NAMES, IOIDataset

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor", "showlegend", "xaxis_tickmode", "yaxis_tickmode", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"}

def imshow(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    facet_labels = kwargs_pre.pop("facet_labels", None)
    border = kwargs_pre.pop("border", False)
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if border:
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    fig.show(renderer=renderer)


def hist(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    names = kwargs_pre.pop("names", None)
    if "barmode" not in kwargs_post:
        kwargs_post["barmode"] = "overlay"
    if "bargap" not in kwargs_post:
        kwargs_post["bargap"] = 0.0
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.histogram(x=tensor, **kwargs_pre).update_layout(**kwargs_post)
    if names is not None:
        for i in range(len(fig.data)):
            fig.data[i]["name"] = names[i // 2]
    fig.show(renderer)

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Function to get data (when I generate a lot of data I'll be iterating through this a bunch of times).

In [3]:
def _logits_to_ave_logit_diff(logits: Float[Tensor, "batch seq d_vocab"], ioi_dataset: IOIDataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''
    
    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()



def _ioi_metric(
        logits: Float[Tensor, "batch seq d_vocab"],
        clean_logit_diff: float,
        corrupted_logit_diff: float,
        ioi_dataset: IOIDataset,
    ) -> float:
        '''
        We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset), 
        and -1 when performance has been destroyed (i.e. is same as ABC dataset).
        '''
        patched_logit_diff = _logits_to_ave_logit_diff(logits, ioi_dataset)
        return (patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)



def generate_data_and_caches(N: int, verbose: bool = False, seed: int = 42):

    ioi_dataset = IOIDataset(
        prompt_type="mixed",
        N=N,
        tokenizer=model.tokenizer,
        prepend_bos=False,
        seed=seed,
        device=str(device)
    )

    abc_dataset = ioi_dataset.gen_flipped_prompts("ABB->XYZ, BAB->XYZ")

    model.reset_hooks(including_permanent=True)

    ioi_logits_original, ioi_cache = model.run_with_cache(ioi_dataset.toks)
    abc_logits_original, abc_cache = model.run_with_cache(abc_dataset.toks)

    ioi_average_logit_diff = _logits_to_ave_logit_diff(ioi_logits_original, ioi_dataset).item()
    abc_average_logit_diff = _logits_to_ave_logit_diff(abc_logits_original, ioi_dataset).item()

    if verbose:
        print(f"Average logit diff (IOI dataset): {ioi_average_logit_diff:.4f}")
        print(f"Average logit diff (ABC dataset): {abc_average_logit_diff:.4f}")

    ioi_metric = functools.partial(
        _ioi_metric,
        clean_logit_diff=ioi_average_logit_diff,
        corrupted_logit_diff=abc_average_logit_diff,
        ioi_dataset=ioi_dataset,
    )

    return ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric



N = 15
ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric = generate_data_and_caches(N, verbose=True)

Average logit diff (IOI dataset): 3.1277
Average logit diff (ABC dataset): 0.4074


## Path Patching experiments

In [4]:
from path_patching import Node, IterNode, path_patch

A feature which I added to path patching - you can have the patching metric return any object (not necessarily a float). This is useful if you want e.g. the new attention patterns you get after path patching.

You do this by passing the argument `apply_metric_to_cache=True`, this makes sure that the object returned by the `path_patch` function is the result of applying `patching_metric` to the cache, rather than the output logits.

In [5]:
NEG_NMH = (10, 7)

def get_attn_pattern_for_neg_nmh(
    patched_cache: ActivationCache,
    neg_nmh = NEG_NMH
) -> Float[Tensor, "batch seq_Q seq_K"]:
    layer, head = neg_nmh
    attn_pattern = patched_cache["pattern", layer][:, head]
    return attn_pattern
    

NAME_MOVER_HEADS = [(9, 6), (9, 9), (10, 10)]

nmh_patterns_orig = ioi_cache["pattern", NEG_NMH[0]][:, NEG_NMH[1]]

nmh_patterns_after_patching = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_components=[Node("z", layer, head=head) for layer, head in NAME_MOVER_HEADS], # Output of all name mover heads
    receiver_components=Node("q", NEG_NMH[0], head=NEG_NMH[1]), # To query input of negative name mover head
    patching_metric=get_attn_pattern_for_neg_nmh,
    apply_metric_to_cache=True,
)

In [6]:
def compare_heads(batch_idx: Union[int, List[int]]):

    if isinstance(batch_idx, int):
        batch_idx = [batch_idx]
    
    for idx in batch_idx:
        display(cv.attention.attention_patterns(
            attention = t.stack([nmh_patterns_orig[idx], nmh_patterns_after_patching[idx]]),
            tokens = model.to_str_tokens(ioi_dataset.sentences[idx], prepend_bos=False),
            attention_head_names = ["10.7 clean", "10.7 path-patched"],
        ))

compare_heads(range(5))

From eyeballing, first 5 seem to basically hold up - attention probability from `end` to `IO` is high in the clean case, but when you Q-patch it decreases a lot (it doesn't for K or V-patching).

Now let's get more quantitative - I'll calculate the decreases in attention probability from `end` to `IO` over all sequences in our dataset. I'll do this over a bunch of datasets.

In [93]:
NEG_NMH = (10, 7)
NAME_MOVER_HEADS = [(9, 6), (9, 9), (10, 10)]


def get_io_attn_change_for_nmh(
    patched_cache: ActivationCache,
    ioi_dataset: IOIDataset,
    ioi_cache: ActivationCache,
    neg_nmh = NEG_NMH,
) -> Float[Tensor, "batch"]:
    layer, head = neg_nmh
    attn_pattern_patched = patched_cache["pattern", layer][:, head]
    attn_pattern_clean = ioi_cache["pattern", layer][:, head]
    # both are (batch, seq_Q, seq_K), and I want all the "end -> IO" attention probs

    N = ioi_dataset.toks.size(0)
    io_seq_pos = ioi_dataset.word_idx["IO"]
    end_seq_pos = ioi_dataset.word_idx["end"]
    
    return attn_pattern_patched[range(N), end_seq_pos, io_seq_pos], attn_pattern_clean[range(N), end_seq_pos, io_seq_pos]


results_patched = []
results_clean = []

for seed in tqdm(range(50)):

    ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric = generate_data_and_caches(15, seed=seed)

    result_patched, result_clean = path_patch(
        model,
        orig_input=ioi_dataset.toks,
        new_input=abc_dataset.toks,
        sender_components=[Component("z", layer, head=head) for layer, head in NAME_MOVER_HEADS], # Output of all name mover heads
        receiver_components=Component("q", NEG_NMH[0], head=NEG_NMH[1]), # To query input of negative name mover head
        patching_metric=functools.partial(get_io_attn_change_for_nmh, ioi_dataset=ioi_dataset, ioi_cache=ioi_cache),
        apply_metric_to_cache=True,
    )
    results_patched.extend(result_patched.tolist())
    results_clean.extend(result_clean.tolist())

    t.cuda.empty_cache()


hist(
    [results_patched, results_clean],
    labels={"variable": "Version", "value": "Attn"},
    title="Attention paid from END to IO (patched vs non-patched)",
    names=["Patched", "Clean"],
    width=800, 
    height=600,
    opacity=0.75,
    marginal="box",
    template="simple_white"
)

  0%|          | 0/50 [00:00<?, ?it/s]

Alternate measure - rather than looking at decrease in attn to "IO" (and showing it decreases), I'm going to look at difference in attention to "IO" vs attention to "S" (and show that it's much larger for clean than for patched).

Importantly, **this result shows that the average attn paid to IO is very close to attn paid to S1** for the patched run. So the neg name mover head 10.7 paying attention to the IO token is **fully explained by Q-composition from name mover heads**.

In [119]:
NEG_NMH = (10, 7)
NAME_MOVER_HEADS = [(9, 6), (9, 9), (10, 10)]


def get_io_vs_s_attn_for_nmh(
    patched_cache: ActivationCache,
    ioi_dataset: IOIDataset,
    ioi_cache: ActivationCache,
    neg_nmh = NEG_NMH,
) -> Float[Tensor, "batch"]:
    layer, head = neg_nmh
    attn_pattern_patched = patched_cache["pattern", layer][:, head]
    attn_pattern_clean = ioi_cache["pattern", layer][:, head]
    # both are (batch, seq_Q, seq_K), and I want all the "end -> IO" attention probs

    N = ioi_dataset.toks.size(0)
    io_seq_pos = ioi_dataset.word_idx["IO"]
    s1_seq_pos = ioi_dataset.word_idx["S1"]
    end_seq_pos = ioi_dataset.word_idx["end"]
    
    return (
        attn_pattern_patched[range(N), end_seq_pos, io_seq_pos] - attn_pattern_patched[range(N), end_seq_pos, s1_seq_pos], 
        attn_pattern_clean[range(N), end_seq_pos, io_seq_pos] - attn_pattern_clean[range(N), end_seq_pos, s1_seq_pos], 
    )


results_patched = []
results_clean = []

for seed in tqdm(range(50)):

    ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric = generate_data_and_caches(20, seed=seed)

    result_patched, result_clean = path_patch(
        model,
        orig_input=ioi_dataset.toks,
        new_input=abc_dataset.toks,
        sender_components=[Component("z", layer, head=head) for layer, head in NAME_MOVER_HEADS], # Output of all name mover heads
        receiver_components=Component("q", NEG_NMH[0], head=NEG_NMH[1]), # To query input of negative name mover head
        patching_metric=functools.partial(get_io_vs_s_attn_for_nmh, ioi_dataset=ioi_dataset, ioi_cache=ioi_cache),
        apply_metric_to_cache=True,
    )
    results_patched.extend(result_patched.tolist())
    results_clean.extend(result_clean.tolist())

    t.cuda.empty_cache()


hist(
    [results_patched, results_clean],
    labels={"variable": "Version", "value": "Attn diff (positive => more attn paid to IO than S1)"},
    title="Difference in attn from END->IO and END->S1 (patched vs clean)",
    names=["Patched", "Clean"],
    width=800, 
    height=600,
    opacity=0.75,
    marginal="box",
    template="simple_white"
)

  0%|          | 0/50 [00:00<?, ?it/s]

every plot understandable without reading caption or main text = good

note - monday lw post (incrementation heads & copying suprpressors). zero ablate all, find head w/ semantic meaning

In [121]:
hist(
    [results_patched, results_clean],
    labels={"variable": "Version", "value": "Attn diff (positive => more attn paid to IO than S1)"},
    title="Difference in attn from END->IO vs. END->S1 (patched vs clean)",
    names=["Patched", "Clean"],
    width=800, 
    height=600,
    opacity=0.75,
    marginal="box",
    template="simple_white"
)