# Setup

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-lpd4a9w7
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-lpd4a9w7
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 0ffcc8ad647d9e991f4c2596557a9d7475617773
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting wandb>=0.13.5
  Downloading wandb-0.15.1-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers>=4.25.1
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [None]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f027c13b7c0>

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, midpoint=0.0, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=midpoint, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
line(np.arange(5))

set-up device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Pronoun prediction

The task is choosing the right pronouns (e.g. he vs she vs it vs they)

A good setup is a rhetorical question (so it doesn’t spoil the answer!) like “Lina is a great friend, isn’t” (h/t Marius Hobbhahn)

The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer.

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    refactor_factored_attn_matrices=True, 
    device=device
    )

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


The next step is to verify that the model can actually do the task!

In [None]:
example_prompt = "Mary is a great friend, isn’t"
example_answer = " she"
utils.test_prompt(example_prompt, example_answer, model)

Tokenized prompt: ['<|endoftext|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 17.41 Prob: 82.67% Token: | she|
Top 1th token. Logit: 14.69 Prob:  5.45% Token: | it|
Top 2th token. Logit: 13.51 Prob:  1.68% Token: | he|
Top 3th token. Logit: 13.11 Prob:  1.12% Token: | there|
Top 4th token. Logit: 12.74 Prob:  0.78% Token: | I|
Top 5th token. Logit: 12.72 Prob:  0.76% Token: | we|
Top 6th token. Logit: 12.67 Prob:  0.72% Token: | you|
Top 7th token. Logit: 12.35 Prob:  0.52% Token: | her|
Top 8th token. Logit: 12.23 Prob:  0.47% Token: | this|
Top 9th token. Logit: 12.16 Prob:  0.43% Token: | that|


In [None]:
example_prompt = "John is a great friend, isn’t"
example_answer = " he"
utils.test_prompt(example_prompt, example_answer, model)

Tokenized prompt: ['<|endoftext|>', 'John', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Tokenized answer: [' he']


Top 0th token. Logit: 17.47 Prob: 83.43% Token: | he|
Top 1th token. Logit: 14.81 Prob:  5.81% Token: | it|
Top 2th token. Logit: 13.22 Prob:  1.18% Token: | there|
Top 3th token. Logit: 13.06 Prob:  1.01% Token: | you|
Top 4th token. Logit: 13.04 Prob:  0.99% Token: | we|
Top 5th token. Logit: 12.69 Prob:  0.70% Token: | I|
Top 6th token. Logit: 12.62 Prob:  0.65% Token: | she|
Top 7th token. Logit: 12.52 Prob:  0.59% Token: | that|
Top 8th token. Logit: 12.26 Prob:  0.45% Token: | this|
Top 9th token. Logit: 11.78 Prob:  0.28% Token: | the|


In [None]:
example_prompt = "Matrix is a great movie, isn’t"
example_answer = " it"
utils.test_prompt(example_prompt, example_answer, model)

Tokenized prompt: ['<|endoftext|>', 'Matrix', ' is', ' a', ' great', ' movie', ',', ' isn', '�', '�', 't']
Tokenized answer: [' it']


Top 0th token. Logit: 18.16 Prob: 94.84% Token: | it|
Top 1th token. Logit: 13.65 Prob:  1.04% Token: | there|
Top 2th token. Logit: 13.15 Prob:  0.63% Token: | that|
Top 3th token. Logit: 12.37 Prob:  0.29% Token: | this|
Top 4th token. Logit: 12.37 Prob:  0.29% Token: | he|
Top 5th token. Logit: 12.32 Prob:  0.28% Token: | the|
Top 6th token. Logit: 12.10 Prob:  0.22% Token: | you|
Top 7th token. Logit: 12.09 Prob:  0.22% Token: |?|
Top 8th token. Logit: 11.81 Prob:  0.16% Token: | I|
Top 9th token. Logit: 11.41 Prob:  0.11% Token: | they|


It can do it pretty well!

Let's see how does the toy models do on this task.

In [None]:
example_prompt = "Mary is a great friend, isn’t"
example_answer = " she"

In [None]:
model_types = []
n_layers = []
model_names = []
probs = []

for model_type in ["attn-only","gelu","solu"]:
    for n_layer in ["1","2","3","4"]:

        model_name = f"{model_type}-{n_layer}l"
        print(f"{model_name}\n")

        toy_model = HookedTransformer.from_pretrained(model_name, device=device)
        utils.test_prompt(example_prompt, example_answer, toy_model,top_k = 5)

        correct_index = toy_model.to_single_token(example_answer)
        tokens = toy_model.to_tokens(example_prompt)
        logits =  toy_model(tokens, return_type="logits")
        prob = logits.softmax(dim=-1)[:,-1,correct_index].item()

        model_types.append(model_type)
        n_layers.append(n_layer)
        model_names.append(model_name)
        probs.append(prob)

        print("\n")

attn-only-1l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/205M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.04M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Loaded pretrained model attn-only-1l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 16.26 Prob: 13.56% Token: | a|
Top 1th token. Logit: 15.56 Prob:  6.75% Token: | the|
Top 2th token. Logit: 15.50 Prob:  6.32% Token: | it|
Top 3th token. Logit: 15.21 Prob:  4.72% Token: | just|
Top 4th token. Logit: 14.70 Prob:  2.83% Token: | an|




attn-only-2l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/210M [00:00<?, ?B/s]

Loaded pretrained model attn-only-2l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 17.27 Prob: 30.10% Token: | it|
Top 1th token. Logit: 15.49 Prob:  5.11% Token: | just|
Top 2th token. Logit: 15.46 Prob:  4.95% Token: | that|
Top 3th token. Logit: 15.33 Prob:  4.33% Token: | the|
Top 4th token. Logit: 15.15 Prob:  3.62% Token: | she|




attn-only-3l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/216M [00:00<?, ?B/s]

Loaded pretrained model attn-only-3l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 18.01 Prob: 48.57% Token: | it|
Top 1th token. Logit: 16.77 Prob: 14.15% Token: | that|
Top 2th token. Logit: 15.58 Prob:  4.28% Token: | the|
Top 3th token. Logit: 14.75 Prob:  1.87% Token: | always|
Top 4th token. Logit: 14.60 Prob:  1.62% Token: | a|




attn-only-4l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/221M [00:00<?, ?B/s]

Loaded pretrained model attn-only-4l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 17.03 Prob: 61.47% Token: | it|
Top 1th token. Logit: 14.59 Prob:  5.35% Token: | she|
Top 2th token. Logit: 14.50 Prob:  4.88% Token: | the|
Top 3th token. Logit: 13.88 Prob:  2.64% Token: | that|
Top 4th token. Logit: 13.78 Prob:  2.38% Token: | he|




gelu-1l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/213M [00:00<?, ?B/s]

Loaded pretrained model gelu-1l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 16.66 Prob: 24.92% Token: | it|
Top 1th token. Logit: 15.65 Prob:  9.07% Token: | she|
Top 2th token. Logit: 15.25 Prob:  6.09% Token: | the|
Top 3th token. Logit: 14.92 Prob:  4.39% Token: | a|
Top 4th token. Logit: 14.58 Prob:  3.11% Token: | her|




gelu-2l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/227M [00:00<?, ?B/s]

Loaded pretrained model gelu-2l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 17.96 Prob: 43.40% Token: | it|
Top 1th token. Logit: 16.96 Prob: 15.96% Token: | she|
Top 2th token. Logit: 16.04 Prob:  6.36% Token: | that|
Top 3th token. Logit: 15.66 Prob:  4.33% Token: | he|
Top 4th token. Logit: 15.24 Prob:  2.86% Token: | the|




gelu-3l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loaded pretrained model gelu-3l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 19.44 Prob: 51.23% Token: | it|
Top 1th token. Logit: 18.65 Prob: 23.24% Token: | she|
Top 2th token. Logit: 17.38 Prob:  6.55% Token: | he|
Top 3th token. Logit: 16.79 Prob:  3.61% Token: | that|
Top 4th token. Logit: 16.46 Prob:  2.62% Token: | there|




gelu-4l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/254M [00:00<?, ?B/s]

Loaded pretrained model gelu-4l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 18.88 Prob: 56.46% Token: | she|
Top 1th token. Logit: 17.51 Prob: 14.36% Token: | it|
Top 2th token. Logit: 17.01 Prob:  8.68% Token: | he|
Top 3th token. Logit: 16.71 Prob:  6.47% Token: | that|
Top 4th token. Logit: 15.47 Prob:  1.87% Token: | afraid|




solu-1l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/213M [00:00<?, ?B/s]

Loaded pretrained model solu-1l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 17.60 Prob: 47.17% Token: | it|
Top 1th token. Logit: 15.15 Prob:  4.07% Token: | a|
Top 2th token. Logit: 15.07 Prob:  3.73% Token: | the|
Top 3th token. Logit: 15.03 Prob:  3.59% Token: | just|
Top 4th token. Logit: 14.92 Prob:  3.22% Token: | she|




solu-2l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/227M [00:00<?, ?B/s]

Loaded pretrained model solu-2l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 17.09 Prob: 25.19% Token: | it|
Top 1th token. Logit: 17.01 Prob: 23.28% Token: | she|
Top 2th token. Logit: 15.94 Prob:  8.02% Token: | he|
Top 3th token. Logit: 15.75 Prob:  6.62% Token: | that|
Top 4th token. Logit: 15.27 Prob:  4.10% Token: | a|




solu-3l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loaded pretrained model solu-3l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 19.28 Prob: 45.19% Token: | it|
Top 1th token. Logit: 18.62 Prob: 23.43% Token: | she|
Top 2th token. Logit: 17.72 Prob:  9.53% Token: | he|
Top 3th token. Logit: 17.07 Prob:  4.95% Token: | you|
Top 4th token. Logit: 16.88 Prob:  4.09% Token: | that|




solu-4l



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

Downloading model_final.pth:   0%|          | 0.00/255M [00:00<?, ?B/s]

Loaded pretrained model solu-4l into HookedTransformer
Tokenized prompt: ['<|BOS|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '’', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 18.11 Prob: 33.81% Token: | she|
Top 1th token. Logit: 17.70 Prob: 22.43% Token: | it|
Top 2th token. Logit: 17.55 Prob: 19.35% Token: | he|
Top 3th token. Logit: 16.16 Prob:  4.80% Token: | that|
Top 4th token. Logit: 15.70 Prob:  3.02% Token: | there|






Let's visualize this!

In [None]:
scatter(
    x=n_layers, 
    y=probs, 
    xaxis="N Layers",
    yaxis="Probability of correct answer", # Not sure if probability is the right metric here
    caxis="Model Type",
    hover_name = model_names,
    color=model_types,
    title="Performace of toy models on pronoun prediction task")

Ok so it's not looking good for attention only models.

Also `gelu-4l` and `solu-4l` are the only models to correctly predict the answer.

Another observation from the `test_prompt` output is that there is a heavy bias for the " it" token in the toy models.

In [None]:
for model_type in ["attn-only","gelu","solu"]:
    for n_layer in ["1","2","3","4"]:

        model_name = f"{model_type}-{n_layer}l"
        print(f"{model_name}\n")

        toy_model = HookedTransformer.from_pretrained(model_name, device=device)
        it_bias = toy_model.unembed.b_U[model.to_single_token(' it')]
        she_bias = toy_model.unembed.b_U[model.to_single_token(' she')]

        print(f"it bias: {it_bias.item():.4f}")
        print(f"she bias: {she_bias.item():.4f}")
        print(f"prob ratio bias: {torch.exp(it_bias - she_bias).item():.4f}x")

        print("\n")

attn-only-1l

Loaded pretrained model attn-only-1l into HookedTransformer
it bias: 8.2050
she bias: 4.0616
prob ratio bias: 63.0147x


attn-only-2l

Loaded pretrained model attn-only-2l into HookedTransformer
it bias: 5.3323
she bias: -0.1984
prob ratio bias: 252.3091x


attn-only-3l

Loaded pretrained model attn-only-3l into HookedTransformer
it bias: 4.4752
she bias: 0.2956
prob ratio bias: 65.3380x


attn-only-4l

Loaded pretrained model attn-only-4l into HookedTransformer
it bias: 4.0389
she bias: 0.6067
prob ratio bias: 30.9431x


gelu-1l

Loaded pretrained model gelu-1l into HookedTransformer
it bias: 4.6283
she bias: 0.9287
prob ratio bias: 40.4317x


gelu-2l

Loaded pretrained model gelu-2l into HookedTransformer
it bias: 3.7967
she bias: 0.7352
prob ratio bias: 21.3587x


gelu-3l

Loaded pretrained model gelu-3l into HookedTransformer
it bias: 3.1698
she bias: 0.4404
prob ratio bias: 15.3239x


gelu-4l

Loaded pretrained model gelu-4l into HookedTransformer
it bias: 2.3943
she

This can also be seen in the `gpt-2-small` model though to a lesser degree.

In [None]:
it_bias = model.unembed.b_U[model.to_single_token(' it')]
she_bias = model.unembed.b_U[model.to_single_token(' she')]

print(f"it bias: {it_bias.item():.4f}")
print(f"she bias: {she_bias.item():.4f}")
print(f"Prob ratio bias: {torch.exp(it_bias - she_bias).item():.4f}x")

it bias: 5.2323
she bias: 3.6625
Prob ratio bias: 4.8057x


The bias created across the unembed due to LayerNorm folding favours " it" over " she" by about 1.6! All other things being the same, this makes the " it" token 4.8x times more likely than the " she" token.

OK let's generate reference prompts for the task to run the model on.

We'll run the model on 20 instances of this task, each prompt format with each name.

In [None]:
prompt_formats = [
    "{} is a great friend, isn’t",
    "{} is an amazing person, isn’t",    
    "{} is a fantastic colleague, isn’t",    
    "{} is a wonderful partner, isn’t",    
    "{} is an excellent student, isn’t"
    ]

pronouns = [" she", " he"]

# List of names, in the format (name, pronoun)
names = [
    ("Mary", 0), 
    ("John", 1),
    ("Dan", 1),
    ("Amy", 0),
]

# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []

for prompt_format in prompt_formats:
    for name, pronoun_idx in names:
        prompts.append(prompt_format.format(name))

        answers.append(
            (
                pronouns[pronoun_idx], 
                pronouns[1-pronoun_idx]
            )
            )
        
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
answer_tokens = torch.tensor(answer_tokens).cuda()
print(prompts)
print(answers)

['Mary is a great friend, isn’t', 'John is a great friend, isn’t', 'Dan is a great friend, isn’t', 'Amy is a great friend, isn’t', 'Mary is an amazing person, isn’t', 'John is an amazing person, isn’t', 'Dan is an amazing person, isn’t', 'Amy is an amazing person, isn’t', 'Mary is a fantastic colleague, isn’t', 'John is a fantastic colleague, isn’t', 'Dan is a fantastic colleague, isn’t', 'Amy is a fantastic colleague, isn’t', 'Mary is a wonderful partner, isn’t', 'John is a wonderful partner, isn’t', 'Dan is a wonderful partner, isn’t', 'Amy is a wonderful partner, isn’t', 'Mary is an excellent student, isn’t', 'John is an excellent student, isn’t', 'Dan is an excellent student, isn’t', 'Amy is an excellent student, isn’t']
[(' she', ' he'), (' he', ' she'), (' he', ' she'), (' she', ' he'), (' she', ' he'), (' he', ' she'), (' he', ' she'), (' she', ' he'), (' she', ' he'), (' he', ' she'), (' he', ' she'), (' she', ' he'), (' she', ' he'), (' he', ' she'), (' he', ' she'), (' she', 

In [None]:
# ensuring all prompts are same number of tokens
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)

Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'John', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Dan', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Amy', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Mary', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'John', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Dan', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Amy', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']


We now run the model on these prompts and use run_with_cache to get both the logits and a cache of all internal activations for later analysis.

In [None]:
tokens = model.to_tokens(prompts, prepend_bos=True)
# Move the tokens to the GPU
tokens = tokens.cuda()
# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the **logit difference**, the difference in logit between the correct pronoun and the incorrect pronoun (eg, `logit( she)-logit( he)`). 

In [None]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens).item())

Per prompt logit difference: tensor([3.8977, 4.8540, 3.9663, 3.4600, 5.1857, 4.7339, 3.7226, 4.7266, 3.9201,
        4.4638, 3.5116, 3.5193, 4.1234, 3.2621, 2.4767, 3.4701, 4.3414, 4.8153,
        3.8098, 3.8759], device='cuda:0')
Average logit difference: 4.006812572479248


We see that the average logit difference is 4.0 - for context, this represents putting an $e^{4.0}\approx 55\times$ higher probability on the correct answer. 

# Direct Logit Attribution

We use `model.tokens_to_residual_directions` to map the answer tokens to residual stream direction, and then convert this to a logit difference direction for each batch.

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([20, 2, 768])
Logit difference directions shape: torch.Size([20, 768])


To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer. 

In [None]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff.item())

Final residual stream shape: torch.Size([20, 11, 768])
Calculated average logit diff: 4.006814002990723
Original logit difference: 4.006812572479248


### Logit Lens

We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers. 

In [None]:
def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

In [None]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

We see that the model is utterly unable to do the task until layer 8 and then the performance starts to increase from there in a step fashion with jumps at attention part of the layer.

### Layer Attribution

We can repeat the above analysis but for each layer (this is equivalent to the differences between adjacent residual streams)

In [None]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

We see that only attention layers matter! And again we note that attention layer 9, 10 and 11 improves things a lot.

### Head Attribution

We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.

In [None]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


We see that only a few heads really matter - heads L9H7, L10H9 and L11H8 contribute a lot positively (explaining why attention layer 9, 10 and 11 are so important). There are also several heads that matter positively or negatively but less strongly.

### Attention Analysis

We use Anthropic's PySvelte library to visualize the attention patterns! We visualize the top 3 positive heads by direct logit attribution, and show these for the first prompt (as an illustration).

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]], 
    local_cache: Optional[ActivationCache]=None, 
    local_tokens: Optional[torch.Tensor]=None, 
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = tokens[0]
    
    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

In [None]:
top_k = 3
top_positive_logit_attr_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads")

pysvelte components appear to be unbuilt or stale
Running npm install...
Building pysvelte components with webpack...


# Activation Patching

### Residual Stream

Lets begin by patching in the residual stream at the start of each layer and for each token position.

We first create a set of corrupted tokens - where we swap each pair of prompts to have the opposite answer.

In [None]:
corrupted_prompts = []
for i in range(0, len(prompts), 2):
    corrupted_prompts.append(prompts[i+1])
    corrupted_prompts.append(prompts[i])
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
print("Clean Average Logit Diff", original_average_logit_diff)

Corrupted Average Logit Diff tensor(-4.0068, device='cuda:0')
Clean Average Logit Diff tensor(4.0068, device='cuda:0')


In [None]:
corrupted_prompts

['John is a great friend, isn’t',
 'Mary is a great friend, isn’t',
 'Amy is a great friend, isn’t',
 'Dan is a great friend, isn’t',
 'John is an amazing person, isn’t',
 'Mary is an amazing person, isn’t',
 'Amy is an amazing person, isn’t',
 'Dan is an amazing person, isn’t',
 'John is a fantastic colleague, isn’t',
 'Mary is a fantastic colleague, isn’t',
 'Amy is a fantastic colleague, isn’t',
 'Dan is a fantastic colleague, isn’t',
 'John is a wonderful partner, isn’t',
 'Mary is a wonderful partner, isn’t',
 'Amy is a wonderful partner, isn’t',
 'Dan is a wonderful partner, isn’t',
 'John is an excellent student, isn’t',
 'Mary is an excellent student, isn’t',
 'Amy is an excellent student, isn’t',
 'Dan is an excellent student, isn’t']

We now intervene on the corrupted run and patch in the clean residual stream at a specific layer and position.

In [None]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook, 
    pos, 
    clean_cache):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)

patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("resid_pre", layer), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
imshow(patched_residual_stream_diff, x=prompt_position_labels, title="Normalized Logit Difference From Patched Residual Stream", labels={"x":"Position", "y":"Layer"})

Initially, the subject (Mary) token is all that matters, and all relevant information remains here until heads in layer 4 and 5 move this to other tokens like " is, " isn" and " friend", from where heads in layer 9 and 10 move this to the final token where it's used to predict the pronoun.

This result is consistent for larger model sizes as well.

### Layers

We can apply exactly the same idea, but this time patching in attention or MLP layers.

In [None]:
patched_attn_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("attn_out", layer), 
                hook_fn)], 
            return_type="logits"
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(patched_attn_logits, answer_tokens)
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("mlp_out", layer), 
                hook_fn)], 
            return_type="logits"
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(patched_mlp_logits, answer_tokens)

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(patched_attn_logit_diff)
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff)

In [None]:
imshow(patched_attn_diff, x=prompt_position_labels, title="Normalized Logit Difference From Patched Attention Layer", labels={"x":"Position", "y":"Layer"})

We see that several attention layers are significant but that, matching the residual stream results, early layers matter on the subject token and other tokens, and later layers matter on the final token. As with direct logit attribution, layers 9, 10 and 11 is positive, suggesting that the late layers only matter for direct logit effects, but we also see that layers 4 and 6 matter significantly.

In [None]:
imshow(patched_mlp_diff, x=prompt_position_labels, title="Normalized Logit Difference From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})

We see that several early MLP layers also matter. MLP 3 and 5 matter for the subject token and MLP 7 matters for other tokens.

And patching MLP 11 is negative for the logit diff, in line with what we saw in the logit lens results. 

MLP 0 also matters a lot, but this is just a generally true statement about MLP 0 rather than being about the circuit on this task. It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance.

### Heads

We can refine the above analysis by patching in individual heads!

In [None]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook, 
    head_index, 
    position,
    clean_cache):
    corrupted_head_vector[:, position, head_index, :] = clean_cache[hook.name][:, position, head_index, :]
    return corrupted_head_vector

patched_head_z_diff = torch.zeros(model.cfg.n_layers*model.cfg.n_heads, tokens.shape[1], device="cuda", dtype=torch.float32)
for i in range(model.cfg.n_layers*model.cfg.n_heads):
    for position in range(tokens.shape[1]):
        layer = i//model.cfg.n_heads
        head_index = i % model.cfg.n_heads
        hook_fn = partial(patch_head_vector, head_index=head_index, position=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("z", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_z_diff[i, position] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
head_names = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]

imshow(patched_head_z_diff, y=head_names, title="Normalized Logit Difference From Patched Head Output", labels={"x":"Position", "y":"Head"})

In [None]:
def patch_head_vector_all_pos(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook, 
    head_index, 
    clean_cache):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][:, :, head_index, :]
    return corrupted_head_vector

patched_head_z_diff_all_pos = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector_all_pos, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("z", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_z_diff_all_pos[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_z_diff_all_pos, title="Normalized Logit Difference From Patched Head Output (All Position)", labels={"x":"Head", "y":"Layer"})

We can now see that, in mid layers the heads L6H0 and L4H3 matter for other tokens (explaining why attention layer 4 and 6 were significant in the patching result) and are presumably responsible for moving information from the subject to other tokens.

Heads L9H7 and L10H9 in late layers matter for the final token.

### Decomposing Heads

First let's patch in the value vectors, to measure when figuring out what to move is important.

In [None]:
patched_head_v_diff = torch.zeros(model.cfg.n_layers*model.cfg.n_heads, tokens.shape[1], device="cuda", dtype=torch.float32)
for i in range(model.cfg.n_layers*model.cfg.n_heads):
    for position in range(tokens.shape[1]):
        layer = i//model.cfg.n_heads
        head_index = i % model.cfg.n_heads
        hook_fn = partial(patch_head_vector, head_index=head_index, position=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("v", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_v_diff[i, position] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_v_diff, y=head_names, title="Normalized Logit Difference From Patched Head Value", labels={"x":"Position", "y":"Head"})

In [None]:
patched_head_v_diff_all_pos = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector_all_pos, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("v", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_v_diff_all_pos[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_v_diff_all_pos, title="Normalized Logit Difference From Patched Head Value (All Position)", labels={"x":"Head", "y":"Layer"})

It's very Hooked to interpret if we plot a scatter plot against patching head outputs. 

In [None]:
scatter(
    x=utils.to_numpy(patched_head_v_diff_all_pos.flatten()), 
    y=utils.to_numpy(patched_head_z_diff_all_pos.flatten()), 
    xaxis="Value Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name = head_names,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs value patching (All Position)")

Now let's patch in the attention pattern, to measure when figuring out where to move is important.

In [None]:
def patch_head_pattern(
    corrupted_head_pattern: Float[torch.Tensor, "batch head_index dest_pos src_pos"],
    hook, 
    head_index,  
    position,
    clean_cache):
    corrupted_head_pattern[:, head_index, position, :] = clean_cache[hook.name][:, head_index, position, :]
    return corrupted_head_pattern

patched_head_attn_diff = torch.zeros(model.cfg.n_layers*model.cfg.n_heads, tokens.shape[1], device="cuda", dtype=torch.float32)
for i in range(model.cfg.n_layers*model.cfg.n_heads):
    for position in range(tokens.shape[1]):
        layer = i//model.cfg.n_heads
        head_index = i % model.cfg.n_heads
        hook_fn = partial(patch_head_pattern, head_index=head_index, position=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("attn", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_attn_diff[i, position] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_attn_diff, y=head_names, title="Normalized Logit Difference From Patched Head Pattern", labels={"x":"Position", "y":"Head"})

Note the difference in logit diff scale for this and value patching.

In [None]:
def patch_head_pattern_all_pos(
    corrupted_head_pattern: Float[torch.Tensor, "batch head_index query_pos d_head"],
    hook, 
    head_index, 
    clean_cache):
    corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][:, head_index, :, :]
    return corrupted_head_pattern

patched_head_attn_diff_all_pos = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_pattern_all_pos, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("attn", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_attn_diff_all_pos[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_attn_diff_all_pos, title="Normalized Logit Difference From Patched Head Pattern (All Position)", labels={"x":"Head", "y":"Layer"})
scatter(
    x=utils.to_numpy(patched_head_attn_diff_all_pos.flatten()), 
    y=utils.to_numpy(patched_head_z_diff_all_pos.flatten()), 
    xaxis="Attention Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name = head_names,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs attention patching (All Position)")

Let's patch in the query vectors!

In [None]:
patched_head_q_diff = torch.zeros(model.cfg.n_layers*model.cfg.n_heads, tokens.shape[1], device="cuda", dtype=torch.float32)
for i in range(model.cfg.n_layers*model.cfg.n_heads):
    for position in range(tokens.shape[1]):
        layer = i//model.cfg.n_heads
        head_index = i % model.cfg.n_heads
        hook_fn = partial(patch_head_vector, head_index=head_index, position=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("q", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_q_diff[i, position] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_q_diff, y=head_names, title="Normalized Logit Difference From Patched Head Query", labels={"x":"Position", "y":"Head"})

In [None]:
patched_head_q_diff_all_pos = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector_all_pos, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("q", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_q_diff_all_pos[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_q_diff_all_pos, title="Normalized Logit Difference From Patched Head Query (All Position)", labels={"x":"Head", "y":"Layer"})

It's very Hooked to interpret if we plot a scatter plot against patching head outputs. 

In [None]:
scatter(
    x=utils.to_numpy(patched_head_q_diff_all_pos.flatten()), 
    y=utils.to_numpy(patched_head_z_diff_all_pos.flatten()), 
    xaxis="Query Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name = head_names,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs query patching (All Position)")

Let's patch in the key vectors!

In [None]:
patched_head_k_diff = torch.zeros(model.cfg.n_layers*model.cfg.n_heads, tokens.shape[1], device="cuda", dtype=torch.float32)
for i in range(model.cfg.n_layers*model.cfg.n_heads):
    for position in range(tokens.shape[1]):
        layer = i//model.cfg.n_heads
        head_index = i % model.cfg.n_heads
        hook_fn = partial(patch_head_vector, head_index=head_index, position=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("k", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_k_diff[i, position] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_k_diff, y=head_names, title="Normalized Logit Difference From Patched Head Key", labels={"x":"Position", "y":"Head"})

In [None]:
patched_head_k_diff_all_pos = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector_all_pos, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens, 
            fwd_hooks = [(utils.get_act_name("k", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_k_diff_all_pos[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
imshow(patched_head_k_diff_all_pos, title="Normalized Logit Difference From Patched Head Key (All Position)", labels={"x":"Head", "y":"Layer"})

It's very Hooked to interpret if we plot a scatter plot against patching head outputs. 

In [None]:
scatter(
    x=utils.to_numpy(patched_head_k_diff_all_pos.flatten()), 
    y=utils.to_numpy(patched_head_z_diff_all_pos.flatten()), 
    xaxis="Key Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name = head_names,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs key patching (All Position)")

### Neurons

Let's patch individual neurons!

We'll do this for MLP layer 0, 3, 5, 7 and 11 since these are the important ones that we got from the MLP patching results, also it'll take too much time for doing it for all layers.

In [None]:
def patch_neuron(
   corrupted_neuron_vector: Float[torch.Tensor, "batch pos d_mlp"],
   hook,
   neuron,
   clean_cache):
   corrupted_neuron_vector[:, :, neuron] = clean_cache[hook.name][:, :, neuron]
   return corrupted_neuron_vector

imp_mlp = [0, 3, 5, 7, 11] # Important MLP layers

patched_neuron_diff = torch.zeros(len(imp_mlp), model.cfg.d_mlp, device="cuda", dtype=torch.float32)

for i, layer in enumerate(imp_mlp):
   for neuron in range(model.cfg.d_mlp):
       hook_fn = partial(patch_neuron, neuron=neuron, clean_cache=cache)
       patched_logits = model.run_with_hooks(
           corrupted_tokens,
           fwd_hooks = [(utils.get_act_name("pre", layer, "mlp"),
               hook_fn)],
           return_type="logits"
       )
       patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)
       patched_neuron_diff[i,neuron] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
scatter(
   x=torch.arange(model.cfg.d_mlp).repeat(len(imp_mlp)),
   y=patched_neuron_diff.view(-1),
   xaxis="Neuron",
   yaxis="Patch Improvement",
   caxis="Layer",
   color=[str(x) for x in imp_mlp for _ in range(model.cfg.d_mlp)],
   title="Normalized Logit Difference From Patched Neuron (All Position)")

Neurons that stand out in above result:
- [L0N1457](https://neuroscope.io/gpt2-small/0/1457.html): Maximally activates on tokens like " her", " she" and female names.
- [L0N1160](https://neuroscope.io/gpt2-small/0/1160.html): Maximally activates on tokens "She".
- [L0N2055](https://neuroscope.io/gpt2-small/0/2055.html): Maximally activates on tokens like " he", " him", " his" and male names.
- [L0N1844](https://neuroscope.io/gpt2-small/0/1844.html): Maximally activates on male names.
- [L3N3040](https://neuroscope.io/gpt2-small/3/3040.html): Maximally activates on tokens "She", "Ms", "Mrs" and " wife".
- [L5N1319](https://neuroscope.io/gpt2-small/5/1319.html): Maximally activates on sub words of female names.
- [L7N3050](https://neuroscope.io/gpt2-small/7/3050.html): Not sure about this one though seems to activate with names.
- [L7N79](https://neuroscope.io/gpt2-small/7/79.html): Maximally activates on tokens like " he", "him" and " his".
- [L11N2627](https://neuroscope.io/gpt2-small/11/2627.html): Maximally activates on token " since".
- [L11N2607](https://neuroscope.io/gpt2-small/11/2607.html): Maximally activates on tokens where the subsequent token is " he", " him" and " his".
- [L11N2652](https://neuroscope.io/gpt2-small/11/2652.html): Maximally activates on tokens where the subsequent token is " he" and " his".
- [L11N2926](https://neuroscope.io/gpt2-small/11/2926.html): Maximally activates on tokens where the subsequent token is " her", " community" and " refugees".
- [L11N2980](https://neuroscope.io/gpt2-small/11/2980.html): Maximally activates on tokens where the subsequent token is " she".

# Visualizing Attention Patterns

Looking at the attention patterns of these heads. Let's take the top 15 heads by output patching (in absolute value) and split it into early, middle and late.

In [None]:
top_k = 15
top_heads_by_output_patch = torch.topk(patched_head_z_diff_all_pos.abs().flatten(), k=top_k).indices
first_mid_layer = 4
first_late_layer = 8
early_heads = top_heads_by_output_patch[top_heads_by_output_patch<model.cfg.n_heads * first_mid_layer]
mid_heads = top_heads_by_output_patch[torch.logical_and(model.cfg.n_heads * first_mid_layer<=top_heads_by_output_patch, top_heads_by_output_patch<model.cfg.n_heads * first_late_layer)]
late_heads = top_heads_by_output_patch[model.cfg.n_heads * first_late_layer<=top_heads_by_output_patch]
visualize_attention_patterns(early_heads, title=f"Top Early Heads")
visualize_attention_patterns(mid_heads, title=f"Top Middle Heads")
visualize_attention_patterns(late_heads, title=f"Top Late Heads")

# Backup Heads

If we knock out one of the heads, then there are some backup heads in later layers that *change their behaviour* and do (some of) the job of the original head.

Let's test this! Let's ablate the most important head (head L10H9) on just the final token using a custom ablation hook and then cache all new activations and compared performance. We focus on the final position because we want to specifically ablate the direct logit effect.

In [None]:
top_head = per_head_logit_diffs.flatten().argmax().item()
top_head_layer = top_head//model.cfg.n_heads
top_head_head = top_head % model.cfg.n_heads
print(f"Top head to ablate: L{top_head_layer}H{top_head_head}")
def ablate_top_head_hook(z: Float[torch.Tensor, "batch pos head_index d_head"], hook):
    z[:, -1, top_head_head, :] = 0
    return z
# Adds a hook into global model state
model.blocks[top_head_layer].attn.hook_z.add_hook(ablate_top_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(tokens)
print(f"Original logit diff: {original_average_logit_diff}")
print(f"Post ablation logit diff: {logits_to_ave_logit_diff(ablated_logits, answer_tokens).item()}")
print(f"Direct Logit Attribution of top head: {per_head_logit_diffs.flatten()[top_head].item()}")
print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_head].item()}")

Top head to ablate: L10H9
Original logit diff: 4.006812572479248
Post ablation logit diff: 3.2031781673431396
Direct Logit Attribution of top head: 1.35689377784729
Naive prediction of post ablation logit diff: 2.649918794631958


We see that naively, ablating the top head should reduce the logit diff by 1.35, from 4.0 to 2.64. **But actually, it only goes down to 3.2!**

As before, we can look at the direct logit attribution of each head after ablating head L10H9 to see what's going on. It's easiest to interpret if plotted as a scatter plot against the initial per head logit difference. We also plot the difference between the original and post-ablation logit diff of each head.

In [None]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Post-Ablation Direct Logit Attribution of Heads")
scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_names, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")
imshow((per_head_ablated_logit_diffs- per_head_logit_diffs), labels={"x":"Head", "y":"Layer"}, title="Difference in Post-Ablation Direct Logit Attribution and Original Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now




We can see some minor difference in a few heads! Head L11H1 increases its effect, adding +0.19 to the logit diff, and the head L11H8 decreases its effect, adding -0.12 to the logit diff. Though not sure if it's enough to justify them as backup heads.