In [1]:
import torch
import numpy as np
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'colab'
import einops
import plotly.graph_objects as go

from functools import partial
import tqdm.auto as tqdm
import datasets

In [2]:
from transformer_lens import EasyTransformer

In [3]:
#Plotting functions
# This is mostly a bunch of over-engineered mess to hack Plotly into producing 
# the pretty pictures I want, I recommend not reading too closely unless you 
# want Plotly hacking practice
def to_numpy(tensor, flat=False):
    if type(tensor)!=torch.Tensor:
        return tensor
    if flat:
        return tensor.flatten().detach().cpu().numpy()
    else:
        return tensor.detach().cpu().numpy()
def imshow(tensor, xaxis=None, yaxis=None, animation_name='Snapshot', **kwargs):
    tensor = torch.squeeze(tensor)
    px.imshow(to_numpy(tensor, flat=False), 
              labels={'x':xaxis, 'y':yaxis, 'animation_name':animation_name}, 
              **kwargs).show()
# Set default colour scheme
# Creates good defaults for showing divergent colour scales (ie with both 
# positive and negative values, where 0 is white)
imshow = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)

def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):
    if type(y)==torch.Tensor:
        y = to_numpy(y, flat=True)
    if type(x)==torch.Tensor:
        x=to_numpy(x, flat=True)
    fig = px.line(x, y=y, hover_name=hover, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    fig.show()
def scatter(x, y, **kwargs):
    px.scatter(x=to_numpy(x, flat=True), y=to_numpy(y, flat=True), **kwargs).show()
def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = EasyTransformer.from_pretrained('gpt2') # This loads GPT-2 Small, with 80M parameters. 
# gpt2-medium, gpt2-large, gpt2-xl give larger sizes, see the demo notebook for more
model = model.to(device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cpu


In [5]:

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x2cdd24cd0>

In [None]:
example_prompt = "After John and Mary went to the shops, Mary handed a bottle of milk to" # @param
example_answer = " John" #@param
# Hacky function to map a text string to separate tokens as text
example_prompt_str_tokens = model.to_str_tokens(example_prompt, prepend_bos=True)
example_answer_str_tokens = model.to_str_tokens(example_answer)
print("Tokenized prompt:", example_prompt_str_tokens)
print("Tokenized answer:", example_answer_str_tokens)
prompt_length = len(example_prompt_str_tokens)
answer_length = len(example_answer_str_tokens)
example_logits = model(example_prompt+example_answer)
for index in range(prompt_length, prompt_length + answer_length):
    print("Logits for token:", example_answer_str_tokens[index - prompt_length])
    token_logits = example_logits[0, index-1]
    probs = torch.nn.functional.softmax(token_logits, dim=-1)
    values, indices = token_logits.sort(descending=True)
    for i in range(10):
        print(f"Top {i}th logit. Logit: {values[i].item():.6} Prob: {probs[indices[i]].item():.2%} Token: |{model.tokenizer.decode(indices[i])}|")

In [7]:
example_text = "After John and Mary went to the shops, Mary handed a bottle of milk to"
example_text_reverse = "After John and Mary went to the shops, John handed a bottle of milk to"
example_str_tokens = model.to_str_tokens(example_text, prepend_bos=True)
print("Input split into tokens:", example_str_tokens)
john_index = model.tokenizer.encode(" John")[0]
mary_index = model.tokenizer.encode(" Mary")[0]
print(f"Index of John token: {john_index}. Index of Mary token: {mary_index}")

def get_logit_diff(logits):
    # Takes in a batch x position x vocab tensor of logits, and returns the difference between the John and Mary logit
    return logits[0, -1, john_index] - logits[0, -1, mary_index]
example_logits = model(example_text) # Shape batch x position x vocab
example_logit_diff = get_logit_diff(example_logits)
example_logits_reverse = model(example_text_reverse) # Shape batch x position x vocab
example_logit_diff_reverse = get_logit_diff(example_logits_reverse)
print(f"Input text: {example_text}, John logit - Mary logit: {example_logit_diff.item()}")
print(f"Input text: {example_text_reverse}, John logit - Mary logit: {example_logit_diff_reverse.item()}")

Input split into tokens: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' handed', ' a', ' bottle', ' of', ' milk', ' to']
Index of John token: 1757. Index of Mary token: 5335
Input text: After John and Mary went to the shops, Mary handed a bottle of milk to, John logit - Mary logit: 4.0183610916137695
Input text: After John and Mary went to the shops, John handed a bottle of milk to, John logit - Mary logit: -2.918598175048828


In [8]:
model.cfg.use_attn_result = True
example_cache = {}
model.cache_all(example_cache, remove_batch_dim=True)
_ = model(example_text)
model.reset_hooks()
reverse_example_cache = {}
model.cache_all(reverse_example_cache, remove_batch_dim=True)
_ = model(example_text_reverse)
model.reset_hooks()
model.cfg.use_attn_result = False
for act_name in example_cache:
    print(act_name, example_cache[act_name].shape)



hook_embed torch.Size([17, 768])
hook_pos_embed torch.Size([17, 768])
blocks.0.hook_resid_pre torch.Size([17, 768])
blocks.0.ln1.hook_scale torch.Size([17, 1])
blocks.0.ln1.hook_normalized torch.Size([17, 768])
blocks.0.attn.hook_q torch.Size([17, 12, 64])
blocks.0.attn.hook_k torch.Size([17, 12, 64])
blocks.0.attn.hook_v torch.Size([17, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([12, 17, 17])
blocks.0.attn.hook_pattern torch.Size([12, 17, 17])
blocks.0.attn.hook_z torch.Size([17, 12, 64])
blocks.0.attn.hook_result torch.Size([17, 12, 768])
blocks.0.hook_attn_out torch.Size([17, 768])
blocks.0.hook_resid_mid torch.Size([17, 768])
blocks.0.ln2.hook_scale torch.Size([17, 1])
blocks.0.ln2.hook_normalized torch.Size([17, 768])
blocks.0.mlp.hook_pre torch.Size([17, 3072])
blocks.0.mlp.hook_post torch.Size([17, 3072])
blocks.0.hook_mlp_out torch.Size([17, 768])
blocks.0.hook_resid_post torch.Size([17, 768])
blocks.1.hook_resid_pre torch.Size([17, 768])
blocks.1.ln1.hook_scale torch.S

In [12]:
example_text = "After John and Mary went to the shops, Mary handed a bottle of milk to"
example_text_reverse = "After John and Mary went to the shops, John handed a bottle of milk to"
example_str_tokens = model.to_str_tokens(example_text, prepend_bos=True)

print("Input split into tokens:", example_str_tokens)


john_index = model.tokenizer.encode(" John")[0]
mary_index = model.tokenizer.encode(" Mary")[0]
print(f"Index of John token: {john_index}. Index of Mary token: {mary_index}")

def get_logit_diff(logits):
    # Takes in a batch x position x vocab tensor of logits, and returns the difference between the John and Mary logit
    return logits[0, -1, john_index] - logits[0, -1, mary_index]


example_logits = model(example_text) # Shape batch x position x vocab
example_logit_diff = get_logit_diff(example_logits)
example_logits_reverse = model(example_text_reverse) # Shape batch x position x vocab
example_logit_diff_reverse = get_logit_diff(example_logits_reverse)
print(f"Input text: {example_text}, John logit - Mary logit: {example_logit_diff.item()}")
print(f"Input text: {example_text_reverse}, John logit - Mary logit: {example_logit_diff_reverse.item()}")

Input split into tokens: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' handed', ' a', ' bottle', ' of', ' milk', ' to']
Index of John token: 1757. Index of Mary token: 5335
Input text: After John and Mary went to the shops, Mary handed a bottle of milk to, John logit - Mary logit: 4.0183610916137695
Input text: After John and Mary went to the shops, John handed a bottle of milk to, John logit - Mary logit: -2.918598175048828


In [10]:
person_1 = "John"
person_2 = "Mary"


example_texts = [f"After {person_1} and {person_2} went to the shops, {person_2} handed a bottle of milk to",
                 f"After {person_1} and {person_2} got home, {person_2} handed the car keys to", 
                 f"After {person_1} and {person_2} started corresponding, {person_2} wrote a letter to", 
                 f"After {person_1} and {person_2} fought, {person_2} apologized to",
                 f"After {person_1} and {person_2} shared dinner, {person_2} thanked"] 


example_text_reverses = [f"After {person_1} and {person_2} went to the shops, {person_1} handed a bottle of milk to",
                            f"After {person_1} and {person_2} got home, {person_1} handed the car keys to",
                            f"After {person_1} and {person_2} started corresponding, {person_1} wrote a letter to",
                            f"After {person_1} and {person_2} fought, {person_1} apologized to",
                            f"After {person_1} and {person_2} shared dinner, {person_1} thanked"]

for i in range(len(example_texts)):
    print(i)
    example_text = example_texts[i]
    example_text_reverse = example_text_reverses[i]
    example_str_tokens = model.to_str_tokens(example_text, prepend_bos=True)
    print("Input split into tokens:", example_str_tokens)
    example_logits = model(example_text) # Shape batch x position x vocab
    example_logit_diff = get_logit_diff(example_logits)
    example_logits_reverse = model(example_text_reverse) # Shape batch x position x vocab
    example_logit_diff_reverse = get_logit_diff(example_logits_reverse)
    print(f"Input text: {example_text}, John logit - Mary logit: {example_logit_diff.item()}")
    print(f"Input text: {example_text_reverse}, John logit - Mary logit: {example_logit_diff_reverse.item()}")


0
Input split into tokens: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' handed', ' a', ' bottle', ' of', ' milk', ' to']
Input text: After John and Mary went to the shops, Mary handed a bottle of milk to, John logit - Mary logit: 4.0183610916137695
Input text: After John and Mary went to the shops, John handed a bottle of milk to, John logit - Mary logit: -2.918598175048828
1
Input split into tokens: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' got', ' home', ',', ' Mary', ' handed', ' the', ' car', ' keys', ' to']
Input text: After John and Mary got home, Mary handed the car keys to, John logit - Mary logit: 3.8062944412231445
Input text: After John and Mary got home, John handed the car keys to, John logit - Mary logit: -2.6187286376953125
2
Input split into tokens: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' started', ' corresponding', ',', ' Mary', ' wrote', ' a', ' letter', ' to']
Input text: After John

In [13]:
example_length = len(model.to_tokens(example_text, prepend_bos=True)[0])
norm_difference = torch.zeros(model.cfg.n_layers, example_length)
for layer in range(model.cfg.n_layers):
    example_resid = example_cache[f"blocks.{layer}.hook_resid_pre"]
    reverse_example_resid = reverse_example_cache[f"blocks.{layer}.hook_resid_pre"]
    resid_diff = (example_resid - reverse_example_resid)
    # Compare the activations, normalising by the average size of the original activations
    norm_difference[layer] = resid_diff.norm(dim=-1)/(example_resid.norm(dim=-1) * reverse_example_resid.norm(dim=-1)).sqrt()

# I need to make the labels {token}_{index} so they're all unique because Plotly gets confused if given duplicate labels for imshow
imshow(norm_difference, yaxis='Layer', x=[f"{token}_{c}" for c, token in enumerate(example_str_tokens)], title='Norm of difference in residual stream')

In [15]:
example_length = len(model.to_tokens(example_text, prepend_bos=True)[0])
norm_difference = torch.zeros(model.cfg.n_layers, example_length)
for layer in range(model.cfg.n_layers):
    example_attn_out = example_cache[f"blocks.{layer}.hook_attn_out"]
    reverse_example_attn_out = reverse_example_cache[f"blocks.{layer}.hook_attn_out"]
    attn_out_diff = (example_attn_out - reverse_example_attn_out)
    # Compare the activations, normalising by the average size of the original activations
    norm_difference[layer] = attn_out_diff.norm(dim=-1)/(example_attn_out.norm(dim=-1) * reverse_example_attn_out.norm(dim=-1)).sqrt()

# I need to make the labels {token}_{index} so they're all unique because Plotly gets confused if given duplicate labels for imshow
imshow(norm_difference, yaxis='Layer', title='Norm of difference in attn_out')