In [1]:
# Import stuff
import torch as t
import numpy as np
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.express as px
import einops
import plotly.graph_objects as go 
from functools import partial
import tqdm.auto as tqdm
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from fancy_einsum import einsum
from jaxtyping import Float, Int, Bool
import re

#from plotly_utils import imshow, line, scatter, bar


In [2]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [4]:
## turn off AD to save memory, since we're focusing on model inference here 
t.set_grad_enabled(False)

device = 'cuda' if t.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [5]:

prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]

names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )

        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = t.tensor(answer_tokens).to(device)

### check that all the prompts have the same number of tokens 
prompt_len = len(model.to_str_tokens(prompts[1]))
assert len(set([len(model.to_str_tokens(prompt)) for prompt in prompts])) == 1


In [6]:
### print all prompts in a table (learned from Keith's notebook! )
from rich.table import Table, Column
from rich import print as rprint


prompt_tab = Table('prompt', 'clean', 'corrupted', title = 'prompts and answers')

for i in range(len(prompts)):
    prompt_tab.add_row(prompts[i], answers[i][0], answers[i][1])

rprint(prompt_tab)

In [7]:
n_ex = len(prompts)


tokens = model.to_tokens(prompts, prepend_bos = True).to(device)
og_logits, cache = model.run_with_cache(tokens)


Design a metric to test model performance. In this case, we'll use the logit difference between the indirect object (correct answer) and the subject (incorrect answer)

In [8]:
d_vocab = model.cfg.d_vocab
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
n_layers = model.cfg.n_layers



assert og_logits.shape == t.Size([n_ex, prompt_len, d_vocab])

def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt = False):
    # take the last logit for every prompt (only these are relevant to the answer)
    final_logits = logits[:,-1,:]
    # get the logits corresponding to the IO/ sub tokens 
    answer_logits = final_logits.gather(dim=-1, index = answer_tokens)
    answer_logit_diff = answer_logits[:,0] - answer_logits[:,1]
    ## If per_prompt = True, return an array of the per_prompt difference, instead of the average 
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()
    

og_logit_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=True)
og_logit_avg_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=False)

cols = [
    "Prompt", 
    Column("Correct", style="rgb(0,200,0) bold"), 
    Column("Incorrect", style="rgb(255,0,0) bold"), 
    Column("Logit Difference", style="bold"), Column("Avg Logit Difference", style="bold")
]
logit_diff_table = Table(*cols, title="Logit differences")

for prompt, ans, logit_diff in zip(prompts, answers,og_logit_diff):
    logit_diff_table.add_row(prompt, ans[0], ans[1], f"{logit_diff.item():.3f}")
rprint(logit_diff_table)



In [9]:
from transformer_lens import patching

clean_tokens = tokens 
idx_swap = [i+1 if i % 2 == 0 else i-1 for i in range(len(tokens))]
corrupted_tokens = clean_tokens[idx_swap]

print(
    "Clean string 0:    ", model.to_string(clean_tokens[0]), "\n"
    "Corrupted string 0:", model.to_string(corrupted_tokens[0])
)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean string 0:     <|endoftext|>When John and Mary went to the shops, John gave the bag to 
Corrupted string 0: <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Clean logit diff: 3.5519
Corrupted logit diff: -3.5519


In [10]:
def ioi_metric(
    logits: Float[t.Tensor, "batch seq d_vocab"],
    answer_tokens: Float[t.Tensor, "batch 2"] = answer_tokens,
    corrupted_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
) -> Float[t.Tensor, ""]:

    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff  - corrupted_logit_diff)

In [11]:
# checking that this does what we want 

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [13]:
model.blocks[0]

TransformerBlock(
  (ln1): LayerNormPre(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (ln2): LayerNormPre(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (attn): Attention(
    (hook_k): HookPoint()
    (hook_q): HookPoint()
    (hook_v): HookPoint()
    (hook_z): HookPoint()
    (hook_attn_scores): HookPoint()
    (hook_pattern): HookPoint()
    (hook_result): HookPoint()
  )
  (mlp): MLP(
    (hook_pre): HookPoint()
    (hook_post): HookPoint()
  )
  (hook_q_input): HookPoint()
  (hook_k_input): HookPoint()
  (hook_v_input): HookPoint()
  (hook_attn_out): HookPoint()
  (hook_mlp_out): HookPoint()
  (hook_resid_pre): HookPoint()
  (hook_resid_mid): HookPoint()
  (hook_resid_post): HookPoint()
)

In [18]:


## lin comb of attention outputs, patch whole thing to residual stream 
def get_lin_comb_cache(clean_cache, corrupted_cache, lh_list):
    aux_logits, aux_cache = model.run_with_cache(corrupted_tokens, device = device)
    clean_comb = 0
    corrupted_comb = 0
    count = 0

    for layer, head in lh_list:
        count += 1
        attn_out_clean = einops.einsum(model.W_O[layer,head], clean_cache[utils.get_act_name("z", layer)][:,:,head],"d_h d_m, n s d_h -> n s d_m") + model.b_O[layer]
        attn_out_corr = einops.einsum(model.W_O[layer,head], corrupted_cache[utils.get_act_name("z", layer)][:,:,head],"d_h d_m, n s d_h -> n s d_m") + model.b_O[layer]
        ## much more to cache this way 


        clean_comb += attn_out_clean
        corrupted_comb += attn_out_corr
        #print(f"1:{layer}")
    comb = (clean_comb - corrupted_comb)/count
    print(comb.shape, attn_out_clean.shape)

    # do for each layer
    for layer, head in lh_list:
        attn_out_corr = einops.einsum(model.W_O[layer,head], corrupted_cache[utils.get_act_name("z", layer)][:,:,head],"d_h d_m, n s d_h -> n s d_m") + model.b_O[layer]

        aux_cache[utils.get_act_name("attn_out", layer)][:,-1] += attn_out_corr[:,-1] + comb[:,-1]


    return aux_cache



lh_list_pos = [[9,6],[9,9], [10,0]]
lh_list_neg = [[10,7], [11,10]]
lh_list_tot = lh_list_neg + lh_list_pos

pos_cache = get_lin_comb_cache(clean_cache, corrupted_cache, lh_list_pos)
neg_cache = get_lin_comb_cache(clean_cache, corrupted_cache, lh_list_neg)
total_cache = get_lin_comb_cache(clean_cache, corrupted_cache, lh_list_tot)

torch.Size([8, 15, 768]) torch.Size([8, 15, 768])
torch.Size([8, 15, 768]) torch.Size([8, 15, 768])
torch.Size([8, 15, 768]) torch.Size([8, 15, 768])


In [19]:

from typing import List, Optional, Callable, Tuple, Dict, Literal, Set


def patch_head_output(
    corrupted_head_out: Float[t.Tensor, "batch pos d_model"],
    hook: HookPoint,
    from_cache: ActivationCache
) -> Float[t.Tensor, "batch d_model"]:
    '''
    Patches the output of a given head (before it's added to the residual stream) at
    the last sequence position, using the value from the aux cache (either clean or linear comb).
    '''
    corrupted_head_out[:, -1, :] = from_cache[hook.name][:, -1, :]
    return corrupted_head_out

def patch_head_vector(
    corrupted_head_vector: Float[t.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    head_index: int,
    from_cache: ActivationCache
) -> Float[t.Tensor, "batch pos head_index d_head"]:
    '''
    Patches the output of a given head (before it's added to the residual stream) at
    every sequence position, using the value from the aux cache (either clean or linear comb).
    '''
    corrupted_head_vector[range(corrupted_head_vector.size(0)), -1, head_index] = from_cache[hook.name][range(corrupted_head_vector.size(0)), -1, head_index]
    return corrupted_head_vector

def get_act_patched_z(
    model: HookedTransformer,
    corrupted_tokens: Float[t.Tensor, "batch pos"],
    from_cache: ActivationCache,
    patching_metric: Callable,
    lh_list
) -> Float[t.Tensor, "layer"]:
    '''
    Returns an array of results of patching at all positions for each head in each
    layer, using the value from the clean cache.

    The results are calculated using the patching_metric function, which should be
    called on the model's logit output.
    '''
    model.reset_hooks()
    result = t.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)

    head_hooks = [(utils.get_act_name("z", layer),partial(patch_head_vector, head_index=head, from_cache=from_cache)) for layer, head in lh_list]



    patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = head_hooks,
            return_type="logits"
        )
    result = patching_metric(patched_logits)

    return result

def get_act_patched_attn(
    model: HookedTransformer,
    corrupted_tokens: Float[t.Tensor, "batch pos"],
    from_cache: ActivationCache,
    patching_metric: Callable,
    lh_list
) -> Float[t.Tensor, "layer"]:
    '''
    Returns an array of results of patching at all positions for each head in each
    layer, using the value from the clean cache.

    The results are calculated using the patching_metric function, which should be
    called on the model's logit output.
    '''
    model.reset_hooks()
    result = t.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)

    #head_hooks = [(utils.get_act_name("z", layer),partial(patch_head_vector, head_index=head, from_cache=from_cache)) for layer, head in lh_list]

    head_hooks = [(utils.get_act_name("attn_out", layer),partial(patch_head_output, from_cache=from_cache)) for layer, head in lh_list]


    patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = head_hooks,
            return_type="logits"
        )
    result = patching_metric(patched_logits)

    return result

In [22]:
test = lh_list_pos
print(get_act_patched_z(model, corrupted_tokens, clean_cache, ioi_metric, test))
test_cache = get_lin_comb_cache(clean_cache, corrupted_cache, test)
print(get_act_patched_attn(model, corrupted_tokens, test_cache, ioi_metric, test))

tensor(0.3267)
torch.Size([8, 15, 768]) torch.Size([8, 15, 768])
tensor(0.4481)
