In [1]:
# Import stuff
import torch as t
import numpy as np
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.express as px
import einops
import plotly.graph_objects as go 
from functools import partial
import tqdm.auto as tqdm
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from fancy_einsum import einsum
from jaxtyping import Float, Int, Bool
import re

#from plotly_utils import imshow, line, scatter, bar


In [2]:
## plotting functions 
update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat",
    "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor",
    "showlegend", "xaxis_tickmode", "yaxis_tickmode", "xaxis_tickangle", "yaxis_tickangle", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"
}

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
## turn off AD to save memory, since we're focusing on model inference here 
t.set_grad_enabled(False)

device = 'cuda' if t.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Generate example prompts for IOI along with clean and corrupted answers. It's important that they're all the same length (taken from exploratory analysis demo )

In [4]:

prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]

names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )

        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = t.tensor(answer_tokens).to(device)

### check that all the prompts have the same number of tokens 
prompt_len = len(model.to_str_tokens(prompts[1]))
assert len(set([len(model.to_str_tokens(prompt)) for prompt in prompts])) == 1


In [5]:
### print all prompts in a table (learned from Keith's notebook! )
from rich.table import Table, Column
from rich import print as rprint


prompt_tab = Table('prompt', 'clean', 'corrupted', title = 'prompts and answers')

for i in range(len(prompts)):
    prompt_tab.add_row(prompts[i], answers[i][0], answers[i][1])

rprint(prompt_tab)

cache the logits and model internals for all the prompts

In [6]:
n_ex = len(prompts)


tokens = model.to_tokens(prompts, prepend_bos = True).to(device)
og_logits, cache = model.run_with_cache(tokens)


Design a metric to test model performance. In this case, we'll use the logit difference between the indirect object (correct answer) and the subject (incorrect answer)

In [7]:
d_vocab = model.cfg.d_vocab
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
n_layers = model.cfg.n_layers



assert og_logits.shape == t.Size([n_ex, prompt_len, d_vocab])

def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt = False):
    # take the last logit for every prompt (only these are relevant to the answer)
    final_logits = logits[:,-1,:]
    # get the logits corresponding to the IO/ sub tokens 
    answer_logits = final_logits.gather(dim=-1, index = answer_tokens)
    answer_logit_diff = answer_logits[:,0] - answer_logits[:,1]
    ## If per_prompt = True, return an array of the per_prompt difference, instead of the average 
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()
    

og_logit_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=True)
og_logit_avg_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=False)

cols = [
    "Prompt", 
    Column("Correct", style="rgb(0,200,0) bold"), 
    Column("Incorrect", style="rgb(255,0,0) bold"), 
    Column("Logit Difference", style="bold"), Column("Avg Logit Difference", style="bold")
]
logit_diff_table = Table(*cols, title="Logit differences")

for prompt, ans, logit_diff in zip(prompts, answers,og_logit_diff):
    logit_diff_table.add_row(prompt, ans[0], ans[1], f"{logit_diff.item():.3f}")
rprint(logit_diff_table)



In [33]:
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]])

In [9]:
from transformer_lens import patching

clean_tokens = tokens 
idx_swap = [i+1 if i % 2 == 0 else i-1 for i in range(len(tokens))]
corrupted_tokens = clean_tokens[idx_swap]

print(
    "Clean string 0:    ", model.to_string(clean_tokens[0]), "\n"
    "Corrupted string 0:", model.to_string(corrupted_tokens[0])
)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean string 0:     <|endoftext|>When John and Mary went to the shops, John gave the bag to 
Corrupted string 0: <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Clean logit diff: 3.5519
Corrupted logit diff: -3.5519


In [10]:
def ioi_metric(
    logits: Float[t.Tensor, "batch seq d_vocab"],
    answer_tokens: Float[t.Tensor, "batch 2"] = answer_tokens,
    corrupted_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
) -> Float[t.Tensor, ""]:

    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff  - corrupted_logit_diff)

In [11]:
# checking that this does what we want 

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [12]:
import copy

def get_lin_comb_cache(clean_cache, corrupted_cache, lh_list):
    aux_cache = copy.deepcopy(clean_cache)
    clean_comb = 0
    corrupted_comb = 0
    count = 0

    for layer, head in lh_list:
        count += 1
        clean_comb += clean_cache[utils.get_act_name("z", layer)][:,:,head,:]
        corrupted_comb += corrupted_cache[utils.get_act_name("z", layer)][:,:,head,:]
        #print(f"1:{layer}")
    comb = (clean_comb - corrupted_comb)/count
    for layer, head in lh_list:

        #print(f"2:{layer,head}")
        aux_cache[utils.get_act_name("z", layer)][:,:,head,:] += comb

    return aux_cache

lh_list_pos = [[9,6],[9,9], [10,0]]
lh_list_neg = [[10,7], [11,10]]
lh_list_tot = lh_list_neg + lh_list_pos

pos_cache = get_lin_comb_cache(clean_cache, corrupted_cache, lh_list_pos)
neg_cache = get_lin_comb_cache(clean_cache, corrupted_cache, lh_list_neg)
total_cache = get_lin_comb_cache(clean_cache, corrupted_cache, lh_list_tot)

In [13]:

from typing import List, Optional, Callable, Tuple, Dict, Literal, Set


def patch_head_vector(
    corrupted_head_vector: Float[t.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    head_index: int,
    from_cache: ActivationCache
) -> Float[t.Tensor, "batch pos head_index d_head"]:
    '''
    Patches the output of a given head (before it's added to the residual stream) at
    every sequence position, using the value from the aux cache (either clean or linear comb).
    '''
    corrupted_head_vector[:, :, head_index] = from_cache[hook.name][:, :, head_index]
    return corrupted_head_vector

def get_act_patched(
    model: HookedTransformer,
    corrupted_tokens: Float[t.Tensor, "batch pos"],
    from_cache: ActivationCache,
    patching_metric: Callable,
    lh_list
) -> Float[t.Tensor, "layer head"]:
    '''
    Returns an array of results of patching at all positions for each head in each
    layer, using the value from the clean cache.

    The results are calculated using the patching_metric function, which should be
    called on the model's logit output.
    '''
    model.reset_hooks()
    result = t.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)

    head_hooks = [(utils.get_act_name("z", layer),partial(patch_head_vector, head_index=head, from_cache=from_cache)) for layer, head in lh_list]


    patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = head_hooks,
            return_type="logits"
        )
    result = patching_metric(patched_logits)

    return result

In [14]:

len([(utils.get_act_name("z", layer),partial(patch_head_vector, head_index=head, from_cache=clean_cache)) for layer, head in lh_list_neg]), len(lh_list_neg)

(2, 2)

In [15]:
act_patch_attn_indiv_pos = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_list_pos)
act_patch_attn_lin_comb_pos = get_act_patched(model, corrupted_tokens, pos_cache, ioi_metric, lh_list_pos)

act_patch_attn_indiv_neg = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_list_neg)
act_patch_attn_lin_comb_neg = get_act_patched(model, corrupted_tokens, neg_cache, ioi_metric, lh_list_neg)

act_patch_attn_indiv_tot = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_list_tot)
act_patch_attn_lin_comb_tot = get_act_patched(model, corrupted_tokens, total_cache, ioi_metric, lh_list_tot)

print(f'+ Heads: result(lin comb) = {act_patch_attn_lin_comb_pos:.4f}, result(individual) = {act_patch_attn_indiv_pos:.4f}')
print(f'- Heads: result(lin comb) = {act_patch_attn_lin_comb_neg:.4f}, result(individual) = {act_patch_attn_indiv_neg:.4f}')
print(f'All Heads: result(lin comb) = {act_patch_attn_lin_comb_tot:.4f}, result(individual) = {act_patch_attn_indiv_tot:.4f}')

print(f'Fraction recovered by positive heads:{(act_patch_attn_lin_comb_pos)/t.abs(act_patch_attn_indiv_pos):.4f}')
print(f'Fraction recovered by negative heads:{(act_patch_attn_lin_comb_neg)/t.abs(act_patch_attn_indiv_neg):.4f}')
print(f'Fraction recovered by all heads:{(act_patch_attn_lin_comb_tot)/t.abs(act_patch_attn_indiv_tot):.4f}')

+ Heads: result(lin comb) = 0.6618, result(individual) = 0.3353
- Heads: result(lin comb) = -1.2503, result(individual) = -0.8568
All Heads: result(lin comb) = 0.3069, result(individual) = 0.1690
Fraction recovered by positive heads:1.9737
Fraction recovered by negative heads:-1.4592
Fraction recovered by all heads:1.8162


Add in backup name movers to see if it makes a difference: First compare individually patched heads + backups with linear combination of just +/- movers

In [16]:
lh_backup = [[10,10],[10,6],[10,2],[10,1],[11,2],[11,9],[9,0],[9,7]]

lh_neg_bu = lh_list_neg + lh_backup
lh_pos_bu = lh_list_pos + lh_backup
lh_tot_bu = lh_list_tot + lh_backup

act_patch_attn_indiv_pos_bu = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_pos_bu)
act_patch_attn_indiv_neg_bu = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_neg_bu)
act_patch_attn_indiv_tot_bu = get_act_patched(model, corrupted_tokens, clean_cache, ioi_metric, lh_tot_bu)

print(f'+ Heads: result(lin comb) = {act_patch_attn_lin_comb_pos:.4f}, result(individual+ backup) = {act_patch_attn_indiv_pos_bu:.4f}')
print(f'- Heads: result(lin comb) = {act_patch_attn_lin_comb_neg:.4f}, result(individual+backup) = {act_patch_attn_indiv_neg_bu:.4f}')
print(f'All Heads: result(lin comb) = {act_patch_attn_lin_comb_tot:.4f}, result(individual+backup) = {act_patch_attn_indiv_tot_bu:.4f}')

print(f'Fraction recovered by positive heads:{(act_patch_attn_lin_comb_pos)/t.abs(act_patch_attn_indiv_pos_bu):.4f}')
print(f'Fraction recovered by negative heads:{(act_patch_attn_lin_comb_neg)/t.abs(act_patch_attn_indiv_neg_bu):.4f}')
print(f'Fraction recovered by all heads:{(act_patch_attn_lin_comb_tot)/t.abs(act_patch_attn_indiv_tot_bu):.4f}')

+ Heads: result(lin comb) = 0.6618, result(individual+ backup) = 0.6764
- Heads: result(lin comb) = -1.2503, result(individual+backup) = -0.6919
All Heads: result(lin comb) = 0.3069, result(individual+backup) = 0.6992
Fraction recovered by positive heads:0.9783
Fraction recovered by negative heads:-1.8069
Fraction recovered by all heads:0.4389


In [17]:
### Now patch in the linear combination of heads + backup heads, and compare with above 

pos_cache_bu = get_lin_comb_cache(clean_cache, corrupted_cache, lh_pos_bu)
neg_cache_bu = get_lin_comb_cache(clean_cache, corrupted_cache, lh_neg_bu)
total_cache_bu = get_lin_comb_cache(clean_cache, corrupted_cache, lh_tot_bu)

act_patch_attn_lin_comb_pos_bu = get_act_patched(model, corrupted_tokens, pos_cache_bu, ioi_metric, lh_pos_bu)
act_patch_attn_lin_comb_neg_bu = get_act_patched(model, corrupted_tokens, pos_cache_bu, ioi_metric, lh_neg_bu)
act_patch_attn_lin_comb_tot_bu = get_act_patched(model, corrupted_tokens, pos_cache_bu, ioi_metric, lh_tot_bu)

print(f'+ Heads: result(lin comb + bu) = {act_patch_attn_lin_comb_pos_bu:.4f}, result(individual + bu) = {act_patch_attn_indiv_pos_bu:.4f}')
print(f'- Heads: result(lin comb + bu) = {act_patch_attn_lin_comb_neg_bu:.4f}, result(individual + bu) = {act_patch_attn_indiv_neg_bu:.4f}')
print(f'All Heads: result(lin comb + bu) = {act_patch_attn_lin_comb_tot_bu:.4f}, result(individual + bu) = {act_patch_attn_indiv_tot_bu:.4f}')

print(f'Fraction recovered by positive heads:{(act_patch_attn_lin_comb_pos_bu)/t.abs(act_patch_attn_indiv_pos_bu):.4f}')
print(f'Fraction recovered by negative heads:{(act_patch_attn_lin_comb_neg_bu)/t.abs(act_patch_attn_indiv_neg_bu):.4f}')
print(f'Fraction recovered by all heads:{(act_patch_attn_lin_comb_tot_bu)/t.abs(act_patch_attn_indiv_tot_bu):.4f}')

+ Heads: result(lin comb + bu) = 0.7651, result(individual + bu) = 0.6764
- Heads: result(lin comb + bu) = -0.7013, result(individual + bu) = -0.6919
All Heads: result(lin comb + bu) = 0.8322, result(individual + bu) = 0.6992
Fraction recovered by positive heads:1.1310
Fraction recovered by negative heads:-1.0135
Fraction recovered by all heads:1.1903


Above was fed through the rest of the network, but we want to study the direct logit attribution instead. In this case, we treat each head as if they're in the same layer (equivalent to changing the basis)

In [18]:
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]])

In [19]:

def get_head_out(input_cache, layer, head):      
    attn_out = einops.einsum(model.W_O[layer,head], input_cache[utils.get_act_name("z", layer)][:,:,head],"d_h d_m, n s d_h -> n s d_m") + model.b_O[layer]
    return einops.einsum(model.W_U, attn_out, "d_m n_ctx, b s d_m -> b s n_ctx ")


def compare_patched_logits(indiv_cache, lin_comb_cache, lh_list, answer_tokens):
    indiv_logits = [logits_to_ave_logit_diff(get_head_out(indiv_cache, layer, head), answer_tokens) for layer, head in lh_list]
    lin_logits = [logits_to_ave_logit_diff(get_head_out(lin_comb_cache, layer, head), answer_tokens) for layer, head in lh_list]
    return np.sum(indiv_logits), np.sum(lin_logits)

compare_patched_logits(clean_cache, total_cache, lh_list_tot, answer_tokens)




(38.14662, 55.2978)

tracing the negative name mover (last one) through the MLP

In [20]:
d_head = model.cfg.d_head
assert model.W_O.shape == t.Size([n_layers,n_heads,d_head,d_model])
assert cache[utils.get_act_name("z", 11)].shape == t.Size([n_ex,prompt_len,n_heads,d_head])

# act with W_O
last_neg_out_no_bias = einops.einsum(model.W_O[11,10], cache[utils.get_act_name("z", 11)][:,:,10],"d_h d_m, n s d_h -> n s d_m")


Still to do: path patching and debug. Make sure the direct logit attribution for the uniform linear comb is the same as the individual patch 