### summarising key findings with a bunch of graphs!

### Setup

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-0we3kvu6
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-0we3kvu6
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 25a9c07cd883f762725f4a5a0cab7b36bc4096cc
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting wandb>=0.13.5
  Downloading wandb-0.15.0-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m60.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers>=4.25.1
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [4]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [5]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [6]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fc2459748e0>

Plotting helper functions:

In [7]:
def imshow(tensor, renderer=None, midpoint=0.0, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=midpoint, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [8]:
line(np.arange(5))

set-up device

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Pronoun prediction

The task is choosing the right pronouns (e.g. he vs she vs it vs they)

A good setup is a rhetorical question (so it doesn’t spoil the answer!) like “Lina is a great friend, isn’t” (h/t Marius Hobbhahn)

The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer.

In [10]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True, 
    device=device
    )

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


The next step is to verify that the model can actually do the task!

In [11]:
example_prompt = "Mary is a great friend, isn’t"
example_answer = " she"
utils.test_prompt(example_prompt, example_answer, model)

Tokenized prompt: ['<|endoftext|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 17.41 Prob: 82.67% Token: | she|
Top 1th token. Logit: 14.69 Prob:  5.45% Token: | it|
Top 2th token. Logit: 13.51 Prob:  1.68% Token: | he|
Top 3th token. Logit: 13.11 Prob:  1.12% Token: | there|
Top 4th token. Logit: 12.74 Prob:  0.78% Token: | I|
Top 5th token. Logit: 12.72 Prob:  0.76% Token: | we|
Top 6th token. Logit: 12.67 Prob:  0.72% Token: | you|
Top 7th token. Logit: 12.35 Prob:  0.52% Token: | her|
Top 8th token. Logit: 12.23 Prob:  0.47% Token: | this|
Top 9th token. Logit: 12.16 Prob:  0.43% Token: | that|


In [12]:
example_prompt = "John is a great friend, isn’t"
example_answer = " he"
utils.test_prompt(example_prompt, example_answer, model)

Tokenized prompt: ['<|endoftext|>', 'John', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Tokenized answer: [' he']


Top 0th token. Logit: 17.47 Prob: 83.43% Token: | he|
Top 1th token. Logit: 14.81 Prob:  5.81% Token: | it|
Top 2th token. Logit: 13.22 Prob:  1.18% Token: | there|
Top 3th token. Logit: 13.06 Prob:  1.01% Token: | you|
Top 4th token. Logit: 13.04 Prob:  0.99% Token: | we|
Top 5th token. Logit: 12.69 Prob:  0.70% Token: | I|
Top 6th token. Logit: 12.62 Prob:  0.65% Token: | she|
Top 7th token. Logit: 12.52 Prob:  0.59% Token: | that|
Top 8th token. Logit: 12.26 Prob:  0.45% Token: | this|
Top 9th token. Logit: 11.78 Prob:  0.28% Token: | the|


In [13]:
example_prompt = "Matrix is a great movie, isn’t"
example_answer = " it"
utils.test_prompt(example_prompt, example_answer, model)

Tokenized prompt: ['<|endoftext|>', 'Matrix', ' is', ' a', ' great', ' movie', ',', ' isn', '�', '�', 't']
Tokenized answer: [' it']


Top 0th token. Logit: 18.16 Prob: 94.84% Token: | it|
Top 1th token. Logit: 13.65 Prob:  1.04% Token: | there|
Top 2th token. Logit: 13.15 Prob:  0.63% Token: | that|
Top 3th token. Logit: 12.37 Prob:  0.29% Token: | this|
Top 4th token. Logit: 12.37 Prob:  0.29% Token: | he|
Top 5th token. Logit: 12.32 Prob:  0.28% Token: | the|
Top 6th token. Logit: 12.10 Prob:  0.22% Token: | you|
Top 7th token. Logit: 12.09 Prob:  0.22% Token: |?|
Top 8th token. Logit: 11.81 Prob:  0.16% Token: | I|
Top 9th token. Logit: 11.41 Prob:  0.11% Token: | they|


Generate reference prompts for the task to run the model on.

We'll run the model on 25 instances of this task, each prompt format with each name.

In [14]:
prompt_formats = [
    "{} is a great friend, isn’t",
    "{} is an amazing person, isn’t",    
    "{} is a fantastic colleague, isn’t",    
    "{} is a wonderful partner, isn’t",    
    "{} is an excellent student, isn’t"
    ]

pronouns = [" she", " he"]

# List of names, in the format (name, pronoun)
names = [
    ("Mary", 0), 
    ("John", 1),
    ("Tom", 1),
    ("Dan", 1),
    ("Amy", 0),
]

# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []

for prompt_format in prompt_formats:
    for name, pronoun_idx in names:
        prompts.append(prompt_format.format(name))

        answers.append(
            (
                pronouns[pronoun_idx], 
                pronouns[1-pronoun_idx]
            )
            )
        
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
answer_tokens = torch.tensor(answer_tokens).cuda()
print(prompts)
print(answers)

['Mary is a great friend, isn’t', 'John is a great friend, isn’t', 'Tom is a great friend, isn’t', 'Dan is a great friend, isn’t', 'Amy is a great friend, isn’t', 'Mary is an amazing person, isn’t', 'John is an amazing person, isn’t', 'Tom is an amazing person, isn’t', 'Dan is an amazing person, isn’t', 'Amy is an amazing person, isn’t', 'Mary is a fantastic colleague, isn’t', 'John is a fantastic colleague, isn’t', 'Tom is a fantastic colleague, isn’t', 'Dan is a fantastic colleague, isn’t', 'Amy is a fantastic colleague, isn’t', 'Mary is a wonderful partner, isn’t', 'John is a wonderful partner, isn’t', 'Tom is a wonderful partner, isn’t', 'Dan is a wonderful partner, isn’t', 'Amy is a wonderful partner, isn’t', 'Mary is an excellent student, isn’t', 'John is an excellent student, isn’t', 'Tom is an excellent student, isn’t', 'Dan is an excellent student, isn’t', 'Amy is an excellent student, isn’t']
[(' she', ' he'), (' he', ' she'), (' he', ' she'), (' he', ' she'), (' she', ' he')

In [15]:
# ensuring all prompts are same number of tokens
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)

Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'John', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Tom', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Dan', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Amy', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Mary', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'John', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Tom', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']
Pro

We now run the model on these prompts and use run_with_cache to get both the logits and a cache of all internal activations for later analysis.

In [16]:
tokens = model.to_tokens(prompts, prepend_bos=True)
# Move the tokens to the GPU
tokens = tokens.cuda()
# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the **logit difference**, the difference in logit between the correct pronoun and the incorrect pronoun (eg, `logit( she)-logit( he)`). 

In [17]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens).item())

Per prompt logit difference: tensor([3.8977, 4.8540, 3.9514, 3.9663, 3.4600, 5.1857, 4.7339, 3.8067, 3.7226,
        4.7266, 3.9201, 4.4639, 3.5893, 3.5116, 3.5193, 4.1234, 3.2621, 2.5424,
        2.4767, 3.4700, 4.3414, 4.8153, 3.8081, 3.8098, 3.8759],
       device='cuda:0')
Average logit difference: 3.913365125656128


We see that the average logit difference is 3.9 - for context, this represents putting an $e^{3.9}\approx 49\times$ higher probability on the correct answer. 

### visualize attention pattern

In [18]:
def attn_pattern(text, layer=0):
    tokens = model.to_tokens(text)
    logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
    attention_pattern = cache["pattern", layer, "attn"]
    str_tokens = model.to_str_tokens(text)
    return cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

In [19]:
attn_pattern("Mary is a great friend, isn’t")

In [20]:
attn_pattern("John is a great friend, isn’t")

In [21]:
attn_pattern("Matrix is a great movie, isn’t")

### Ablation

In [22]:
# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    value[:, :, head_index_to_ablate, :] = 0.
    return value

tokens = model.to_tokens("Mary is a great friend, isn’t she")

original_loss = model(tokens, return_type="loss")

# We make a tensor to store the results for each ablation run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
ablation_result = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for head in range(model.cfg.n_heads):

        # Use functools.partial to create a temporary hook function with the head fixed
        temp_hook_fn = partial(head_ablation_hook, head_index_to_ablate=head)
        # Run the model with the patching hook

        ablated_loss = model.run_with_hooks(
            tokens, 
            return_type="loss", # get loss for final logit
            fwd_hooks=[(
                utils.get_act_name("v", layer), # try v -> o
                temp_hook_fn
                )]
            )
        ablation_result[layer, head] = ablated_loss

model.reset_hooks()

  0%|          | 0/12 [00:00<?, ?it/s]

In [23]:
%matplotlib inline

imshow(ablation_result, midpoint=original_loss.item(), xaxis="Head", yaxis="Layer", title="Ablated loss for every head")

ablating some heads like layer 0 head 7 increases the loss whereas ablating some heads like layer 2 head 0 decreases the loss

In [24]:
# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    value[:, :, head_index_to_ablate, :] = 0.
    return value

tokens = model.to_tokens("Mary is a great friend, isn’t")

correct_index = model.to_single_token(" she")
incorrect_index = model.to_single_token(" he")

original_logits = model(tokens, return_type="logits")
original_logit_diff = original_logits[0, -1, correct_index] - original_logits[0, -1, incorrect_index]

# We make a tensor to store the results for each ablation run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
ablation_result = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for head in range(model.cfg.n_heads):

        # Use functools.partial to create a temporary hook function with the head fixed
        temp_hook_fn = partial(head_ablation_hook, head_index_to_ablate=head)
        # Run the model with the patching hook

        ablated_logits = model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=[(
                utils.get_act_name("v", layer), 
                temp_hook_fn
                )]
            )
        
        ablated_logit_diff = ablated_logits[0, -1, correct_index] - ablated_logits[0, -1, incorrect_index]
        ablation_result[layer, head] = ablated_logit_diff

model.reset_hooks()

  0%|          | 0/12 [00:00<?, ?it/s]

In [25]:
%matplotlib inline

imshow(ablation_result, midpoint=original_logit_diff.item(), xaxis="Head", yaxis="Layer", title="Ablated logit difference for every head")

ablating some heads like layer 0 head 8 decreases the logit difference whereas ablating some heads like layer 7 head 3 increases the logit difference

### Activation Patching

In [26]:
clean_prompt = "Mary is a great friend, isn’t"
corrupted_prompt = "John is a great friend, isn’t"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" she", incorrect_answer=" he"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

Clean logit difference: 3.898
Corrupted logit difference: -4.854


In [27]:
clean_tokens.shape, corrupted_tokens.shape

(torch.Size([1, 11]), torch.Size([1, 11]))

In [28]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary
def residual_stream_patching_hook(
    resid_pre: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint,
    position: int
) -> Float[torch.Tensor, "batch pos d_model"]:
    # Each HookPoint has a name attribute giving the name of the hook.
    clean_resid_pre = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre

# We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
num_positions = len(clean_tokens[0])
patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)

  0%|          | 0/12 [00:00<?, ?it/s]

In [29]:
%matplotlib inline
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream")

Initially, the subject (Mary) token is all that matters, and all relevant information remains here until heads in layer 4 and 5 move this to the "is" token, from where heads in layer 9 and 10 move this to the final token where it's used to predict the pronoun.

This result is consistent for larger model sizes as well.

### LayerNorm folding bias

In [30]:
he_bias = model.unembed.b_U[model.to_single_token(' he')]
she_bias = model.unembed.b_U[model.to_single_token(' she')]

print(f"he bias: {he_bias.item():.4f}")
print(f"she bias: {she_bias.item():.4f}")
print(f"Prob ratio bias: {torch.exp(he_bias - she_bias).item():.4f}x")

he bias: 4.5582
she bias: 3.6625
Prob ratio bias: 2.4490x


The bias created across the unembed due to LayerNorm folding favours " he" over " she" by about 0.9! All other things being the same, this makes the " he" token 2.4x times more likely than the " she" token.

### Analysing circuit formation during training