# Setup

In [1]:
from transformer_lens.cautils.notebook import *

from transformer_lens.rs.callum2.ioi_and_bos.ioi_functions import (
    attn_scores_as_linear_func_of_keys,
    attn_scores_as_linear_func_of_queries,
    get_attn_scores_as_linear_func_of_queries_for_histogram,
    get_attn_scores_as_linear_func_of_keys_for_histogram,
    decompose_attn_scores,
    plot_contribution_to_attn_scores,
    project,
    decompose_attn_scores_full,
    create_fucking_massive_plot_1,
    create_fucking_massive_plot_2,
    get_nonspace_name_tokenIDs,
    get_nonspace_name_tokenIDs,
    get_lowercase_name_tokenIDs,
)
from transformer_lens.rs.callum2.utils import (
    get_effective_embedding,
)

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

clear_output()

In [3]:
effective_embeddings = get_effective_embedding(model, use_codys_without_attention_changes=False) 

W_U = model.W_U
W_EE = effective_embeddings["W_E (including MLPs)"]
W_EE_subE = effective_embeddings["W_E (only MLPs)"]

# Patching positions

Below, I'm going to perform a patching experiment where the residual stream values for `q_input` before heads in layer 10 are replaced with their values on the flipped dataset: where the token identities are swapped around. Essentially, this means that head 10.7's *"look for token `IO` at position `pos(IO)`"* desire will be replaced with *"look for token `IO` at position `pos(S1)`"*. If positions don't matter, this experiment won't do anything. But if positions do matter, then I expect this to reduce attention diff.

Why is this meaningfully different than things we've done before? Because before I think we either didn't implement it well (I don't trust myself back at this point!) or else we just replaced the positional embeddings `W_pos` rather than patching "the whole way up to head 10.7".

In [14]:
clean_hook_names = [utils.get_act_name("attn_scores", 10), utils.get_act_name("scale", 10, "ln1")]

ioi_dataset, ioi_cache = generate_data_and_caches(
    N = 150,
    model = model,
    verbose = False,
    seed = 42,
    prepend_bos = True,
    only_ioi = True,
    symmetric = True,
    return_cache = True,
    names_filter = lambda name: name in clean_hook_names,
)
ioi_cache = cast(ActivationCache, ioi_cache)

In [5]:
# %pip install git+https://github.com/callummcdougall/eindex.git
from eindex import eindex

In [18]:
def get_attn_diff_from_cache(
    cache: ActivationCache,
    ioi_dataset: IOIDataset = ioi_dataset,
    NEG_HEAD: Tuple[int, int] = (10, 7),
) -> float:
    
    LAYER, HEAD = NEG_HEAD
    end_posns = ioi_dataset.word_idx["end"]
    IO_posns = ioi_dataset.word_idx["IO"]
    S1_posns = ioi_dataset.word_idx["S1"]

    # Get attn scores from cache
    attn_score_clean = cache["attn_scores", LAYER][:, HEAD] # [batch seqQ seqK]

    # We care abt attn from END to IO1, and END to S1
    attn_score_IO_clean = eindex(attn_score_clean, end_posns, IO_posns, "batch [batch] [batch]")
    attn_score_S1_clean = eindex(attn_score_clean, end_posns, S1_posns, "batch [batch] [batch]")
    attn_score_diff_avg_clean = (attn_score_IO_clean - attn_score_S1_clean).mean()

    return attn_score_diff_avg_clean.item()
    


def run_patching_experiment(
    ioi_dataset: IOIDataset = ioi_dataset,
    ioi_cache: ActivationCache = ioi_cache,
    model: HookedTransformer = model,
    NEG_HEAD: Tuple[int, int] = (10, 7),
    verbose: bool = False,
):
    '''
    Runs the described patching experiment. Prints the average attn score diff in both cases.
    '''
    model.reset_hooks()
    LAYER, HEAD = NEG_HEAD
    batch_size, seq_len = ioi_dataset.toks.shape

    end_posns = ioi_dataset.word_idx["end"]

    hook_name_resid_pre = utils.get_act_name("resid_pre", LAYER)
    hook_name_q_input = utils.get_act_name("q_input", LAYER)
    hook_name_attn_scores = utils.get_act_name("attn_scores", LAYER)
    hook_name_scale = utils.get_act_name("scale", 10, "ln1")

    # Get clean attn scores
    attn_score_diff_avg_clean = get_attn_diff_from_cache(ioi_cache, NEG_HEAD=NEG_HEAD)

    # Generate dataset with IO and S1 reversed (i.e. negating results of this is like just flipping posns of these)
    flipped_dataset = ioi_dataset.gen_flipped_prompts("ABB -> BAB, BAB -> ABB")
    # Sanity check
    if verbose:
        print("First 3 sentences of IOI dataset:")
        for i in range(3): print(ioi_dataset.sentences[i])
        print("\nFirst 3 sentences of IOI-flipped dataset:")
        for i in range(3): print(flipped_dataset.sentences[i])
    
    # Get new resid_pre values from this dataset
    _, flipped_cache = model.run_with_cache(
        flipped_dataset.toks,
        return_type = None,
        names_filter = lambda name: name == hook_name_resid_pre,
    )
    flipped_resid_pre = flipped_cache[hook_name_resid_pre] # [batch seq d_model]

    # Define hook fns to patch query input onto these new values
    def hook_queries(query_input: Float[Tensor, "batch seq heads d_model"], hook: HookPoint, mode: Literal["project", "patch"]):
        assert mode in ["project", "patch"]
        if mode == "project":
            if verbose: print("projecting")
            W_U_IO = W_U.T[ioi_dataset.io_tokenIDs]
            W_U_S1 = W_U.T[ioi_dataset.s_tokenIDs]
            projection_directions = t.stack([W_U_IO, W_U_S1], dim=-1) # [batch d_model 2]
            query_input[range(batch_size), end_posns, HEAD] = project(
                query_input[range(batch_size), end_posns, HEAD],
                projection_directions,
            )
        elif mode == "patch":
            if verbose: print("patching")
            query_input[:, :, HEAD] = flipped_resid_pre        
        return query_input

    def hook_freeze_scale(scale: Float[Tensor, "batch seq *heads 1"], hook: HookPoint):
        return ioi_cache[hook_name_scale]
    
    # Run hooked fwd pass for both hook fns
    model.reset_hooks()
    model.add_hook(hook_name_q_input, partial(hook_queries, mode="patch"))
    model.add_hook(hook_name_scale, hook_freeze_scale)
    _, patched_cache = model.run_with_cache(ioi_dataset.toks, return_type=None, names_filter = lambda name: name == hook_name_attn_scores)
    attn_score_diff_avg_patched = get_attn_diff_from_cache(patched_cache)
    
    model.reset_hooks()
    model.add_hook(hook_name_q_input, partial(hook_queries, mode="project"))
    model.add_hook(hook_name_scale, hook_freeze_scale)
    _, projected_cache = model.run_with_cache(ioi_dataset.toks, return_type=None, names_filter = lambda name: name == hook_name_attn_scores)
    attn_score_diff_avg_projected = get_attn_diff_from_cache(projected_cache)
    model.reset_hooks()

    # Print all results
    table = Table("Intervention", "Attention Score Diff (IO - S1)", title="IOI patching, effect on 10.7 attn score diff")
    table.add_row("Clean", f"{attn_score_diff_avg_clean:.4f}")
    table.add_row("Flipped positional information (patched)", f"{attn_score_diff_avg_patched:.4f}")
    table.add_row("Project queries onto W_U[IO, S1]", f"{attn_score_diff_avg_projected:.4f}")
    rprint(table)

run_patching_experiment(verbose = True)

First 3 sentences of IOI dataset:
Then, Alex and Anthony had a lot of fun at the school. Anthony gave a ring to Alex
Then, Anthony and Alex had a lot of fun at the school. Alex gave a ring to Anthony
Then, Connor and Roman were working at the house. Roman decided to give a basketball to Connor

First 3 sentences of IOI-flipped dataset:
Then, Anthony and Alex had a lot of fun at the school. Anthony gave a ring to Alex
Then, Alex and Anthony had a lot of fun at the school. Alex gave a ring to Anthony
Then, Roman and Connor were working at the house. Roman decided to give a basketball to Connor
patching
projecting


# Conclusions

Positional information is more important than we thought (maybe explains about 5-10% of the attention), but not important enough to make a meaningful difference I also anticipate it'll be hard to work into projections.