In [5]:
# Import stuff
import torch as t
import numpy as np
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.express as px
import einops
import plotly.graph_objects as go 
from functools import partial
import tqdm.auto as tqdm
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from fancy_einsum import einsum
from jaxtyping import Float, Int, Bool
import re

#from plotly_utils import imshow, line, scatter, bar


In [6]:
## plotting functions 
update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat",
    "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor",
    "showlegend", "xaxis_tickmode", "yaxis_tickmode", "xaxis_tickangle", "yaxis_tickangle", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"
}

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [7]:
## turn off AD to save memory, since we're focusing on model inference here 
t.set_grad_enabled(False)

device = 'cuda' if t.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Generate example prompts for IOI along with clean and corrupted answers. It's important that they're all the same length (taken from exploratory analysis demo )

In [8]:

prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]

names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )

        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = t.tensor(answer_tokens).to(device)

### check that all the prompts have the same number of tokens 
prompt_len = len(model.to_str_tokens(prompts[1]))
assert len(set([len(model.to_str_tokens(prompt)) for prompt in prompts])) == 1


In [9]:
### print all prompts in a table (learned from Keith's notebook! )
from rich.table import Table, Column
from rich import print as rprint


prompt_tab = Table('prompt', 'clean', 'corrupted', title = 'prompts and answers')

for i in range(len(prompts)):
    prompt_tab.add_row(prompts[i], answers[i][0], answers[i][1])

rprint(prompt_tab)

cache the logits and model internals for all the prompts

In [10]:
n_ex = len(prompts)


tokens = model.to_tokens(prompts, prepend_bos = True).to(device)
og_logits, cache = model.run_with_cache(tokens)


Design a metric to test model performance. In this case, we'll use the logit difference between the indirect object (correct answer) and the subject (incorrect answer)

In [11]:
d_vocab = model.cfg.d_vocab
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
n_layers = model.cfg.n_layers



assert og_logits.shape == t.Size([n_ex, prompt_len, d_vocab])

def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt = False):
    # take the last logit for every prompt (only these are relevant to the answer)
    final_logits = logits[:,-1,:]
    # get the logits corresponding to the IO/ sub tokens 
    answer_logits = final_logits.gather(dim=-1, index = answer_tokens)
    answer_logit_diff = answer_logits[:,0] - answer_logits[:,1]
    ## If per_prompt = True, return an array of the per_prompt difference, instead of the average 
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()
    

og_logit_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=True)
og_logit_avg_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=False)

cols = [
    "Prompt", 
    Column("Correct", style="rgb(0,200,0) bold"), 
    Column("Incorrect", style="rgb(255,0,0) bold"), 
    Column("Logit Difference", style="bold"), Column("Avg Logit Difference", style="bold")
]
logit_diff_table = Table(*cols, title="Logit differences")

for prompt, ans, logit_diff in zip(prompts, answers,og_logit_diff):
    logit_diff_table.add_row(prompt, ans[0], ans[1], f"{logit_diff.item():.3f}")
rprint(logit_diff_table)



In [12]:
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]])

In [13]:
from transformer_lens import patching

clean_tokens = tokens 
idx_swap = [i+1 if i % 2 == 0 else i-1 for i in range(len(tokens))]
corrupted_tokens = clean_tokens[idx_swap]

print(
    "Clean string 0:    ", model.to_string(clean_tokens[0]), "\n"
    "Corrupted string 0:", model.to_string(corrupted_tokens[0])
)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean string 0:     <|endoftext|>When John and Mary went to the shops, John gave the bag to 
Corrupted string 0: <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Clean logit diff: 3.5519
Corrupted logit diff: -3.5519


In [14]:
def ioi_metric(
    logits: Float[t.Tensor, "batch seq d_vocab"],
    answer_tokens: Float[t.Tensor, "batch 2"] = answer_tokens,
    corrupted_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
) -> Float[t.Tensor, ""]:

    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff  - corrupted_logit_diff)

In [15]:
# checking that this does what we want 

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [162]:
list(set([layer for layer, head in lh_list_pos]))

[9, 10]

In [164]:
import copy

### for just the last position

# def get_lin_comb_cache(clean_cache, corrupted_cache, lh_list):
#     aux_cache = copy.deepcopy(corrupted_cache)
#     clean_comb = 0
#     corrupted_comb = 0
#     count = 0

#     for layer, head in lh_list:
#         count += 1
#         clean_comb += clean_cache[utils.get_act_name("z", layer)][:,-1,head,:]
#         corrupted_comb += corrupted_cache[utils.get_act_name("z", layer)][:,-1,head,:]
#         #print(f"1:{layer}")
#     comb = (clean_comb - corrupted_comb)/count
#     for layer, head in lh_list:

#         #print(f"2:{layer,head}")
#         aux_cache[utils.get_act_name("z", layer)][:,-1,head,:] += comb

#     return aux_cache


## lin comb of attention outputs 
def get_lin_comb_cache(clean_cache, corrupted_cache, lh_list):
    aux_cache = copy.deepcopy(corrupted_cache)
    clean_comb = 0
    corrupted_comb = 0
    count = 0

    for layer, head in lh_list:
        count += 1
        attn_out_clean = einops.einsum(model.W_O[layer,head], clean_cache[utils.get_act_name("z", layer)][:,:,head],"d_h d_m, n s d_h -> n s d_m") + model.b_O[layer]
        attn_out_corr = einops.einsum(model.W_O[layer,head], corrupted_cache[utils.get_act_name("z", layer)][:,:,head],"d_h d_m, n s d_h -> n s d_m") + model.b_O[layer]
        ## much more to cache this way 
        #print(attn_out_clean[:,-1].shape,clean_cache[utils.get_act_name("z", layer)][:,-1,head,:].shape)

        clean_comb += attn_out_clean[:,-1]
        corrupted_comb += attn_out_corr[:,-1]
        #print(f"1:{layer}")
    comb = (clean_comb - corrupted_comb)/count
    # do for each layer
    for layer in list(set([layer for layer, head in lh_list])):
        aux_cache[utils.get_act_name("attn_out", layer)][:,-1] += comb

    return aux_cache



lh_list_pos = [[9,6],[9,9], [10,0]]
lh_list_neg = [[10,7], [11,10]]
lh_list_tot = lh_list_neg + lh_list_pos

pos_cache = get_lin_comb_cache(clean_cache, corrupted_cache, lh_list_pos)
neg_cache = get_lin_comb_cache(clean_cache, corrupted_cache, lh_list_neg)
total_cache = get_lin_comb_cache(clean_cache, corrupted_cache, lh_list_tot)

In [169]:
clean_cache[utils.get_act_name("attn_out", 9)].shape

torch.Size([8, 15, 768])

In [176]:

from typing import List, Optional, Callable, Tuple, Dict, Literal, Set


def patch_head_output(
    corrupted_head_out: Float[t.Tensor, "batch pos d_model"],
    hook: HookPoint,
    from_cache: ActivationCache
) -> Float[t.Tensor, "batch d_model"]:
    '''
    Patches the output of a given head (before it's added to the residual stream) at
    every sequence position, using the value from the aux cache (either clean or linear comb).
    '''
    corrupted_head_out[:, -1, :] = from_cache[hook.name][:, -1, :]
    return corrupted_head_out

def get_act_patched(
    model: HookedTransformer,
    corrupted_tokens: Float[t.Tensor, "batch pos"],
    from_cache: ActivationCache,
    patching_metric: Callable,
    lh_list
) -> Float[t.Tensor, "layer"]:
    '''
    Returns an array of results of patching at all positions for each head in each
    layer, using the value from the clean cache.

    The results are calculated using the patching_metric function, which should be
    called on the model's logit output.
    '''
    model.reset_hooks()
    result = t.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)

    #head_hooks = [(utils.get_act_name("z", layer),partial(patch_head_vector, head_index=head, from_cache=from_cache)) for layer, head in lh_list]

    head_hooks = [(utils.get_act_name("attn_out", layer),partial(patch_head_output, from_cache=from_cache)) for layer, head in lh_list]


    patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = head_hooks,
            return_type="logits"
        )
    result = patching_metric(patched_logits)

    return result

In [177]:
act_patch_attn_indiv_pos = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_list_pos)
act_patch_attn_lin_comb_pos = get_act_patched(model, corrupted_tokens, pos_cache, ioi_metric, lh_list_pos)

act_patch_attn_indiv_neg = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_list_neg)
act_patch_attn_lin_comb_neg = get_act_patched(model, corrupted_tokens, neg_cache, ioi_metric, lh_list_neg)

act_patch_attn_indiv_tot = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_list_tot)
act_patch_attn_lin_comb_tot = get_act_patched(model, corrupted_tokens, total_cache, ioi_metric, lh_list_tot)



fraction of IO - S diff recovered by patching + Heads: (lin comb) 0.5298, (individual) = 0.6377
fraction of IO - S diff recovered by patching - Heads: (lin comb) -0.8113, (individual) = -0.5163
fraction of IO - S diff recovered by patching all Heads: (lin comb) 0.3477, (individual) = 0.8607


In [178]:
print(f'fraction of IO - S diff recovered by patching + Heads: (uni lin comb) {act_patch_attn_lin_comb_pos:.4f},(individual) = {act_patch_attn_indiv_pos:.4f}')
print(f' fraction of IO - S diff recovered by patching - Heads: (uni lin comb) {act_patch_attn_lin_comb_neg:.4f},(individual) = {act_patch_attn_indiv_neg:.4f}')
print(f' fraction of IO - S diff recovered by patching all heads: (uni lin comb) = {act_patch_attn_lin_comb_tot:.4f}, (individual) = {act_patch_attn_indiv_tot:.4f}')

print(f'fraction of IO - S diff recovered by patching + Heads: (attn weighted lin comb) {act_patch_attn_lin_comb_pos_weighted_attn:.4f},(individual) = {act_patch_attn_indiv_pos:.4f}')
print(f' fraction of IO - S diff recovered by patching - Heads: (attn weighted lin comb) {act_patch_attn_lin_comb_neg_weighted_attn:.4f},(individual) = {act_patch_attn_indiv_neg:.4f}')
print(f' fraction of IO - S diff recovered by patching all heads: (attn weighted lin comb) = {act_patch_attn_lin_comb_tot_weighted_attn:.4f}, (individual) = {act_patch_attn_indiv_tot:.4f}')

fraction of IO - S diff recovered by patching + Heads: (uni lin comb) 0.5298,(individual) = 0.6377
 fraction of IO - S diff recovered by patching - Heads: (uni lin comb) -0.8113,(individual) = -0.5163
 fraction of IO - S diff recovered by patching all heads: (uni lin comb) = 0.3477, (individual) = 0.8607
fraction of IO - S diff recovered by patching + Heads: (attn weighted lin comb) 0.2017,(individual) = 0.6377
 fraction of IO - S diff recovered by patching - Heads: (attn weighted lin comb) -0.4857,(individual) = -0.5163
 fraction of IO - S diff recovered by patching all heads: (attn weighted lin comb) = 0.0855, (individual) = 0.8607


Add in backup name movers to see if it makes a difference: Compare individually patched heads + backups with linear combination of backups and +/- movers

In [21]:
lh_backup = [[10,10],[10,6],[10,2],[10,1],[11,2],[11,9],[9,0],[9,7]]

lh_neg_bu = lh_list_neg + lh_backup
lh_pos_bu = lh_list_pos + lh_backup
lh_tot_bu = lh_list_tot + lh_backup

pos_cache_bu = get_lin_comb_cache(clean_cache, corrupted_cache, lh_pos_bu)
neg_cache_bu = get_lin_comb_cache(clean_cache, corrupted_cache, lh_neg_bu)
total_cache_bu = get_lin_comb_cache(clean_cache, corrupted_cache, lh_tot_bu)

act_patch_attn_lin_comb_pos_bu = get_act_patched(model, corrupted_tokens, pos_cache_bu, ioi_metric, lh_pos_bu)
act_patch_attn_lin_comb_neg_bu = get_act_patched(model, corrupted_tokens, neg_cache_bu, ioi_metric, lh_neg_bu)
act_patch_attn_lin_comb_tot_bu = get_act_patched(model, corrupted_tokens, total_cache_bu, ioi_metric, lh_tot_bu)

act_patch_attn_indiv_pos_bu = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_pos_bu)
act_patch_attn_indiv_neg_bu = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_neg_bu)
act_patch_attn_indiv_tot_bu = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_tot_bu)

print(f'fraction of IO - S diff recovered by patching + Heads + bu: (uni lin comb) {act_patch_attn_lin_comb_pos_bu:.4f},(individual) = {act_patch_attn_indiv_pos_bu:.4f}, (attn weighted lin comb) {act_patch_attn_lin_comb_pos_weighted_attn_bu:.4f}')
print(f' fraction of IO - S diff recovered by patching - Heads + bu: (uni lin comb) {act_patch_attn_lin_comb_neg_bu:.4f},(individual) = {act_patch_attn_indiv_neg_bu:.4f}, (attn weighted lin comb) {act_patch_attn_lin_comb_neg_weighted_attn_bu:.4f}')
print(f' fraction of IO - S diff recovered by patching all heads + bu: (uni lin comb) = {act_patch_attn_lin_comb_tot_bu:.4f}, (individual) = {act_patch_attn_indiv_tot_bu:.4f}, (attn weighted lin comb) {act_patch_attn_lin_comb_tot_weighted_attn_bu:.4f}')



fraction of IO - S diff recovered by patching + Heads + bu: (lin comb) 0.0569,(individual+ backup) = 0.6625
 fraction of IO - S diff recovered by patching - Heads + bu: (lin comb) -0.0739,(individual+backup) = -0.6918
 fraction of IO - S diff recovered by patching all heads + bu: (lin comb) = 0.0352, (individual+backup) = 0.6837


In [125]:
act_patch_attn_lin_comb_pos_bu = get_act_patched(model, corrupted_tokens, pos_cache_bu, ioi_metric, lh_pos_bu)
act_patch_attn_lin_comb_neg_bu = get_act_patched(model, corrupted_tokens, neg_cache_bu, ioi_metric, lh_neg_bu)
act_patch_attn_lin_comb_tot_bu = get_act_patched(model, corrupted_tokens, total_cache_bu, ioi_metric, lh_tot_bu)

act_patch_attn_indiv_pos_bu = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_pos_bu)
act_patch_attn_indiv_neg_bu = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_neg_bu)
act_patch_attn_indiv_tot_bu = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_tot_bu)

print(f'fraction of IO - S diff recovered by patching + Heads + bu: (lin comb) {act_patch_attn_lin_comb_pos_bu:.4f},(individual+ backup) = {act_patch_attn_indiv_pos_bu:.4f}')
print(f' fraction of IO - S diff recovered by patching - Heads + bu: (lin comb) {act_patch_attn_lin_comb_neg_bu:.4f},(individual+backup) = {act_patch_attn_indiv_neg_bu:.4f}')
print(f' fraction of IO - S diff recovered by patching all heads + bu: (lin comb) = {act_patch_attn_lin_comb_tot_bu:.4f}, (individual+backup) = {act_patch_attn_indiv_tot_bu:.4f}')

fraction of IO - S diff recovered by patching + Heads + bu: (lin comb) 0.0569,(individual+ backup) = 0.6625
 fraction of IO - S diff recovered by patching - Heads + bu: (lin comb) -0.0739,(individual+backup) = -0.6918
 fraction of IO - S diff recovered by patching all heads + bu: (lin comb) = 0.0352, (individual+backup) = 0.6837


tweak code for linear combination according to attention paid to IO

In [119]:
io_pos = t.tensor([[i for i, el in enumerate(model.to_str_tokens(prompts[k])) if el ==answers[k][0]][0] for k in range(len(prompts))])
end_pos = t.tensor([-1]*len(prompts))

def get_attn_weight(cache, lh_list):
    attn_list = []
    for l, h in lh_list:
        patt = cache["pattern",l][:,h]
        attn_from_end = patt[range(patt.size(0)), end_pos, : ]
        attn_end_to_io = attn_from_end[range(patt.size(0)), io_pos].mean(dim = -1)
        attn_list.append(attn_end_to_io)
    softmax_attn = t.tensor(attn_list).softmax(dim=0)

    assert len(attn_list) == len(lh_list)
    return softmax_attn


def get_lin_comb_cache_weighted_attn(clean_cache, corrupted_cache, lh_list):
    aux_cache = copy.deepcopy(corrupted_cache)
    clean_comb = 0
    corrupted_comb = 0
    count = 0
    weight_list = get_attn_weight(clean_cache, lh_list)
    #assert weight_list.sum() == 1
    for i, (layer, head) in enumerate(lh_list):
        count += 1

        attn_end_to_io = weight_list[i]      

        clean_comb += attn_end_to_io * clean_cache[utils.get_act_name("z", layer)][:,-1,head,:]
        corrupted_comb += attn_end_to_io * corrupted_cache[utils.get_act_name("z", layer)][:,-1,head,:]
        #print(f"1:{layer}")
    comb = (clean_comb - corrupted_comb)
    for layer, head in lh_list:

        #print(f"2:{layer,head}")
        aux_cache[utils.get_act_name("z", layer)][:,-1,head,:] += comb

    return aux_cache


pos_cache_weighted_attn = get_lin_comb_cache_weighted_attn(clean_cache, corrupted_cache, lh_list_pos)
neg_cache_weighted_attn = get_lin_comb_cache_weighted_attn(clean_cache, corrupted_cache, lh_list_neg)
total_cache_weighted_attn = get_lin_comb_cache_weighted_attn(clean_cache, corrupted_cache, lh_list_tot)

In [120]:
t.tensor([0.6936,0.8170,0.3626]).softmax(dim=0)

tensor([0.3509, 0.3970, 0.2520])

In [121]:
act_patch_attn_lin_comb_pos_weighted_attn = get_act_patched(model, corrupted_tokens, pos_cache_weighted_attn, ioi_metric, lh_list_pos)

act_patch_attn_lin_comb_neg_weighted_attn = get_act_patched(model, corrupted_tokens, neg_cache_weighted_attn, ioi_metric, lh_list_neg)

act_patch_attn_lin_comb_tot_weighted_attn = get_act_patched(model, corrupted_tokens, total_cache_weighted_attn, ioi_metric, lh_list_tot)

print(f'fraction of IO - S diff recovered by patching + Heads: (lin comb) {act_patch_attn_lin_comb_pos_weighted_attn:.4f}, (individual) = {act_patch_attn_indiv_pos:.4f}')
print(f'fraction of IO - S diff recovered by patching - Heads: (lin comb) {act_patch_attn_lin_comb_neg_weighted_attn:.4f}, (individual) = {act_patch_attn_indiv_neg:.4f}')
print(f'fraction of IO - S diff recovered by patching all Heads: (lin comb) {act_patch_attn_lin_comb_tot_weighted_attn:.4f}, (individual) = {act_patch_attn_indiv_tot:.4f}')

fraction of IO - S diff recovered by patching + Heads: (lin comb) 0.2017, (individual) = 0.3267
fraction of IO - S diff recovered by patching - Heads: (lin comb) -0.4857, (individual) = -0.8536
fraction of IO - S diff recovered by patching all Heads: (lin comb) 0.0855, (individual) = 0.1591


In [130]:
pos_cache_weighted_attn_bu = get_lin_comb_cache_weighted_attn(clean_cache, corrupted_cache, lh_pos_bu)
neg_cache_weighted_attn_bu = get_lin_comb_cache_weighted_attn(clean_cache, corrupted_cache, lh_neg_bu)
total_cache_weighted_attn_bu = get_lin_comb_cache_weighted_attn(clean_cache, corrupted_cache, lh_tot_bu)

act_patch_attn_lin_comb_pos_weighted_attn_bu = get_act_patched(model, corrupted_tokens, pos_cache_weighted_attn_bu, ioi_metric, lh_pos_bu)

act_patch_attn_lin_comb_neg_weighted_attn_bu = get_act_patched(model, corrupted_tokens, neg_cache_weighted_attn_bu, ioi_metric, lh_neg_bu)

act_patch_attn_lin_comb_tot_weighted_attn_bu = get_act_patched(model, corrupted_tokens, total_cache_weighted_attn_bu, ioi_metric, lh_tot_bu)

In [132]:

prompt_tab = Table('name mover heads', 'individually patched', 'uniform linear', 'weighted by average attention to IO', title = 'fraction of IO - S diff recovered by patching for IOI (flipped) - > IOI')

prompt_tab.add_row('pos', f'{act_patch_attn_indiv_pos:.4f}', f'{act_patch_attn_lin_comb_pos:.4f}', f'{act_patch_attn_lin_comb_pos_weighted_attn:.4f}')
prompt_tab.add_row('neg', f'{act_patch_attn_indiv_neg:.4f}', f'{act_patch_attn_lin_comb_neg:.4f}', f'{act_patch_attn_lin_comb_neg_weighted_attn:.4f}')
prompt_tab.add_row('pos + neg', f'{act_patch_attn_indiv_tot:.4f}', f'{act_patch_attn_lin_comb_tot:.4f}', f'{act_patch_attn_lin_comb_tot_weighted_attn:.4f}')
prompt_tab.add_row('pos + b/u', f'{act_patch_attn_indiv_pos_bu:.4f}', f'{act_patch_attn_lin_comb_pos_bu:.4f}', f'{act_patch_attn_lin_comb_pos_weighted_attn_bu:.4f}')
prompt_tab.add_row('neg + b/u', f'{act_patch_attn_indiv_neg_bu:.4f}', f'{act_patch_attn_lin_comb_neg_bu:.4f}', f'{act_patch_attn_lin_comb_neg_weighted_attn_bu:.4f}')
prompt_tab.add_row('pos + neg + b/u', f'{act_patch_attn_indiv_tot_bu:.4f}', f'{act_patch_attn_lin_comb_tot_bu:.4f}', f'{act_patch_attn_lin_comb_tot_weighted_attn_bu:.4f}')


rprint(prompt_tab)

Above was fed through the rest of the network, but we want to study the direct logit attribution instead. In this case, we treat each head as if they're in the same layer (equivalent to changing the basis)

In [22]:
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]])

In [23]:

def get_head_out(input_cache, layer, head):      
    attn_out = einops.einsum(model.W_O[layer,head], input_cache[utils.get_act_name("z", layer)][:,:,head],"d_h d_m, n s d_h -> n s d_m") + model.b_O[layer]
    return einops.einsum(model.W_U, attn_out, "d_m n_ctx, b s d_m -> b s n_ctx ")


def compare_patched_logits(indiv_cache, lin_comb_cache, lh_list, answer_tokens):
    indiv_logits = [logits_to_ave_logit_diff(get_head_out(indiv_cache, layer, head), answer_tokens) for layer, head in lh_list]
    lin_logits = [logits_to_ave_logit_diff(get_head_out(lin_comb_cache, layer, head), answer_tokens) for layer, head in lh_list]
    return np.sum(indiv_logits), np.sum(lin_logits)

compare_patched_logits(clean_cache, total_cache, lh_list_tot, answer_tokens)




(38.14662, -20.995476)

tracing the negative name mover (last one) through the MLP

In [24]:
d_head = model.cfg.d_head
assert model.W_O.shape == t.Size([n_layers,n_heads,d_head,d_model])
assert cache[utils.get_act_name("z", 11)].shape == t.Size([n_ex,prompt_len,n_heads,d_head])

# act with W_O
last_neg_out_no_bias = einops.einsum(model.W_O[11,10], cache[utils.get_act_name("z", 11)][:,:,10],"d_h d_m, n s d_h -> n s d_m")


Still to do: path patching and debug. Make sure the direct logit attribution for the uniform linear comb is the same as the individual patch 