In [1]:
# Import stuff
import torch as t
import numpy as np
# Plotly needs a different renderer for VSCode/Notebooks vs Colab
import plotly.io as pio
pio.renderers.default = "notebook_connected"
import plotly.express as px
import einops
import plotly.graph_objects as go 
from functools import partial
import tqdm.auto as tqdm
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from fancy_einsum import einsum
from jaxtyping import Float, Int, Bool
import re
import random
from IPython.display import display

In [2]:

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
## turn off AD to save memory, since we're focusing on model inference here 
t.set_grad_enabled(False)

device = 'cuda' if t.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Studying the head outputs on different distributions

Create the datasets 

In [4]:
proper_nouns = [
    " Goose", " Church",
    " Google", " Chair",
    " Bag", " Statue",
    " Lamp", " Flower"
]

nouns = [
    " goose", " church",
    " google", " chair",
    " bag", " statue",
    " lamp", " flower"
]

multi_names_religious = [
    " Mary", " Joseph",
    " Abraham", " Paul",
    " Isaac", " Noah",
    " Jacob", " Jesus"
]

multi_names_places = [
    " Paris", " London",
    " Madison", " Phoenix",
    " Devon", " Florence",
    " Austin", " Brooklyn"
]

test = proper_nouns
[model.to_single_token(test[i]) for i in range(len(test))]


[46317, 4564, 3012, 9369, 20127, 43330, 28607, 20025]

In [5]:



def get_dataset(N, names):
    prompts = []
    # List of answers, in the format (correct, incorrect)
    answers = []
    # List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
    answer_tokens = []
    for _ in range(N):
        S1, S2 = random.sample(names,2)
        answers.append((S1,S2))
        answers.append((S2,S1))
        prompt1 = f'When{S1} and{S2} went to the shops,{S2} gave the bag to'
        prompt2 = f'When{S1} and{S2} went to the shops,{S1} gave the bag to'
        prompts.append(prompt1)
        prompts.append(prompt2)

        answer_tok_1 = model.to_single_token(S1)
        answer_tok_2 = model.to_single_token(S2)

        answer_tokens.append((answer_tok_1, answer_tok_2))
        answer_tokens.append((answer_tok_2, answer_tok_1))
    assert len(set([len(model.to_str_tokens(prompt)) for prompt in prompts])) == 1
    answer_tokens = t.tensor(answer_tokens).to(device)
    return prompts, answers, answer_tokens



In [6]:
prompts_rel, ans_rel, ans_toks_rel = get_dataset(10,multi_names_religious)
prompts_pl, ans_pl, ans_toks_pl = get_dataset(10,multi_names_places)
prompts_n, ans_n, ans_toks_n = get_dataset(10,nouns)
prompts_pn, ans_pn, ans_toks_pn = get_dataset(10,proper_nouns)



In [7]:
prompts_full = prompts_n + prompts_pl + prompts_rel + prompts_pl
ans_full = ans_n + ans_pl + ans_rel + ans_pl
ans_toks_full = t.concat([ans_toks_n, ans_toks_pl, ans_toks_rel, ans_toks_pl], dim =0)
ans_toks_full.shape, ans_toks_n.shape

(torch.Size([80, 2]), torch.Size([20, 2]))

In [8]:
from rich.table import Table, Column
from rich import print as rprint


In [9]:
control_prompts = ['When John and Mary went to the shops, John gave the bag to',
 'When John and Mary went to the shops, Mary gave the bag to',
 'When Tom and James went to the park, James gave the ball to',
 'When Tom and James went to the park, Tom gave the ball to',
 'When Dan and Sid went to the shops, Sid gave an apple to',
 'When Dan and Sid went to the shops, Dan gave an apple to',
 'After Martin and Amy went to the park, Amy gave a drink to',
 'After Martin and Amy went to the park, Martin gave a drink to']

control_answers = [(' Mary', ' John'),
 (' John', ' Mary'),
 (' Tom', ' James'),
 (' James', ' Tom'),
 (' Dan', ' Sid'),
 (' Sid', ' Dan'),
 (' Martin', ' Amy'),
 (' Amy', ' Martin')]

control_ans_prompts = t.tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]])



In [10]:
print(model(prompts_n).shape)

def ave_logit_diff(prompts, answer_tokens, per_prompt = False):
    final_logits = model(prompts)[:,-1,:]
    answer_logits = final_logits.gather(dim = -1, index = answer_tokens)
    #print(final_logits.shape)
    answer_logit_diff = answer_logits[:,0] - answer_logits[:,1]
    if per_prompt:
        return answer_logit_diff 
    else:
        return answer_logit_diff.mean()

ave_logit_diff(prompts_n, ans_toks_n)

torch.Size([20, 15, 50257])


tensor(0.2733)

In [11]:
### Is the logit diff for these two names the right metric? 

### First, try the logit diff between the IO and the average of the actual topk logits for each example prompt 


def ave_logit_diff_topk(prompts, answer_tokens, k, per_prompt = False):
    final_logits = model(prompts)[:,-1,:]
    answer_logits = final_logits.gather(dim = -1, index = answer_tokens)[:,0]

    sample_logits = model(prompts)[:,-1]
    vals, ids = t.topk(sample_logits,k)
    topk_avg = vals.mean(dim=-1)

    answer_logit_diff = answer_logits - topk_avg
    #print(answer_logits, topk_avg)
    if per_prompt:
        return answer_logit_diff 
    else:
        return answer_logit_diff.mean()

ave_logit_diff_topk(prompts_pn, ans_toks_pn,15)





tensor(1.9502)

In [12]:
noun_test = [[k,ave_logit_diff_topk(prompts_n, ans_toks_n,k)] for k in range(25)]

pnoun_test = [[k,ave_logit_diff_topk(prompts_pn, ans_toks_pn,k)] for k in range(25)]

rel_test = [[k,ave_logit_diff_topk(prompts_rel, ans_toks_rel,k)] for k in range(25)]

pl_test = [[k,ave_logit_diff_topk(prompts_pl, ans_toks_pl,k)] for k in range(25)]


KeyboardInterrupt: 

In [None]:
full_test = [[k,ave_logit_diff_topk(prompts_full, ans_toks_full,k)] for k in range(25)]

: 

scatter(x= t.tensor(noun_test)[:,0], y = t.tensor(noun_test)[:,1],xaxis = 'k', yaxis = 'topk metric', title = 'Avg logit diff between IO and avg of topk answer logits: nouns')

In [None]:
scatter(x= t.tensor(noun_test)[:,0], y = t.tensor(noun_test)[:,1],xaxis = 'k', yaxis = 'topk metric', title = 'Avg logit diff between IO and avg of topk answer logits: nouns')

scatter(x= t.tensor(pnoun_test)[:,0], y = t.tensor(pnoun_test)[:,1],xaxis = 'k', yaxis = 'topk metric', title = 'Avg logit diff between IO and avg of topk answer logits: pnouns')

scatter(x= t.tensor(rel_test)[:,0], y = t.tensor(rel_test)[:,1],xaxis = 'k', yaxis = 'topk metric', title = 'Avg logit diff between IO and avg of topk answer logits: rel')

scatter(x= t.tensor(pl_test)[:,0], y = t.tensor(pl_test)[:,1],xaxis = 'k', yaxis = 'topk metric', title = 'Avg logit diff between IO and avg of topk answer logits: places')

scatter(x= t.tensor(full_test)[:,0], y = t.tensor(full_test)[:,1],xaxis = 'k', yaxis = 'topk metric', title = 'Avg logit diff between IO and avg of topk answer logits: full')


: 

In [13]:
### could also take the difference between the IO logits and the average of the logits for the other names in the set, to check for correlations among the words 


def ave_logit_diff_assoc(prompts, answer_tokens, per_prompt = False):
    final_logits = model(prompts)[:,-1,:]
    answer_logits = final_logits.gather(dim = -1, index = answer_tokens)[:,0]


    # take the average of the other answers in the same distribution
    ave_assoc = (final_logits[:,answer_tokens[:,0]].sum(dim=-1) - answer_logits)/(len(prompts) - 1)
    print(ave_assoc)
    answer_logit_diff = answer_logits - ave_assoc
    #print(answer_logits, topk_avg)
    if per_prompt:
        return answer_logit_diff 
    else:
        return answer_logit_diff.mean()


ave_logit_diff_assoc(prompts_n, ans_toks_n)


tensor([7.3412, 7.0198, 6.9092, 6.9868, 6.6871, 6.4508, 6.6871, 6.4508, 7.3412,
        7.0198, 6.8713, 6.6284, 6.4316, 6.1561, 6.3415, 6.2283, 6.0798, 5.8029,
        7.2027, 6.9211])


tensor(4.5586)

In [14]:
def make_table(prompts, answers, answer_tokens, title):
    cols = [
        "Prompt", 
        Column("Correct", style="rgb(0,200,0) bold"), 
        Column("Incorrect", style="rgb(255,0,0) bold"), 
        Column("Logit Difference", style="bold")
    ]
    logit_diffs = ave_logit_diff(prompts,answer_tokens, per_prompt = True)
    
    ave_logits = ave_logit_diff(prompts,answer_tokens, per_prompt = False)
    logit_diff_table = Table(*cols, title=title + f": Ave logit diff = {ave_logits.item():.3f}")
    logit_diffs = ave_logit_diff(prompts,answer_tokens, per_prompt = True)
    
    ave_logits = ave_logit_diff(prompts,answer_tokens, per_prompt = False)
                                 
    ave_logit_diff(prompts, answer_tokens)
    for prompt, ans, logit_diff in zip(prompts, answers,logit_diffs):
        logit_diff_table.add_row(prompt, ans[0], ans[1], f"{logit_diff.item():.3f}")
    rprint(logit_diff_table)

In [15]:
make_table(control_prompts, control_answers, control_ans_prompts, "Control Names")

In [16]:
#[(model.to_str_tokens(ans_toks_n[k],ans_n[k]), model.to_str_tokens(ans_toks_n[k]),model.to_str_tokens(ans_n[k]), ans_n[k]) for k in range(len(ans_toks_n))]

test = ans_toks_n[10]
test_ans = ans_n[10]
[model.tokenizer.decode(test[i]) for i in range(len(test))], model.to_str_tokens(test), ans_n[10]


([' lamp', ' flower'], [' lamp', ' flower'], (' lamp', ' flower'))

In [17]:
make_table(prompts_n, ans_n, ans_toks_n, "nouns")
make_table(prompts_pn, ans_pn, ans_toks_pn, "proper nouns")
make_table(prompts_rel, ans_rel, ans_toks_rel, "Religious Names")
make_table(prompts_pl, ans_pl, ans_toks_pl, "Place Names")
make_table(prompts_full, ans_full, ans_toks_full, "Full Dist")

In [80]:
test_prompts = prompts_n[:4]+prompts_pn[:4] +prompts_rel[-9:-5] + prompts_pl[-4:]
test_ans = ans_n[:4]+ans_pn[:4] +ans_rel[-9:-5] + ans_pl[-4:]
test_toks = t.concat([ans_toks_n[:4], ans_toks_pn[:4], ans_toks_rel[-9:-5], ans_toks_pl[-4:]])

test_toks
make_table(test_prompts, test_ans, test_toks, "small sample of full Dist")

Looking at the top-k logits for each prompt

In [18]:

prompts = prompts_full
k = 5

for sample_prompt in prompts:
    sample_logits = model(sample_prompt)
    sample_probs = t.softmax(sample_logits[0, -1], dim = -1)
    vals, ids = t.topk(sample_probs,k)

    print(f"Prompt = {sample_prompt}")
    for i in range(k):
            print(f"Top {i}th logit. prob = {vals[i]:.2%}, token = {model.tokenizer.decode(ids[i])}")


Prompt = When lamp and google went to the shops, google gave the bag to
Top 0th logit. prob = 15.12%, token =  the
Top 1th logit. prob = 11.21%, token =  them
Top 2th logit. prob = 8.51%, token =  me
Top 3th logit. prob = 4.63%, token =  Google
Top 4th logit. prob = 4.62%, token =  google
Prompt = When lamp and google went to the shops, lamp gave the bag to
Top 0th logit. prob = 16.18%, token =  the
Top 1th logit. prob = 10.15%, token =  me
Top 2th logit. prob = 5.17%, token =  them
Top 3th logit. prob = 4.80%, token =  a
Top 4th logit. prob = 3.93%, token =  Google
Prompt = When church and lamp went to the shops, lamp gave the bag to
Top 0th logit. prob = 28.59%, token =  the
Top 1th logit. prob = 6.17%, token =  a
Top 2th logit. prob = 2.75%, token =  them
Top 3th logit. prob = 1.85%, token =  church
Top 4th logit. prob = 1.44%, token =  their
Prompt = When church and lamp went to the shops, church gave the bag to
Top 0th logit. prob = 34.99%, token =  the
Top 1th logit. prob = 6.55%

Looking at the attention heads on each distribution

In [19]:
noun_tokens = model.to_tokens(prompts_n).to(device)
noun_logits, noun_cache = model.run_with_cache(prompts_n)
print(noun_cache["pattern", 0, "attn"].shape)

pnoun_tokens = model.to_tokens(prompts_pn).to(device)
pnoun_logits, pnoun_cache = model.run_with_cache(prompts_pn)

rel_tokens = model.to_tokens(prompts_rel).to(device)
rel_logits, rel_cache = model.run_with_cache(prompts_rel)

full_tokens = model.to_tokens(prompts_full).to(device)
full_logits, full_cache = model.run_with_cache(prompts_full)


pl_tokens = model.to_tokens(prompts_pl).to(device)
pl_logits, pl_cache = model.run_with_cache(prompts_pl)
print(noun_cache["pattern", 0, "attn"].shape)



torch.Size([20, 12, 15, 15])
torch.Size([20, 12, 15, 15])


In [69]:

print("Layer 9 Head Attention Patterns:")
display(cv.attention.attention_patterns(tokens=model.to_str_tokens(noun_tokens[10]), attention= noun_cache["pattern", 1][10]))




Layer 9 Head Attention Patterns:


In [65]:
#len(noun_tokens), einops.rearrange(pattern, "b h s1 s2 -> h s1 s2 b").shape


model.to_str_tokens(noun_tokens[0]), pattern[0].shape

(['<|endoftext|>',
  'When',
  ' lamp',
  ' and',
  ' google',
  ' went',
  ' to',
  ' the',
  ' shops',
  ',',
  ' google',
  ' gave',
  ' the',
  ' bag',
  ' to'],
 torch.Size([12, 15, 15]))

Use the logit lens to look at the attention output at each head/layer

Define the logit_diff_directions in different ways: according to the 3 metrics described above

In [21]:

def get_IOI_dir(answer_tokens):
    answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
    print("Answer residual directions shape:", answer_residual_directions.shape)
    logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
    print("Logit difference directions shape:", logit_diff_directions.shape)
    return logit_diff_directions

def residual_stack_to_logit_diff(residual_stack: Float[t.Tensor, "components batch d_model"], cache: ActivationCache, logit_diff_directions) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)





def get_per_head_logit_diffs(cache,logit_diff_directions):
    per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
    per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache, logit_diff_directions)
    per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
    return per_head_logit_diffs


answer_tokens = ans_toks_pn
cache = pnoun_cache
logit_diff_dirs = get_IOI_dir(answer_tokens)
imshow(get_per_head_logit_diffs(cache, logit_diff_dirs), labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head - proper nouns distr")

answer_tokens = ans_toks_full
cache = full_cache
logit_diff_dirs = get_IOI_dir(answer_tokens)
imshow(get_per_head_logit_diffs(cache, logit_diff_dirs), labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head - full distr")

Answer residual directions shape: torch.Size([20, 2, 768])
Logit difference directions shape: torch.Size([20, 768])
Tried to stack head results when they weren't cached. Computing head results now


Answer residual directions shape: torch.Size([80, 2, 768])
Logit difference directions shape: torch.Size([80, 768])
Tried to stack head results when they weren't cached. Computing head results now


In [22]:
answer_tokens = ans_toks_n
cache = noun_cache
logit_diff_dirs = get_IOI_dir(answer_tokens)
imshow(get_per_head_logit_diffs(cache, logit_diff_dirs), labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head - nouns distr")

answer_tokens = ans_toks_rel
cache = rel_cache
logit_diff_dirs = get_IOI_dir(answer_tokens)
imshow(get_per_head_logit_diffs(cache, logit_diff_dirs), labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head - rel names distr")

answer_tokens = ans_toks_pl
cache = pl_cache
logit_diff_dirs = get_IOI_dir(answer_tokens)
imshow(get_per_head_logit_diffs(cache, logit_diff_dirs), labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head - place names distr")

Answer residual directions shape: torch.Size([20, 2, 768])
Logit difference directions shape: torch.Size([20, 768])
Tried to stack head results when they weren't cached. Computing head results now


Answer residual directions shape: torch.Size([20, 2, 768])
Logit difference directions shape: torch.Size([20, 768])
Tried to stack head results when they weren't cached. Computing head results now


Answer residual directions shape: torch.Size([20, 2, 768])
Logit difference directions shape: torch.Size([20, 768])
Tried to stack head results when they weren't cached. Computing head results now


In [23]:
lh_backup = [[10,10],[10,6],[10,2],[10,1],[11,2],[11,9],[9,0],[9,7]]
lh_list_pos = [[9,6],[9,9], [10,0]]
lh_list_neg = [[10,7], [11,10]]

In [None]:
def visualize_attention_patterns(
    heads, 
    local_cache, 
    local_tokens, 
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, t.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    # if local_tokens is None:
    #     # The tokens of the first prompt
    #     local_tokens = tokens[0]
    
    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = t.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

In [87]:

top_k = 3

answer_tokens = ans_toks_pn
cache = pnoun_cache
logit_diff_dirs = get_IOI_dir(answer_tokens)
       
top_positive_logit_attr_heads = t.topk(get_per_head_logit_diffs(cache, logit_diff_dirs).flatten(), k=top_k).indices
#visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads")
# top_negative_logit_attr_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices
# visualize_attention_patterns(top_negative_logit_attr_heads, title=f"Top {top_k} Negative Logit Attribution Heads")

# print("Layer 9 Head Attention Patterns:")
# display(cv.attention.attention_patterns(tokens=model.to_str_tokens(noun_tokens[10]), attention= noun_cache["pattern", 1][10]))

top_positive_logit_attr_heads

Answer residual directions shape: torch.Size([20, 2, 768])
Logit difference directions shape: torch.Size([20, 768])


tensor([117, 114, 120])