In [74]:
# Import stuff
import torch as t
import numpy as np
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.express as px
import einops
import plotly.graph_objects as go 
from functools import partial
import tqdm.auto as tqdm
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from fancy_einsum import einsum
#from plotly_utils import imshow, line, scatter, bar


In [4]:
## turn off AD to save memory, since we're focusing on model inference here 
t.set_grad_enabled(False)

device = 'cuda' if t.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Generate example prompts for IOI along with clean and corrupted answers. It's important that they're all the same length (taken from exploratory analysis demo )

In [36]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )

        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = t.tensor(answer_tokens).to(device)

### check that all the prompts have the same number of tokens 
prompt_len = len(model.to_str_tokens(prompts[1]))
assert len(set([len(model.to_str_tokens(prompt)) for prompt in prompts])) == 1


In [24]:
### print all prompts in a table (learned from Keith's notebook! )
from rich.table import Table, Column
from rich import print as rprint


prompt_tab = Table('prompt', 'clean', 'corrupted', title = 'prompts and answers')

for i in range(len(prompts)):
    prompt_tab.add_row(prompts[i], answers[i][0], answers[i][1])

rprint(prompt_tab)

cache the logits and model internals for all the prompts

In [26]:
tokens = model.to_tokens(prompts, prepend_bos = True).to(device)
og_logits, cache = model.run_with_cache(tokens)


Design a metric to test model performance. In this case, we'll use the logit difference between the indirect object (correct answer) and the subject (incorrect answer)

In [65]:
d_vocab = model.cfg.d_vocab
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
n_layers = model.cfg.n_layers
n_ex = len(prompts)

assert og_logits.shape == t.Size([n_ex, prompt_len, d_vocab])

def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt = False):
    # take the last logit for every prompt (only these are relevant to the answer)
    final_logits = logits[:,-1,:]
    # get the logits corresponding to the IO/ sub tokens 
    answer_logits = final_logits.gather(dim=-1, index = answer_tokens)
    answer_logit_diff = answer_logits[:,0] - answer_logits[:,1]
    ## If per_prompt = True, return an array of the per_prompt difference, instead of the average 
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()
    

og_logit_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=True)
og_logit_avg_diff = logits_to_ave_logit_diff(og_logits, answer_tokens, per_prompt=False)

cols = [
    "Prompt", 
    Column("Correct", style="rgb(0,200,0) bold"), 
    Column("Incorrect", style="rgb(255,0,0) bold"), 
    Column("Logit Difference", style="bold"), Column("Avg Logit Difference", style="bold")
]
logit_diff_table = Table(*cols, title="Logit differences")

for prompt, ans, logit_diff in zip(prompts, answers,og_logit_diff):
    logit_diff_table.add_row(prompt, ans[0], ans[1], f"{logit_diff.item():.3f}")
rprint(logit_diff_table)



In [63]:
og_logit_avg_diff

tensor(3.5519)

What is going on in IOI? There are several ways to check. 

First, Direct Logit Attribution: The residual stream is read and written to with linear maps (+ a ~linear (?) layernorm), so its logits can be decomposed into the sum from each linear function acting on it. Working backwards from the end of the model (logits = U(LN(final_residual))), see which components contribute most to the logit for the right token. 

- The metric for IOI is nice, since the difference of the log probabilities (log softmax) is the same as the difference for the logits

- Getting an output logit = projecting onto the residual stream in that direction


In [80]:
# map answer tokens to the d_model residual stream directions
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print(answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:,0] - answer_residual_directions[:,1]

## Make sure this works by applying U and LN to the residual stream 

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 

final_residual = cache["resid_post", -1]
assert final_residual.shape == t.Size([n_ex, prompt_len, d_model])

final_tok_residual = final_residual[:,-1,:]

# divide by the LN scale. pos_slice are the positions considered (final token of each prompt)
scaled_final_token_residual = cache.apply_ln_to_stack(final_tok_residual, layer = -1, pos_slice=-1)

## Get the average logit diff by projecting onto the answer residual stream directions and summing over each direction and example
average_avg_logit_diff = (einsum("b d_model, b d_model -> ", scaled_final_token_residual, logit_diff_directions).item())/n_ex

# small enough! 
print(average_avg_logit_diff - og_logit_avg_diff)



torch.Size([8, 2, 768])
tensor(-1.1921e-06)
