### summarising key findings with a bunch of graphs!

### Setup

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-o97uwt6f
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-o97uwt6f
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 25a9c07cd883f762725f4a5a0cab7b36bc4096cc
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/

## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
Hit:1 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease
Hit:2 https://deb

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [4]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [5]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [6]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fd51c95adc0>

Plotting helper functions:

In [7]:
def imshow(tensor, renderer=None, midpoint=0.0, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=midpoint, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [8]:
line(np.arange(5))

set-up device

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Pronoun prediction

The task is choosing the right pronouns (e.g. he vs she vs it vs they)

A good setup is a rhetorical question (so it doesn’t spoil the answer!) like “Lina is a great friend, isn’t” (h/t Marius Hobbhahn)

The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer.

In [10]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True, 
    device=device
    )

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


The next step is to verify that the model can actually do the task!

In [11]:
example_prompt = "Mary is a great friend, isn’t"
example_answer = " she"
utils.test_prompt(example_prompt, example_answer, model)

Tokenized prompt: ['<|endoftext|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Tokenized answer: [' she']


Top 0th token. Logit: 17.41 Prob: 82.67% Token: | she|
Top 1th token. Logit: 14.69 Prob:  5.45% Token: | it|
Top 2th token. Logit: 13.51 Prob:  1.68% Token: | he|
Top 3th token. Logit: 13.11 Prob:  1.12% Token: | there|
Top 4th token. Logit: 12.74 Prob:  0.78% Token: | I|
Top 5th token. Logit: 12.72 Prob:  0.76% Token: | we|
Top 6th token. Logit: 12.67 Prob:  0.72% Token: | you|
Top 7th token. Logit: 12.35 Prob:  0.52% Token: | her|
Top 8th token. Logit: 12.23 Prob:  0.47% Token: | this|
Top 9th token. Logit: 12.16 Prob:  0.43% Token: | that|


In [12]:
example_prompt = "John is a great friend, isn’t"
example_answer = " he"
utils.test_prompt(example_prompt, example_answer, model)

Tokenized prompt: ['<|endoftext|>', 'John', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Tokenized answer: [' he']


Top 0th token. Logit: 17.47 Prob: 83.43% Token: | he|
Top 1th token. Logit: 14.81 Prob:  5.81% Token: | it|
Top 2th token. Logit: 13.22 Prob:  1.18% Token: | there|
Top 3th token. Logit: 13.06 Prob:  1.01% Token: | you|
Top 4th token. Logit: 13.04 Prob:  0.99% Token: | we|
Top 5th token. Logit: 12.69 Prob:  0.70% Token: | I|
Top 6th token. Logit: 12.62 Prob:  0.65% Token: | she|
Top 7th token. Logit: 12.52 Prob:  0.59% Token: | that|
Top 8th token. Logit: 12.26 Prob:  0.45% Token: | this|
Top 9th token. Logit: 11.78 Prob:  0.28% Token: | the|


In [13]:
example_prompt = "Matrix is a great movie, isn’t"
example_answer = " it"
utils.test_prompt(example_prompt, example_answer, model)

Tokenized prompt: ['<|endoftext|>', 'Matrix', ' is', ' a', ' great', ' movie', ',', ' isn', '�', '�', 't']
Tokenized answer: [' it']


Top 0th token. Logit: 18.16 Prob: 94.84% Token: | it|
Top 1th token. Logit: 13.65 Prob:  1.04% Token: | there|
Top 2th token. Logit: 13.15 Prob:  0.63% Token: | that|
Top 3th token. Logit: 12.37 Prob:  0.29% Token: | this|
Top 4th token. Logit: 12.37 Prob:  0.29% Token: | he|
Top 5th token. Logit: 12.32 Prob:  0.28% Token: | the|
Top 6th token. Logit: 12.10 Prob:  0.22% Token: | you|
Top 7th token. Logit: 12.09 Prob:  0.22% Token: |?|
Top 8th token. Logit: 11.81 Prob:  0.16% Token: | I|
Top 9th token. Logit: 11.41 Prob:  0.11% Token: | they|


Generate reference prompts for the task to run the model on.

We'll run the model on 20 instances of this task, each prompt format with each name.

In [14]:
prompt_formats = [
    "{} is a great friend, isn’t",
    "{} is an amazing person, isn’t",    
    "{} is a fantastic colleague, isn’t",    
    "{} is a wonderful partner, isn’t",    
    "{} is an excellent student, isn’t"
    ]

pronouns = [" she", " he"]

# List of names, in the format (name, pronoun)
names = [
    ("Mary", 0), 
    ("John", 1),
    ("Dan", 1),
    ("Amy", 0),
]

# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []

for prompt_format in prompt_formats:
    for name, pronoun_idx in names:
        prompts.append(prompt_format.format(name))

        answers.append(
            (
                pronouns[pronoun_idx], 
                pronouns[1-pronoun_idx]
            )
            )
        
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
answer_tokens = torch.tensor(answer_tokens).cuda()
print(prompts)
print(answers)

['Mary is a great friend, isn’t', 'John is a great friend, isn’t', 'Dan is a great friend, isn’t', 'Amy is a great friend, isn’t', 'Mary is an amazing person, isn’t', 'John is an amazing person, isn’t', 'Dan is an amazing person, isn’t', 'Amy is an amazing person, isn’t', 'Mary is a fantastic colleague, isn’t', 'John is a fantastic colleague, isn’t', 'Dan is a fantastic colleague, isn’t', 'Amy is a fantastic colleague, isn’t', 'Mary is a wonderful partner, isn’t', 'John is a wonderful partner, isn’t', 'Dan is a wonderful partner, isn’t', 'Amy is a wonderful partner, isn’t', 'Mary is an excellent student, isn’t', 'John is an excellent student, isn’t', 'Dan is an excellent student, isn’t', 'Amy is an excellent student, isn’t']
[(' she', ' he'), (' he', ' she'), (' he', ' she'), (' she', ' he'), (' she', ' he'), (' he', ' she'), (' he', ' she'), (' she', ' he'), (' she', ' he'), (' he', ' she'), (' he', ' she'), (' she', ' he'), (' she', ' he'), (' he', ' she'), (' he', ' she'), (' she', 

In [15]:
# ensuring all prompts are same number of tokens
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)

Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Mary', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'John', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Dan', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Amy', ' is', ' a', ' great', ' friend', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Mary', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'John', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Dan', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'Amy', ' is', ' an', ' amazing', ' person', ',', ' isn', '�', '�', 't']


We now run the model on these prompts and use run_with_cache to get both the logits and a cache of all internal activations for later analysis.

In [16]:
tokens = model.to_tokens(prompts, prepend_bos=True)
# Move the tokens to the GPU
tokens = tokens.cuda()
# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the **logit difference**, the difference in logit between the correct pronoun and the incorrect pronoun (eg, `logit( she)-logit( he)`). 

In [17]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens).item())

Per prompt logit difference: tensor([3.8977, 4.8540, 3.9663, 3.4600, 5.1857, 4.7339, 3.7226, 4.7266, 3.9201,
        4.4638, 3.5116, 3.5193, 4.1234, 3.2621, 2.4767, 3.4701, 4.3414, 4.8153,
        3.8098, 3.8759], device='cuda:0')
Average logit difference: 4.006812572479248


We see that the average logit difference is 4.0 - for context, this represents putting an $e^{4.0}\approx 55\times$ higher probability on the correct answer. 

### Direct Logit Attribution

We use `model.tokens_to_residual_directions` to map the answer tokens to residual stream direction, and then convert this to a logit difference direction for each batch.

In [18]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([20, 2, 768])
Logit difference directions shape: torch.Size([20, 768])


To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer. 

In [19]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff.item())

Final residual stream shape: torch.Size([20, 11, 768])
Calculated average logit diff: 4.006814002990723
Original logit difference: 4.006812572479248


### Logit Lens

We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers. 

In [20]:
def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

In [21]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

We see that the model is utterly unable to do the task until layer 8 and then the performance starts to increase from there in a step fashion with jumps at attention part of the layer.

### Layer Attribution

We can repeat the above analysis but for each layer (this is equivalent to the differences between adjacent residual streams)

In [22]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

We see that only attention layers matter! And again we note that attention layer 9, 10 and 11 improves things a lot.

### Head Attribution

We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.

In [23]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


We see that only a few heads really matter - heads L9H7, L10H9 and L11H8 contribute a lot positively (explaining why attention layer 9, 10 and 11 are so important). There are also several heads that matter positively or negatively but less strongly.

### Attention Analysis

We use Anthropic's PySvelte library to visualize the attention patterns! We visualize the top 3 positive heads by direct logit attribution, and show these for the first prompt (as an illustration).

In [24]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]], 
    local_cache: Optional[ActivationCache]=None, 
    local_tokens: Optional[torch.Tensor]=None, 
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = tokens[0]
    
    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

In [25]:
top_k = 3
top_positive_logit_attr_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads")

### Ablation

In [26]:
prompts_with_answers = [prompt + answer[0] for prompt, answer in zip(prompts, answers)]
prompts_with_answers

['Mary is a great friend, isn’t she',
 'John is a great friend, isn’t he',
 'Dan is a great friend, isn’t he',
 'Amy is a great friend, isn’t she',
 'Mary is an amazing person, isn’t she',
 'John is an amazing person, isn’t he',
 'Dan is an amazing person, isn’t he',
 'Amy is an amazing person, isn’t she',
 'Mary is a fantastic colleague, isn’t she',
 'John is a fantastic colleague, isn’t he',
 'Dan is a fantastic colleague, isn’t he',
 'Amy is a fantastic colleague, isn’t she',
 'Mary is a wonderful partner, isn’t she',
 'John is a wonderful partner, isn’t he',
 'Dan is a wonderful partner, isn’t he',
 'Amy is a wonderful partner, isn’t she',
 'Mary is an excellent student, isn’t she',
 'John is an excellent student, isn’t he',
 'Dan is an excellent student, isn’t he',
 'Amy is an excellent student, isn’t she']

In [27]:
# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    value[:, :, head_index_to_ablate, :] = 0.
    return value

tokens = model.to_tokens(prompts_with_answers)

original_loss =  model(tokens, loss_per_token=True, return_type="loss")[:,-1].mean()

# We make a tensor to store the results for each ablation run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
ablation_result = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for head in range(model.cfg.n_heads):

        # Use functools.partial to create a temporary hook function with the head fixed
        temp_hook_fn = partial(head_ablation_hook, head_index_to_ablate=head)
        # Run the model with the patching hook

        ablated_loss = model.run_with_hooks(
            tokens, 
            return_type="loss",
            loss_per_token=True,
            fwd_hooks=[(
                utils.get_act_name("v", layer), # try v -> o
                temp_hook_fn
                )]
            )[:,-1].mean()
        ablation_result[layer, head] = ablated_loss

model.reset_hooks()

  0%|          | 0/12 [00:00<?, ?it/s]

In [28]:
%matplotlib inline

imshow(ablation_result, midpoint=original_loss.item(), labels={"x":"Head", "y":"Layer"}, title="Ablated loss for every head")

ablating some heads like L0H9 increases the loss whereas ablating some heads like L7H3 decreases the loss

In [29]:
# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    value[:, :, head_index_to_ablate, :] = 0.
    return value

tokens = model.to_tokens(prompts)

correct_index = [model.to_single_token(answer[0]) for answer in answers]
incorrect_index = [model.to_single_token(answer[1]) for answer in answers]

original_logits = model(tokens, return_type="logits")
original_logit_diff = (original_logits[torch.arange(len(correct_index)), -1, correct_index] - original_logits[torch.arange(len(correct_index)), -1, incorrect_index]).mean()

# We make a tensor to store the results for each ablation run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
ablation_result = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for head in range(model.cfg.n_heads):

        # Use functools.partial to create a temporary hook function with the head fixed
        temp_hook_fn = partial(head_ablation_hook, head_index_to_ablate=head)
        # Run the model with the patching hook

        ablated_logits = model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=[(
                utils.get_act_name("v", layer), 
                temp_hook_fn
                )]
            )
        
        ablated_logit_diff = (ablated_logits[torch.arange(len(correct_index)), -1, correct_index] - ablated_logits[torch.arange(len(correct_index)), -1, incorrect_index]).mean()
        ablation_result[layer, head] = ablated_logit_diff

model.reset_hooks()

  0%|          | 0/12 [00:00<?, ?it/s]

In [30]:
%matplotlib inline

imshow(ablation_result, midpoint=original_logit_diff.item(), labels={"x":"Head", "y":"Layer"}, title="Ablated logit difference for every head")

ablating some heads like L4H3 decreases the logit difference whereas ablating some heads like L0H11 increases the logit difference

### Activation Patching

In [31]:
clean_prompt = "Mary is a great friend, isn’t"
corrupted_prompt = "John is a great friend, isn’t"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" she", incorrect_answer=" he"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

Clean logit difference: 3.898
Corrupted logit difference: -4.854


In [32]:
clean_tokens.shape, corrupted_tokens.shape

(torch.Size([1, 11]), torch.Size([1, 11]))

In [33]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary
def residual_stream_patching_hook(
    resid_pre: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint,
    position: int
) -> Float[torch.Tensor, "batch pos d_model"]:
    # Each HookPoint has a name attribute giving the name of the hook.
    clean_resid_pre = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre

# We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
num_positions = len(clean_tokens[0])
patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)

  0%|          | 0/12 [00:00<?, ?it/s]

In [34]:
%matplotlib inline
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(patching_result, x=token_labels, labels={"x":"Position", "y":"Layer"}, title="Normalized Logit Difference After Patching Residual Stream")

Initially, the subject (Mary) token is all that matters, and all relevant information remains here until heads in layer 4 and 5 move this to the "is" token, from where heads in layer 9 and 10 move this to the final token where it's used to predict the pronoun.

This result is consistent for larger model sizes as well.

### LayerNorm folding bias

In [35]:
he_bias = model.unembed.b_U[model.to_single_token(' he')]
she_bias = model.unembed.b_U[model.to_single_token(' she')]

print(f"he bias: {he_bias.item():.4f}")
print(f"she bias: {she_bias.item():.4f}")
print(f"Prob ratio bias: {torch.exp(he_bias - she_bias).item():.4f}x")

he bias: 4.5582
she bias: 3.6625
Prob ratio bias: 2.4490x


The bias created across the unembed due to LayerNorm folding favours " he" over " she" by about 0.9! All other things being the same, this makes the " he" token 2.4x times more likely than the " she" token.

### Analysing circuit formation during training