# Setup

## Installation

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install -Uqq git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install -Uqq git+https://github.com/neelnanda-io/PySvelte.git
    !git clone https://github.com/curt-tigges/logic-mi.git
    !pip install -Uqq circuitsvis
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the TransformerLens code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m76.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m87.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m92.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

## Imports

In [3]:
# Import stuff
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
import pysvelte
import circuitsvis as cv

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [5]:
if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = "cpu"

In [6]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f80f2e9bfa0>

## Visualization

In [7]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [8]:
def get_attn_head_patterns(model, prompt, attn_heads):
    if isinstance(prompt, str):
        prompt = model.to_tokens(prompt)
    logits, cache = model.run_with_cache(prompt, remove_batch_dim=True)

    head_list = []
    head_name_list = []
    for layer, head in attn_heads:
        head_list.append(cache["pattern", layer, "attn"][head])
        head_name_list.append(f"L{layer}H{head}")
    attention_pattern = torch.stack(head_list, dim=0)
    tokens = model.to_str_tokens(prompt)

    return tokens, attention_pattern, head_name_list

In [9]:
def plot_attention_heads(tensor, title="", top_n=0, range_x=[0, 2.5], threshold=0.02):
    # convert the PyTorch tensor to a numpy array
    values = tensor.cpu().detach().numpy()

    # create a list of labels for each head
    labels = []
    for layer in range(values.shape[0]):
        for head in range(values.shape[1]):
            label = f"Layer {layer}, Head {head}"
            labels.append(label)

    # flatten the values array
    flattened_values = values.flatten()

    if top_n > 0:
        # get the indices of the top N values
        top_indices = flattened_values.argsort()[-top_n:][::-1]

        # filter the flattened values and labels arrays based on the top N indices
        flattened_values = flattened_values[top_indices]
        labels = [labels[i] for i in top_indices]

        # sort the values and labels in descending order
        flattened_values, labels = zip(
            *sorted(zip(flattened_values, labels), reverse=False)
        )

    # create a dataframe with the flattened values and labels
    df = pd.DataFrame({"Logit Diff": flattened_values, "Attention Head": labels})
    flat_value_array = np.array(flattened_values)
    # print sum of all values over threshold
    print(
        f"Total logit diff contribution above threshold: {flat_value_array.sum():.2f}"
    )

    # create the plot
    fig = px.bar(
        df,
        x="Logit Diff",
        y="Attention Head",
        orientation="h",
        range_x=range_x,
        title=title,
    )
    fig.show()

In [37]:
def scatter_attention_and_contribution_logic(
    model,
    head,
    prompts,
    answer_residual_directions,
    return_vals=False,
    return_fig=False,
):

    df = []

    layer, head_idx = head
    # Get the attention output to the residual stream for the head
    logits, cache = model.run_with_cache(prompts)
    per_head_residual, labels = cache.stack_head_results(
        layer=-1, pos_slice=-1, return_labels=True
    )
    scaled_residual_stack = cache.apply_ln_to_stack(
        per_head_residual, layer=-1, pos_slice=-1
    )
    head_resid = scaled_residual_stack[layer * model.cfg.n_heads + head_idx]

    # Loop over each prompt
    for i in range(len(answer_residual_directions)):
        # Get attention values
        tokens, attn, names = get_attn_head_patterns(model, prompts[i], [head])

        # For IO
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][0]
        )

        # Get the attention probability to the correct answer
        prob = attn[0, 16, 4]
        df.append([prob, dot, "Descriptor", prompts[i]])


    # Plot the results
    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Descriptor", f"Dot w Descriptor Embed", "Word Type", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Descriptor",
        y=f"Dot w Descriptor Embed",
        color="Word Type",
        hover_data=["text"],
        color_discrete_sequence=["rgb(114,255,100)", "rgb(201,165,247)"],
        title=f"How Strong {layer}.{head_idx} Writes in the Descriptor Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

## Load Model

In [11]:
source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")

model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-160m",
    hf_model=source_model,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    #refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/375M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


# Data Setup

In [12]:
#example_prompt = "John moved the chair into the dining room. Jane then pushed the chair to the deck. Jack retrieves the chair from the "
#example_prompt =  "Mary went to the school to pick up her daughter, Sarah. Mary is Sarah's"
#example_prompt = "All dogs are mammals, and no mammals are cold-blooded. Therefore, all dogs are"
example_prompt = "All flowers are angiosperms. Daisies are flowers. Therefore, daisies are"
example_answer = " ang"
print(len(model.to_str_tokens(example_prompt)))
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

21
Tokenized prompt: ['<|endoftext|>', 'All', ' flowers', ' are', ' ang', 'ios', 'per', 'ms', '.', ' D', 'ais', 'ies', ' are', ' flowers', '.', ' Therefore', ',', ' da', 'is', 'ies', ' are']
Tokenized answer: [' ang']


Top 0th token. Logit: 30.60 Prob: 15.97% Token: | flowers|
Top 1th token. Logit: 29.71 Prob:  6.53% Token: | plants|
Top 2th token. Logit: 29.48 Prob:  5.17% Token: | the|
Top 3th token. Logit: 28.62 Prob:  2.20% Token: | ang|
Top 4th token. Logit: 28.55 Prob:  2.04% Token: | a|
Top 5th token. Logit: 28.06 Prob:  1.26% Token: | also|
Top 6th token. Logit: 28.01 Prob:  1.19% Token: | flower|
Top 7th token. Logit: 27.93 Prob:  1.10% Token: | not|
Top 8th token. Logit: 27.72 Prob:  0.89% Token: | plant|
Top 9th token. Logit: 27.63 Prob:  0.81% Token: | floral|


In [13]:
import itertools
def generate_combinations(categories, descriptors, capitalized_nouns, answer_first_tokens, wrong_answer_first_tokens):
    prompts = []
    answers = []
    opposite_prompts = []
    # shift the descriptors by one so that the first descriptor is the answer
    opposites = descriptors[1:] + [descriptors[0]]
    for category, descriptor, capitalized_noun in itertools.product(categories, descriptors, capitalized_nouns):
        lowercase_noun = capitalized_noun.lower()
        opposite = opposites[descriptors.index(descriptor)]
        example = f"All {category} are {descriptor}. {capitalized_noun} are {category}. Therefore, {lowercase_noun} are"
        opposite = f"All {category} are {opposite}. {capitalized_noun} are {category}. Therefore, {lowercase_noun} are"
        prompts.append(example)
        opposite_prompts.append(opposite)
        # get index of answer
        answer_index = descriptors.index(descriptor)
        answers.append((answer_first_tokens[answer_index], wrong_answer_first_tokens[answer_index]))

    prompt_tokens = [model.to_tokens(prompt).squeeze() for prompt in prompts]
    opposites_tokens = [model.to_tokens(prompt).squeeze() for prompt in opposite_prompts]
    answer_tokens = [[model.to_single_token(answer) for answer in answer_pair] for answer_pair in answers]
    answer_tokens = torch.tensor(answer_tokens, device=device)
    return prompts, opposite_prompts, answers, prompt_tokens, opposites_tokens, answer_tokens 

In [14]:
def check_len(items, specified_len, check_lowercase=False, lowercase_spec_len=None):
    new_items = []
    for item in items:
        length = len(model.to_str_tokens(" " + item)) - 1
        lowercase_len = len(model.to_str_tokens(" " + item.lower())) - 1
        if check_lowercase:
            if length == specified_len and lowercase_len == lowercase_spec_len:
                new_items.append(item)
        else:
            if length == specified_len:
                new_items.append(item)
    return new_items

In [15]:
categories = ['birds', 'fish', 'insects', 'mammals', 'primates', 'rodents', 'predators', 'prey', 'game']
categories = check_len(categories, 1)

descriptors = [
    'endothermic',
    'ectothermic',
    'herbivorous',
    'carnivorous',
    'bipedal',
    'quadrupedal',
    'exoskeletal',
    'endoskeletal',
    'digitigrade',
    'plantigrade'
    ]
descriptors = check_len(descriptors, 3)

answer_first_tokens = [model.to_str_tokens(" " + d)[1] for d in descriptors]
wrong_answer_first_tokens = []
for i in range(len(answer_first_tokens)):
    if i % 2 == 0:
        wrong_answer_first_tokens.append(answer_first_tokens[i + 1])
    else:
        wrong_answer_first_tokens.append(answer_first_tokens[i - 1])

capitalized_nouns = ['Dogs', 'Elephants', 'Lions', 'Zebras', 'Cows', 'Horses',
'Pigs', 'Monkeys', 'Wolves', 'Deer', 'Foxes', 'Bears', 'Mice', 'Bats', 'Snakes',
'Lizards', 'Seals', 'Whales', 'Cats', 'Squid', 'Dolphins', 'Crabs', 'Snails',
'Bees', 'Wasps', 'Emus', 'Ducks', 'Owls', 'Pigs', 'Goats'
]
capitalized_nouns = check_len(capitalized_nouns, 2, check_lowercase=True, lowercase_spec_len=1)

prompts, opposites_prompts, answers, prompt_tokens, opposites_tokens, answer_tokens = generate_combinations(categories, descriptors, capitalized_nouns, answer_first_tokens, wrong_answer_first_tokens)
len(prompts)

1440

In [16]:
def get_logit_diff(logits, answer_token_indices, print_per_prompt=False):
    """Gets the difference between the logits of the provided tokens (e.g., the correct and incorrect tokens in IOI)

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the tokens to compare.

    Returns:
        torch.Tensor: Difference between the logits of the provided tokens.
    """
    if len(logits.shape) == 3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    if print_per_prompt:
        print(correct_logits - incorrect_logits)

    return (correct_logits - incorrect_logits).mean()

In [17]:
# check answer tokens
for i in range(len(prompts)):
    if prompt_tokens[i][4] != answer_tokens[i][0]:
        print(prompt_tokens[i])

clean_prompts = []
clean_opposites_prompts = []
clean_answer_tokens = []

# check logit differences
for i in range(len(prompts)):
    test_logits, test_cache = model.run_with_cache(prompt_tokens[i].unsqueeze(0).to(device))
    opposite_logits, opposite_cache = model.run_with_cache(opposites_tokens[i].unsqueeze(0).to(device))
    logit_diffs = get_logit_diff(test_logits, answer_tokens[i].unsqueeze(0))
    opposite_logit_diffs = get_logit_diff(opposite_logits, answer_tokens[i].unsqueeze(0))

    if logit_diffs.item() > 3 and opposite_logit_diffs.item() < -3:
        clean_prompts.append(prompts[i])
        clean_opposites_prompts.append(opposites_prompts[i])
        clean_answer_tokens.append(answer_tokens[i])

In [18]:
# take random 128 samples from clean dataset

import random

random.seed(42)
dataset_size = len(clean_prompts)
random_indices = random.sample(range(dataset_size), 64)
#random_indices.sort()

clean_prompts = [clean_prompts[i] for i in random_indices]
clean_opposites_prompts = [clean_opposites_prompts[i] for i in random_indices]
clean_answer_tokens = torch.stack([clean_answer_tokens[i] for i in random_indices])

In [19]:
def pair_opposites(prompts_list, opposites_prompts_list, answer_tokens_list):
    prompts = []
    answer_tokens = []
    for i in range(len(prompts_list)):
        if prompts_list[i] not in prompts:
            prompts.append(prompts_list[i])
            answer_tokens.append(answer_tokens_list[i])
        if opposites_prompts_list[i] not in prompts:
            prompts.append(opposites_prompts_list[i])
            answer_tokens.append(torch.tensor([answer_tokens_list[i][1], answer_tokens_list[i][0]], device=device))
    return prompts, answer_tokens

prompts, answer_tokens = pair_opposites(
    clean_prompts, 
    clean_opposites_prompts,
    clean_answer_tokens)

In [20]:
prompts[:10], answer_tokens[:10]

(['All game are herbivorous. Wolves are game. Therefore, wolves are',
  'All game are carnivorous. Wolves are game. Therefore, wolves are',
  'All fish are bipedal. Snakes are fish. Therefore, snakes are',
  'All fish are quadrupedal. Snakes are fish. Therefore, snakes are',
  'All birds are herbivorous. Cows are birds. Therefore, cows are',
  'All birds are carnivorous. Cows are birds. Therefore, cows are',
  'All primates are herbivorous. Elephants are primates. Therefore, elephants are',
  'All primates are carnivorous. Elephants are primates. Therefore, elephants are',
  'All mammals are digitigrade. Whales are mammals. Therefore, whales are',
  'All mammals are plantigrade. Whales are mammals. Therefore, whales are'],
 [tensor([23008, 27771], device='cuda:0'),
  tensor([27771, 23008], device='cuda:0'),
  tensor([15086, 42227], device='cuda:0'),
  tensor([42227, 15086], device='cuda:0'),
  tensor([23008, 27771], device='cuda:0'),
  tensor([27771, 23008], device='cuda:0'),
  tensor(

In [21]:
clean_tokens = model.to_tokens(prompts)
corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.stack(answer_tokens)


Clean string 0 <|endoftext|>All game are herbivorous. Wolves are game. Therefore, wolves are
Corrupted string 0 <|endoftext|>All game are carnivorous. Wolves are game. Therefore, wolves are


In [22]:
answer_token_indices.shape

torch.Size([124, 2])

In [23]:
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    prompt_length = len(str_tokens)
    if prompt_length != 17:
      print(prompt)

# Metrics & Baselines

In [24]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 4.5068
Corrupted logit diff: -4.5068


# Initial Logit Attribution

In [25]:
answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([124, 2, 768])
Logit difference directions shape: torch.Size([124, 768])


In [26]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = clean_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",clean_logit_diff)

Final residual stream shape: torch.Size([124, 17, 768])
Calculated average logit diff: 4.506847381591797
Original logit difference: 4.5068464279174805


In [27]:
def residual_stack_to_logit_diff(residual_stack, logit_diff_directions, prompts, cache):
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)

## Logit Lens

In [28]:
accumulated_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=False, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, logit_diff_directions, prompts, clean_cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers+1), hover_name=labels, title="Logit Difference From Accumulated Residual Stream")

## Layer Attribution

In [29]:
per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, logit_diff_directions, prompts, clean_cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

## Head Attribution

In [30]:
per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, logit_diff_directions, prompts, clean_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


In [31]:
plot_attention_heads(per_head_logit_diffs/clean_logit_diff, top_n=12, range_x=[0, .3])

Total logit diff contribution above threshold: 0.80


## Attention Analysis

In [32]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]], 
    local_cache: Optional[ActivationCache]=None, 
    local_tokens: Optional[torch.Tensor]=None, 
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = clean_cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = clean_tokens[0]
    
    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

In [33]:
top_k = 12
top_positive_logit_attr_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads")
top_k = 4
top_negative_logit_attr_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_negative_logit_attr_heads, title=f"Top {top_k} Negative Logit Attribution Heads")

pysvelte components appear to be unbuilt or stale
Running npm install...
Building pysvelte components with webpack...


In [47]:
potential_wmh = [(8, 2), (10, 10), (9, 6), (10, 8), (8, 8), (9, 8)]
mystery_heads = [(8, 10), (8, 11), (7, 8)] 
wrong_token_heads = [(10, 6), (10, 11), (10, 1)]

In [42]:
for h in potential_wmh:
    scatter_attention_and_contribution_logic(model, h, clean_tokens, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


In [44]:
for h in wrong_token_heads:
    scatter_attention_and_contribution_logic(model, h, clean_tokens, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


In [45]:
for h in mystery_heads:
    scatter_attention_and_contribution_logic(model, h, clean_tokens, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


In [46]:
contextual_wmh = [(8, 2), (10, 10), (9, 6), (10, 8), (8, 8), (9, 8), (10, 6), (10, 11)]

# Activation Patching

# Old

## Visualizing Attention Patterns

We can validate this by looking at the attention patterns of these heads! Let's take the top 10 heads by output patching (in absolute value) and split it into early, middle and late.

We see that middle heads attend from the final token to the second subject, and late heads attend from the final token to the indirect object, which is completely consistent with the above speculation! But weirdly, while *one* early head attends from the second subject to its first copy, the other two mysteriously attend to the word *after* the first copy.

In [None]:
top_k = 10
top_heads_by_output_patch = torch.topk(patched_head_z_diff.abs().flatten(), k=top_k).indices
first_mid_layer = 7
first_late_layer = 9
early_heads = top_heads_by_output_patch[top_heads_by_output_patch<model.cfg.n_heads * first_mid_layer]
mid_heads = top_heads_by_output_patch[torch.logical_and(model.cfg.n_heads * first_mid_layer<=top_heads_by_output_patch, top_heads_by_output_patch<model.cfg.n_heads * first_late_layer)]
late_heads = top_heads_by_output_patch[model.cfg.n_heads * first_late_layer<=top_heads_by_output_patch]
visualize_attention_patterns(early_heads, title=f"Top Early Heads")
visualize_attention_patterns(mid_heads, title=f"Top Middle Heads")
visualize_attention_patterns(late_heads, title=f"Top Late Heads")

## Comparing to the Paper

We can now refer to the (far, far more rigorous and detailed) analysis in the paper to compare our results! Here's the diagram they give of their results. 

![](https://pbs.twimg.com/media/FghGkTAWAAAmkhm.jpg)

(Head 1.2 in their notation is L1H2 in my notation etc. And note - in the [latest version of the paper](https://arxiv.org/pdf/2211.00593.pdf) they add 9.0 as a backup name mover, and remove 11.3)

The heads form three categories corresponding to the early, middle and late categhories we found and we did fairly well! Definitely not perfect, but with some fairly generic techniques and some a priori reasoning, we found the broad strokes of the circuit and what it looks like. We focused on the most important heads, so we didn't find all relevant heads in each category (especially not the heads in brackets, which are more minor), but this serves as a good base for doing more rigorous and involved analysis, especially for finding the *complete* circuit (ie all of the parts of the model which participate in this behaviour) rather than just a partial and suggestive circuit. Go check out [their paper](https://arxiv.org/abs/2211.00593) or [our interview](https://www.youtube.com/watch?v=gzwj0jWbvbo) to learn more about what they did and what they found!

Breaking down their categories:

* Early: The duplicate token heads, previous token heads and induction heads. These serve the purpose of detecting that the second subject is duplicated and which earlier name is the duplicate.
    * We found a direct duplicate token head which behaves exactly as expected, L3H0. Heads L5H0 and L6H9 are induction heads, which explains why they don't attend directly to the earlier copy of John!
    * Note that the duplicate token heads and induction heads do not compose with each other - both directly add to the S-Inhibition heads. The diagram is somewhat misleading.
* Middle: They call these S-Inhibition heads - they copy the information about the duplicate token from the second subject to the to token, and their output is used to *inhibit* the attention paid from the name movers to the first subject copy. We found all these heads, and had a decent guess for what they did.
    * In either case they attend to the second subject, so the patch that mattered was their value vectors!
* Late: They call these name movers, and we found some of them. They attend from the final token to the indirect object name and copy that to the logits, using the S-Inhibition heads to inhibit attention to the first copy of the subject token.
    * We did find their surprising result of *negative* name movers - name movers that inhibit the correct answer!
    * They have an entire category of heads we missed called backup name movers - we'll get to these later.

So, now, let's dig into the two anomalies we missed - induction heads and backup name mover heads

# Bonus: Exploring Anomalies

## Early Heads are Induction Heads(?!)

A really weird observation is that some of the early heads detecting duplicated tokens are induction heads, not just direct duplicate token heads. This is very weird! What's up with that? 

First off, what's an induction head? An induction head is an important type of attention head that can detect and continue repeated sequences. It is the second head in a two head induction circuit, which looks for previous copies of the current token and attends to the token *after* it, and then copies that to the current position and predicts that it will come next. They're enough of a big deal that [we wrote a whole paper on them](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html).

![](https://pbs.twimg.com/media/FNWAzXjVEAEOGRe.jpg)

Second, why is it surprising that they come up here? It's surprising because it feels like overkill. The model doesn't care about *what* token comes after the first copy of the subject, just that it's duplicated. And it already has simpler duplicate token heads. My best guess is that it just already had induction heads around and that, in addition to their main function, they *also* only activate on duplicated tokens. So it was useful to repurpose this existing machinery. 

This suggests that as we look for circuits in larger models life may get more and more complicated, as components in simpler circuits get repurposed and built upon. 

We can verify that these are induction heads by running the model on repeated text and plotting the heads.

In [None]:
example_text = "Research in mechanistic interpretability seeks to explain behaviors of machine learning models in terms of their internal components."
example_repeated_text = example_text + example_text
example_repeated_tokens = model.to_tokens(example_repeated_text, prepend_bos=True)
example_repeated_logits, example_repeated_cache = model.run_with_cache(example_repeated_tokens)
induction_head_labels = [81, 65]
visualize_attention_patterns(induction_head_labels, example_repeated_cache, example_repeated_tokens, title="Induction Heads")

One implication of this is that it's useful to categories heads according to whether they occur in simpler circuits, so that as we look for more complex circuits we can easily look for them. This is easy to do here! An interesting fact about induction heads is that they work on a sequence of repeated random tokens - notable for being wildly off distribution from the natural language GPT-2 was trained on. Being able to predict a model's behaviour off distribution is a good mark of success for mechanistic interpretability! This is a good sanity check for whether a head is an induction head or not. 

We can characterise an induction head by just giving a sequence of random tokens repeated once, and measuring the average attention paid from the second copy of a token to the token after the first copy. At the same time, we can also measure the average attention paid from the second copy of a token to the first copy of the token, which is the attention that the induction head would pay if it were a duplicate token head, and the average attention paid to the previous token to find previous token heads.

Note that this is a superficial study of whether something is an induction head - we totally ignore the question of whether it actually does boost the correct token or whether it composes with a single previous head and how. In particular, we sometimes get anti-induction heads which suppress the induction-y token (no clue why!), and this technique will find those too . But given the previous rigorous analysis, we can be pretty confident that this picks up on some true signal about induction heads.

<details> <summary>Technical Implementation Details</summary> 
We can do this again by using hooks, this time just to access the attention patterns rather than to intervene on them. 

Our hook function acts on the attention pattern activation. This has the name "blocks.{layer}.{layer_type}.hook_{activation_name}" in general, here it's "blocks.{layer}.attn.hook_attn". And it has shape [batch, head_index, query_pos, token_pos]. Our hook function takes in the attention pattern activation, calculates the score for the relevant type of head, and write it to an external cache.

We add in hooks using `model.run_with_hooks(tokens, fwd_hooks=[(names_filter, hook_fn)])` to temporarily add in the hooks and run the model, getting the resulting output. Previously names_filter was the name of the activation, but here it's a boolean function mapping activation names to whether we want to hook them or not. Here it's just whether the name ends with hook_attn. hook_fn must take in the two inputs activation (the activation tensor) and hook (the HookPoint object, which contains the name of the activation and some metadata such as the current layer).

Internally our hooks use the function `tensor.diagonal`, this takes the diagonal between two dimensions, and allows an arbitrary offset - offset by 1 to get previous tokens, seq_len to get duplicate tokens (the distance to earlier copies) and seq_len-1 to get induction heads (the distance to the token *after* earlier copies). Different offsets give a different length of output tensor, and we can now just average to get a score in [0, 1] for each head
</details>

In [None]:
seq_len = 100
batch_size = 2

prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device="cuda")
def prev_token_hook(pattern, hook):
    layer = hook.layer()
    diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)
    # print(diagonal)
    # print(pattern)
    prev_token_scores[layer] = einops.reduce(diagonal, "batch head_index diagonal -> head_index", "mean")
duplicate_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device="cuda")
def duplicate_token_hook(pattern, hook):
    layer = hook.layer()
    diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)
    duplicate_token_scores[layer] = einops.reduce(diagonal, "batch head_index diagonal -> head_index", "mean")
induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device="cuda")
def induction_hook(pattern, hook):
    layer = hook.layer()
    diagonal = pattern.diagonal(offset=seq_len-1, dim1=-1, dim2=-2)
    induction_scores[layer] = einops.reduce(diagonal, "batch head_index diagonal -> head_index", "mean")
original_tokens = torch.randint(100, 20000, size=(batch_size, seq_len))
repeated_tokens = einops.repeat(original_tokens, "batch seq_len -> batch (2 seq_len)").cuda()

pattern_filter = lambda act_name: act_name.endswith("hook_attn")
loss = model.run_with_hooks(repeated_tokens, return_type="loss", fwd_hooks=[(pattern_filter, prev_token_hook), (pattern_filter, duplicate_token_hook), (pattern_filter, induction_hook)])
print(utils.get_corner(prev_token_scores))
print(utils.get_corner(duplicate_token_scores))
print(utils.get_corner(induction_scores))

tensor([[3.9065e-02, 4.5021e-04, 3.3892e-02],
        [1.9221e-01, 1.7122e-01, 6.6501e-02],
        [1.5884e-01, 2.0945e-02, 4.8036e-01]], device='cuda:0')
tensor([[0.0034, 0.1338, 0.0062],
        [0.0002, 0.0001, 0.0016],
        [0.0024, 0.0113, 0.0008]], device='cuda:0')
tensor([[3.5712e-03, 4.1425e-05, 4.0030e-03],
        [6.6889e-04, 3.8370e-04, 1.5733e-03],
        [2.8047e-03, 9.3957e-03, 8.2193e-04]], device='cuda:0')


We can now plot the head scores, and instantly see that the relevant early heads are induction heads or duplicate token heads (though also that there's a lot of induction heads that are *not* use - I have no idea why!). 

In [None]:

imshow(prev_token_scores, labels={"x":"Head", "y":"Layer"}, title="Previous Token Scores")
imshow(duplicate_token_scores, labels={"x":"Head", "y":"Layer"}, title="Duplicate Token Scores")
imshow(induction_scores, labels={"x":"Head", "y":"Layer"}, title="Induction Head Scores")

The above suggests that it would be a useful bit of infrastructure to have a "wiki" for the heads of a model, giving their scores according to some metrics re head functions, like the ones we've seen here. TransformerLens makes this easy to make, as just changing the name input to `HookedTransformer.from_pretrained` gives a different model but in the same architecture, so the same code should work. If you want to make this, I'd love to see it! 

As a proof of concept, [I made a mosaic of all induction heads across the 40 models then in TransformerLens](https://www.neelnanda.io/mosaic).

![](https://firebasestorage.googleapis.com/v0/b/firescript-577a2.appspot.com/o/imgs%2Fapp%2FNeelNanda%2F5vtuFmdzt_.png?alt=media&token=4d613de4-9d14-48d6-ba9d-e591c562d429)

## Backup Name Mover Heads

Another fascinating anomaly is that of the **backup name mover heads**. A standard technique to apply when interpreting model internals is ablations, or knock-out. If we run the model but intervene to set a specific head to zero, what happens? If the model is robust to this intervention, then naively we can be confident that the head is not doing anything important, and conversely if the model is much worse at the task this suggests that head was important. There are several conceptual flaws with this approach, making the evidence only suggestive, eg that the average output of the head may be far from zero and so the knockout may send it far from expected activations, breaking internals on *any* task. But it's still an easy technique to apply to give some data.

But a wild finding in the paper is that models have **built in redundancy**. If we knock out one of the name movers, then there are some backup name movers in later layers that *change their behaviour* and do (some of) the job of the original name mover head. This means that naive knock-out will significantly underestimate the importance of the name movers.


Let's test this! Let's ablate the most important name mover (head L9H9) on just the final token using a custom ablation hook and then cache all new activations and compared performance. We focus on the final position because we want to specifically ablate the direct logit effect. When we do this, we see that naively, removing the top name mover should reduce the logit diff massively, from 3.55 to 0.57. **But actually, it only goes down to 2.99!**

<details> <summary>Implementation Details</summary> 
Ablating heads is really easy in TransformerLens! We can just define a hook on the z activation in the relevant attention layer (recall, z is the mixed values, and comes immediately before multiplying by the output weights $W_O$). z has a head_index axis, so we can set the component for the relevant head and for position -1 to zero, and return it. (Technically we could just edit in place without returning it, but by convention we always return an edited activation). 

We now want to compare all internal activations with a hook, which is hard to do with the nice `run_with_hooks` API. So we can directly access the hook on the z activation with `model.blocks[layer].attn.hook_z` and call its `add_hook` method. This adds in the hook to the *global state* of the model. We can now use run_with_cache, and don't need to care about the global state, because run_with_cache internally adds a bunch of caching hooks, and then removes all hooks after the run, *including* the previously added ablation hook. This can be disabled with the reset_hooks_end flag, but here it's useful! 
</details>

In [None]:
top_name_mover = per_head_logit_diffs.flatten().argmax().item()
top_name_mover_layer = top_name_mover//model.cfg.n_heads
top_name_mover_head = top_name_mover % model.cfg.n_heads
print(f"Top Name Mover to ablate: L{top_name_mover_layer}H{top_name_mover_head}")
def ablate_top_head_hook(z: Float[torch.Tensor, "batch pos head_index d_head"], hook):
    z[:, -1, top_name_mover_head, :] = 0
    return z
# Adds a hook into global model state
model.blocks[top_name_mover_layer].attn.hook_z.add_hook(ablate_top_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(tokens)
print(f"Original logit diff: {original_average_logit_diff}")
print(f"Post ablation logit diff: {logits_to_ave_logit_diff(ablated_logits, answer_tokens).item()}")
print(f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item()}")
print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item()}")

Top Name Mover to ablate: L9H9
Original logit diff: 3.552075147628784
Post ablation logit diff: 2.9173295497894287
Direct Logit Attribution of top name mover head: 2.985630989074707
Naive prediction of post ablation logit diff: 0.5664441585540771


So what's up with this? As before, we can look at the direct logit attribution of each head to see what's going on. It's easiest to interpret if plotted as a scatter plot against the initial per head logit difference.

And we can see a *really* big difference in a few heads! (Hover to see labels) In particular the negative name mover L10H7 decreases its negative effect a lot, adding +1 to the logit diff, and the backup name mover L10H10 adjusts its effect to be more positive, adding +0.8 to the logit diff (with several other marginal changes). (And obviously the ablated head has gone down to zero!)

In [None]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"})
scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


One natural hypothesis is that this is because the final LayerNorm scaling has changed, which can scale up or down the final residual stream. This is slightly true, and we can see that the typical head is a bit off from the x=y line. But the average LN scaling ratio is 1.04, and this should uniformly change *all* heads by the same factor, so this can't be sufficient

In [None]:
print("Average LN scaling ratio:", (cache["ln_final.hook_scale"][:, -1]/ablated_cache["ln_final.hook_scale"][:, -1]).mean().item())
print("Ablation LN scale", ablated_cache["ln_final.hook_scale"][:, -1])
print("Original LN scale", cache["ln_final.hook_scale"][:, -1])

Average LN scaling ratio: 1.0417343378067017
Ablation LN scale tensor([[18.5225],
        [17.4695],
        [17.8207],
        [17.5070],
        [17.2629],
        [18.2541],
        [16.1799],
        [17.4300]], device='cuda:0')
Original LN scale tensor([[19.5689],
        [18.3550],
        [18.2856],
        [18.6836],
        [17.4878],
        [18.8696],
        [16.4217],
        [18.6801]], device='cuda:0')


**Exercise to the reader:** Can you finish off this analysis? What's going on here? Why are the backup name movers changing their behaviour? Why is one negative name mover becoming significantly less important?