In [5]:
# Import stuff
import torch as t
import numpy as np
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.express as px
import einops
import plotly.graph_objects as go 
from functools import partial
import tqdm.auto as tqdm
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from fancy_einsum import einsum
from jaxtyping import Float, Int, Bool
import re
import random


In [2]:

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
## turn off AD to save memory, since we're focusing on model inference here 
t.set_grad_enabled(False)

device = 'cuda' if t.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Studying the head outputs on different distributions

Create the datasets 

In [87]:
proper_nouns = [
    " Goose", " Church",
    " Google", " Chair",
    " Bag", " Statue",
    " Lamp", " Flower"
]

nouns = [
    " goose", " church",
    " google", " chair",
    " bag", " statue",
    " lamp", " flower"
]

multi_names_religious = [
    " Mary", " Joseph",
    " Abraham", " Paul",
    " Isaac", " Noah",
    " Jacob", " Jesus"
]

multi_names_places = [
    " Paris", " London",
    " Madison", " Phoenix",
    " Devon", " Florence",
    " Austin", " Brooklyn"
]

test = proper_nouns
[model.to_single_token(test[i]) for i in range(len(test))]


[46317, 4564, 3012, 9369, 20127, 43330, 28607, 20025]

In [171]:



def get_dataset(N, names):
    prompts = []
    # List of answers, in the format (correct, incorrect)
    answers = []
    # List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
    answer_tokens = []
    for _ in range(N):
        S1, S2 = random.sample(names,2)
        answers.append((S1,S2))
        answers.append((S2,S1))
        prompt1 = f'When {S1} and {S2} went to the shops,{S2} gave the bag to'
        prompt2 = f'When {S1} and {S2} went to the shops,{S1} gave the bag to'
        prompts.append(prompt1)
        prompts.append(prompt2)

        answer_tokens.append((model.to_single_token(answers[-1][0]),model.to_single_token(answers[-1][1])))
        answer_tokens.append((model.to_single_token(answers[-1][1]),model.to_single_token(answers[-1][0])))
    assert len(set([len(model.to_str_tokens(prompt)) for prompt in prompts])) == 1
    answer_tokens = t.tensor(answer_tokens).to(device)
    return prompts, answers, answer_tokens



In [172]:
prompts_rel, ans_rel, ans_toks_rel = get_dataset(10,multi_names_religious)
prompts_pl, ans_pl, ans_toks_pl = get_dataset(10,multi_names_places)
prompts_n, ans_n, ans_toks_n = get_dataset(10,nouns)
prompts_pn, ans_pn, ans_toks_pn = get_dataset(10,proper_nouns)



In [173]:
from rich.table import Table, Column
from rich import print as rprint


In [174]:
control_prompts = ['When John and Mary went to the shops, John gave the bag to',
 'When John and Mary went to the shops, Mary gave the bag to',
 'When Tom and James went to the park, James gave the ball to',
 'When Tom and James went to the park, Tom gave the ball to',
 'When Dan and Sid went to the shops, Sid gave an apple to',
 'When Dan and Sid went to the shops, Dan gave an apple to',
 'After Martin and Amy went to the park, Amy gave a drink to',
 'After Martin and Amy went to the park, Martin gave a drink to']

control_answers = [(' Mary', ' John'),
 (' John', ' Mary'),
 (' Tom', ' James'),
 (' James', ' Tom'),
 (' Dan', ' Sid'),
 (' Sid', ' Dan'),
 (' Martin', ' Amy'),
 (' Amy', ' Martin')]

control_ans_prompts = t.tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]])



In [175]:
print(model(prompts_n).shape)

def ave_logit_diff(prompts, answer_tokens, per_prompt = False):
    final_logits = model(prompts)[:,-1,:]
    answer_logits = final_logits.gather(dim = -1, index = answer_tokens)
    answer_logit_diff = answer_logits[:,0] - answer_logits[:,1]
    if per_prompt:
        return answer_logit_diff 
    else:
        return answer_logit_diff.mean()



torch.Size([20, 17, 50257])


In [176]:
def make_table(prompts, answers, answer_tokens, title):
    cols = [
        "Prompt", 
        Column("Correct", style="rgb(0,200,0) bold"), 
        Column("Incorrect", style="rgb(255,0,0) bold"), 
        Column("Logit Difference", style="bold")
    ]
    logit_diffs = ave_logit_diff(prompts,answer_tokens, per_prompt = True)
    
    ave_logits = ave_logit_diff(prompts,answer_tokens, per_prompt = False)
    logit_diff_table = Table(*cols, title=title + f": Ave logit diff = {ave_logits.item():.3f}")
    logit_diffs = ave_logit_diff(prompts,answer_tokens, per_prompt = True)
    
    ave_logits = ave_logit_diff(prompts,answer_tokens, per_prompt = False)
                                 
    ave_logit_diff(prompts, answer_tokens)
    for prompt, ans, logit_diff in zip(prompts, answers,logit_diffs):
        logit_diff_table.add_row(prompt, ans[0], ans[1], f"{logit_diff.item():.3f}")
    rprint(logit_diff_table)

In [177]:
make_table(control_prompts, control_answers, control_ans_prompts, "Control Names")

In [183]:
[(model.to_str_tokens(ans_toks_n[k],ans_n[k]), ans_n[k]) for k in range(len(ans_toks_n))]

[([' flower', ' church'], (' church', ' flower')),
 ([' church', ' flower'], (' flower', ' church')),
 ([' goose', ' chair'], (' chair', ' goose')),
 ([' chair', ' goose'], (' goose', ' chair')),
 ([' goose', ' flower'], (' flower', ' goose')),
 ([' flower', ' goose'], (' goose', ' flower')),
 ([' bag', ' statue'], (' statue', ' bag')),
 ([' statue', ' bag'], (' bag', ' statue')),
 ([' google', ' flower'], (' flower', ' google')),
 ([' flower', ' google'], (' google', ' flower')),
 ([' flower', ' goose'], (' goose', ' flower')),
 ([' goose', ' flower'], (' flower', ' goose')),
 ([' chair', ' statue'], (' statue', ' chair')),
 ([' statue', ' chair'], (' chair', ' statue')),
 ([' google', ' flower'], (' flower', ' google')),
 ([' flower', ' google'], (' google', ' flower')),
 ([' statue', ' lamp'], (' lamp', ' statue')),
 ([' lamp', ' statue'], (' statue', ' lamp')),
 ([' lamp', ' church'], (' church', ' lamp')),
 ([' church', ' lamp'], (' lamp', ' church'))]

In [184]:
make_table(prompts_n, ans_n, ans_toks_n, "nouns")
make_table(prompts_pn, ans_pn, ans_toks_pn, "proper nouns")
make_table(prompts_rel, ans_rel, ans_toks_rel, "Religious Names")
make_table(prompts_pl, ans_pl, ans_toks_pl, "Place Names")

Results look pretty bad across the board ... let's look at the top_k tokens to see 

In [153]:
sample_prompt = prompts_rel[5]
sample_logits = model(sample_prompt)
sample_probs = t.log_softmax(sample_logits[0, -1], dim = -1)
vals, ids = t.topk(sample_probs,5)
print(vals, sample_probs)
print([model.tokenizer.decode(ids[j]) for j in range(len(ids))])
print(sample_probs[ids[0]])



tensor([-1.2884, -1.5528, -2.2378, -2.4683, -3.5000]) tensor([-12.5240, -12.0724, -14.4456,  ..., -16.4893, -17.4218, -12.1294])
[' Abraham', ' ', ' them', ' the', ' Paul']
tensor(-1.2884)
