In [5]:
# Import stuff
import torch as t
import numpy as np
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.express as px
import einops
import plotly.graph_objects as go 
from functools import partial
import tqdm.auto as tqdm
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from fancy_einsum import einsum
from jaxtyping import Float, Int, Bool
import re
import random


In [2]:

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
## turn off AD to save memory, since we're focusing on model inference here 
t.set_grad_enabled(False)

device = 'cuda' if t.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Studying the head outputs on different distributions

Create the datasets 

In [87]:
proper_nouns = [
    " Goose", " Church",
    " Google", " Chair",
    " Bag", " Statue",
    " Lamp", " Flower"
]

nouns = [
    " goose", " church",
    " google", " chair",
    " bag", " statue",
    " lamp", " flower"
]

multi_names_religious = [
    " Mary", " Joseph",
    " Abraham", " Paul",
    " Isaac", " Noah",
    " Jacob", " Jesus"
]

multi_names_places = [
    " Paris", " London",
    " Madison", " Phoenix",
    " Devon", " Florence",
    " Austin", " Brooklyn"
]

test = proper_nouns
[model.to_single_token(test[i]) for i in range(len(test))]


[46317, 4564, 3012, 9369, 20127, 43330, 28607, 20025]

In [88]:



def get_dataset(N, names):
    prompts = []
    # List of answers, in the format (correct, incorrect)
    answers = []
    # List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
    answer_tokens = []
    for _ in range(N):
        S1, S2 = random.sample(names,2)
        answers.append((S1,S2))
        answers.append((S2,S1))
        prompt1 = f'When {S1} and {S2} went to the shops,{S2} gave the bag to'
        prompt2 = f'When {S1} and {S2} went to the shops,{S1} gave the bag to'
        prompts.append(prompt1)
        prompts.append(prompt2)

        answer_tokens.append((model.to_single_token(answers[-1][0]),model.to_single_token(answers[-1][1])))
    assert len(set([len(model.to_str_tokens(prompt)) for prompt in prompts])) == 1
    answer_tokens = t.tensor(answer_tokens).to(device)
    return prompts, answers, answer_tokens



In [89]:
prompts_rel, ans_rel, ans_toks_rel = get_dataset(10,multi_names_religious)
prompts_pl, ans_pl, ans_toks_pl = get_dataset(10,multi_names_places)
prompts_n, ans_n, ans_toks_n = get_dataset(10,nouns)
prompts_pn, ans_pn, ans_toks_pn = get_dataset(10,proper_nouns)



In [90]:
from rich.table import Table, Column
from rich import print as rprint


In [91]:
print(model(prompts_n).shape)

def ave_logit_diff(prompts, answer_tokens, per_prompt = False):
    final_logits = model(prompts)[:,-1,:]
    answer_logits = final_logits.gather(dim = -1, index = answer_tokens)
    answer_logit_diff = answer_logits[:,0] - answer_logits[:,1]
    if per_prompt:
        return answer_logit_diff 
    else:
        return answer_logit_diff.mean()



torch.Size([20, 17, 50257])


In [92]:
def make_table(prompts, answers, answer_tokens, title):
    cols = [
        "Prompt", 
        Column("Correct", style="rgb(0,200,0) bold"), 
        Column("Incorrect", style="rgb(255,0,0) bold"), 
        Column("Logit Difference", style="bold")
    ]
    logit_diffs = ave_logit_diff(prompts,answer_tokens, per_prompt = True)
    
    ave_logits = ave_logit_diff(prompts,answer_tokens, per_prompt = False)
    logit_diff_table = Table(*cols, title=title + f": Ave logit diff = {ave_logits.item():.3f}")
    logit_diffs = ave_logit_diff(prompts,answer_tokens, per_prompt = True)
    
    ave_logits = ave_logit_diff(prompts,answer_tokens, per_prompt = False)
                                 
    ave_logit_diff(prompts, answer_tokens)
    for prompt, ans, logit_diff in zip(prompts, answers,logit_diffs):
        logit_diff_table.add_row(prompt, ans[0], ans[1], f"{logit_diff.item():.3f}")
    rprint(logit_diff_table)

In [94]:
make_table(prompts_n, ans_n, ans_toks_n, "nouns")
make_table(prompts_pn, ans_pn, ans_toks_pn, "proper nouns")
make_table(prompts_rel, ans_rel, ans_toks_rel, "Religious Names")
make_table(prompts_pl, ans_pl, ans_toks_pl, "Place Names")

Results look pretty bad across the board ... let's look at the top_k tokens to see 

In [141]:
def get_top_preds(prompt, answer, top_n):
    prompt_str_toks = model.to_str_tokens(prompt)
    answer_str_toks = model.to_str_tokens(answer)
    both_str_toks = model.to_str_tokens(prompt + answer)
    print('tokenized prompt: ', prompt_str_toks)
    print('tokenized answer: ', answer_str_toks)
    #print('tokenized total: ', both_str_toks)
    prompt_len = len(prompt_str_toks)
    answer_len = len(answer_str_toks)
    logits = model(prompt + answer) #logits for full sentence 
    ## loop over the answer tokens
    for idx in range(prompt_len, prompt_len + answer_len):
        print("Logits for token:", answer_str_toks[idx - prompt_len])
        # get prediction of next token from the token index before idx
        token_logits = logits[0, idx - 1]
        probs = t.nn.functional.softmax(token_logits, dim = -1)
        # sort the probabilities in descending order 
        vals, ids = token_logits.sort(descending = True)
        for i in range(top_n):
            print(f"Top {i}th logit. Logit = {vals[i]}, prob = {probs[ids[i]].item():.2%}, token = {model.tokenizer.decode(ids[i])}")

In [148]:
sample_prompt = prompts_n[5]
sample_logits = model(sample_prompt)
sample_probs = t.softmax(sample_logits[0, -1], dim = -1)
vals, ids = t.topk(sample_probs,5)
print(vals, sample_probs)
print([model.tokenizer.decode(ids[j]) for j in range(len(ids))])

get_top_preds(prompts_n[5], ans_n[5][0], 5)

tensor([0.2637, 0.1414, 0.0587, 0.0267, 0.0259]) tensor([1.5437e-05, 1.6269e-05, 3.5443e-07,  ..., 1.0786e-07, 6.0058e-07,
        1.7738e-05])
[' the', ' ', ' statue', ' them', ' a']
tokenized prompt:  ['<|endoftext|>', 'When', ' ', ' goose', ' and', ' ', ' statue', ' went', ' to', ' the', ' shops', ',', ' goose', ' gave', ' the', ' bag', ' to']
tokenized answer:  ['<|endoftext|>', ' statue']
Logits for token: <|endoftext|>
Top 0th logit. Logit = 14.341024398803711, prob = 26.37%, token =  the
Top 1th logit. Logit = 13.71792221069336, prob = 14.14%, token =  
Top 2th logit. Logit = 12.838592529296875, prob = 5.87%, token =  statue
Top 3th logit. Logit = 12.051135063171387, prob = 2.67%, token =  them
Top 4th logit. Logit = 12.022403717041016, prob = 2.59%, token =  a
Logits for token:  statue
Top 0th logit. Logit = 16.110246658325195, prob = 31.59%, token = .
Top 1th logit. Logit = 15.645147323608398, prob = 19.84%, token = ,
Top 2th logit. Logit = 15.59281063079834, prob = 18.83%, to