# Problem

Interpret a model trained to predict the output to simple code functions. E.g. predicting the bold text in problems like
$$
a = [1, 2, 3] \\
a[2] = 4 \\
a -> [\textbf{1, 2, 4}] \\
$$

# Setup
(No need to read)

In [239]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-bh2bsk4u
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-bh2bsk4u
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 3cd943628b5c415585c8ef100f65989f6adc7f75
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting typeguard<4.0.0,>=3.0.2 (from transformer-lens==0.0.0)
  Using cached typeguard-3.0.2-py3-none-any.whl (30 kB)
Installing collected packages: typeguard
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.13.3
    Uninstalling typeguard-2.13.3:
      Succe


## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
0% [Working]            Hit:1 https://deb.nodesource.com/node_16.x focal InRelease
Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  InRelease
Get:3 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease [3,622 B]
Hit:4 http://archive.ubuntu.com/ubuntu focal InRelease
Hit:5 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease
Get:6 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
Get:7 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]
Hit:8 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease
Hit:9 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease
Get:10 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]
Hit:11 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu focal InRelease
Hit:12 http://ppa.launchpad.net/ubuntugis/ppa/ubuntu focal InRelease
Fet

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [240]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [241]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [242]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [243]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [244]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

In [245]:
def visualize_attn_patterns(heads, local_tokens, local_cache, title: str = ""):
    labels = []
    patterns = []
    batch_index = 0

    for head in heads:
        if isinstance(head, tuple):
            layer, head_index = head
        else:
            layer, head_index = head // model.cfg.n_heads, head % model.cfg.n_heads
        patterns.append(local_cache["pattern", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    patterns = torch.stack(patterns, dim=-1)
    attn_viz = pysvelte.AttentionMulti(tokens=model.to_str_tokens(local_tokens[batch_index]), attention=patterns, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attn_viz.show()

# Load Model

In [246]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fdc7205d480>

In [247]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


The smallest model I found that can learn this task is a 2L, 1H, attn-only transformer with no biases or layernorms. I also use [shortformer](https://github.com/neelnanda-io/TransformerLens/blob/main/further_comments.md#shortformer-attention-positional_embeddings_type--shortformer) positional embeddings to make things cleaner.

![picture](https://drive.google.com/uc?id=115AlOlXFE24PMYugkPeAadrkrh-hkOcS)

The input is of the form: BOS a1 a2 a3 \n idx elt \n b1 b2 b3. where [a1, a2, a3] is the original list. “idx” is the index of the list we are to update. “elt” is the new element we are writing to “idx” in the list, and [b1, b2, b3] is the correct new list after the update.

In [248]:
LIST_LEN = 3
MAX_NUM = 50
cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=1,
    d_model=128,
    d_head=128,
    n_ctx=2*LIST_LEN+3+2, #BOS a1 a2 a3 \n idx elt \n b1 b2 b3
    d_vocab=MAX_NUM+2, #0,...,MAX_NUM-1, BOS, \n
    d_vocab_out=MAX_NUM,
    attn_only=True,
    normalization_type=None,
    positional_embedding_type="shortformer",
    device=device,
    seed=0
)

model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_attn_input): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [249]:
# disable biases
for name, param in model.named_parameters():
    if 'b_' in name:
        param.requires_grad = False
    print(name, param.shape, param.requires_grad)

embed.W_E torch.Size([52, 128]) True
pos_embed.W_pos torch.Size([11, 128]) True
blocks.0.attn.W_Q torch.Size([1, 128, 128]) True
blocks.0.attn.W_K torch.Size([1, 128, 128]) True
blocks.0.attn.W_V torch.Size([1, 128, 128]) True
blocks.0.attn.W_O torch.Size([1, 128, 128]) True
blocks.0.attn.b_Q torch.Size([1, 128]) False
blocks.0.attn.b_K torch.Size([1, 128]) False
blocks.0.attn.b_V torch.Size([1, 128]) False
blocks.0.attn.b_O torch.Size([128]) False
blocks.1.attn.W_Q torch.Size([1, 128, 128]) True
blocks.1.attn.W_K torch.Size([1, 128, 128]) True
blocks.1.attn.W_V torch.Size([1, 128, 128]) True
blocks.1.attn.W_O torch.Size([1, 128, 128]) True
blocks.1.attn.b_Q torch.Size([1, 128]) False
blocks.1.attn.b_K torch.Size([1, 128]) False
blocks.1.attn.b_V torch.Size([1, 128]) False
blocks.1.attn.b_O torch.Size([128]) False
unembed.W_U torch.Size([128, 50]) True
unembed.b_U torch.Size([50]) False


# Training (Optional)

Mostly boilerplate. Skip unless you want to read the training code.

## Task dataset

In [250]:
def make_data_generator(cfg, batch_size, seed=0):
    torch.manual_seed(seed)
    BOS_TOKEN = cfg.d_vocab-2
    NEW_LINE_TOKEN = cfg.d_vocab-1
    while True:
        bos_vec = (torch.ones(batch_size) * BOS_TOKEN)[:, None]
        nl_vec = (torch.ones(batch_size) * NEW_LINE_TOKEN)[:, None]

        list_toks = torch.randint(0, MAX_NUM, (batch_size, LIST_LEN))

        idx_tok = random.randint(0, LIST_LEN-1)
        idx_vec = (torch.ones(batch_size) * idx_tok)[:, None]

        elt_tok = random.randint(0, MAX_NUM-1)
        elt_vec = (torch.ones(batch_size) * elt_tok)[:, None]

        ans_toks = list_toks.clone()
        ans_toks[:, idx_tok] = elt_tok

        x = torch.cat([bos_vec, list_toks, nl_vec, idx_vec, elt_vec, nl_vec, ans_toks], dim=-1).to(torch.long)
        yield x

print(next(make_data_generator(cfg, 4)))

tensor([[50, 44, 39, 33, 51,  1, 48, 51, 44, 48, 33],
        [50, 10, 13, 29, 51,  1, 48, 51, 10, 48, 29],
        [50, 27,  3, 47, 51,  1, 48, 51, 27, 48, 47],
        [50, 33,  1, 16, 51,  1, 48, 51, 33, 48, 16]])


## Loss Fn

In [251]:
def loss_fn(logits, tokens):
    logits = logits[:, -LIST_LEN-1:-1, :]
    logits = logits.to(torch.float64)
    labels = tokens[:, -LIST_LEN:]
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    return -correct_log_probs.mean()

with torch.no_grad():
    tokens = next(make_data_generator(cfg ,4)).to(device)
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    print(loss)

tensor(3.9169, device='cuda:0', dtype=torch.float64)


In [252]:
print("uniform loss:", np.log(cfg.d_vocab_out))

uniform loss: 3.912023005428146


## Setup Optimizer / dataloader

In [253]:
lr = 1e-3
wd = 1e-2
betas = (0.9, 0.98)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

batch_size = 256
train_data_loader = make_data_generator(cfg, batch_size)

## Training Loop

In [254]:
num_epochs = 1000

train_losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
    tokens = next(train_data_loader).to(device)
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()
    train_losses.append(loss.item())

    optimizer.step()
    optimizer.zero_grad()

    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, train loss: {loss.item()}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 0, train loss: 3.9151253955364718
Epoch: 100, train loss: 2.099208163281922
Epoch: 200, train loss: 0.6149461828937424
Epoch: 300, train loss: 0.18812340879816794
Epoch: 400, train loss: 0.004011678946603272
Epoch: 500, train loss: 0.0026258528138491206
Epoch: 600, train loss: 0.0021738507043532556
Epoch: 700, train loss: 0.0004986208256978274
Epoch: 800, train loss: 0.0011405148919006432
Epoch: 900, train loss: 0.00017650809167173007


In [255]:
line(
    train_losses,
    title="Train Loss Curve",
    xaxis="Epoch", yaxis="Loss"
)

## Sanity Check

In [256]:
test_sample = next(train_data_loader).to(device)
print(test_sample.shape)

torch.Size([256, 11])


In [257]:
with torch.inference_mode():
    logits = model(test_sample)
    logits = logits[:, -LIST_LEN-1:-1, :]
    preds = logits.argmax(dim=-1)

    labels = test_sample[:, -LIST_LEN:]

    acc = (preds == labels).float().mean()
    print("Test sample accuracy:", acc.item())

Test sample accuracy: 1.0


# Reverse Engineering

## Direct logit attribution

We can start with direct logit attribution to determine what bits of the model are writing information that is directly used to make the correct prediction. For simplicity, we will just start with the first prediction rather than dealing with all 3.


As usual, we can start with the logit lens technique. Since I deliberately tried to train the smallest model, we expect that the model will need both layers to make the correct prediction


In [258]:
original_logits, cache = model.run_with_cache(test_sample)
print(original_logits.shape)
print(cache)

torch.Size([256, 11, 50])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_post']


In [259]:
first_pred_pos = -LIST_LEN-1

In [260]:
accumulated_resid_stack, labels = cache.accumulated_resid(layer=-1, pos_slice=first_pred_pos, return_labels=True)
print(accumulated_resid_stack.shape)
print(labels)

torch.Size([3, 256, 128])
['0_pre', '1_pre', 'final_post']


In [261]:
logit_dirs = model.tokens_to_residual_directions(test_sample[:, -LIST_LEN])
print(logit_dirs.shape)

torch.Size([256, 128])


In [262]:
accumulated_contrib = einops.einsum(
    accumulated_resid_stack, logit_dirs,
    "... batch d_model, batch d_model -> ..."
) / accumulated_resid_stack.shape[1]


line(
    accumulated_contrib,
    title="Logit lens: prefix contribution to correct logit dirs",
    xaxis="Layer", yaxis="Contribution",
    x=labels
)

We see that the model does seem to need both layers, as expected. We can also decompose each component to sanity check this:

In [263]:
decomposed_resid_stack, labels = cache.decompose_resid(layer=-1, pos_slice=first_pred_pos, return_labels=True)
print(decomposed_resid_stack.shape)
print(labels)

torch.Size([4, 256, 128])
['embed', 'pos_embed', '0_attn_out', '1_attn_out']


In [264]:
decomposed_contrib = einops.einsum(
    decomposed_resid_stack, logit_dirs,
    "... batch d_model, batch d_model -> ..."
) / decomposed_resid_stack.shape[1]


line(
    decomposed_contrib,
    title="Per component contribution to correct logit dirs",
    xaxis="Layer", yaxis="Contribution",
    x=labels
)

We see the output of attn head 1.0 stands out. (Note that since we are using shortformer positional embeddings, the "pos_embed" term doesn't actually mean anything)


Since this is an attn-only model with the correct answers in the context, a natural hypothesis is that the last head attends to the correct answer and copies. We can start by staring at the attn patterns for head 1.0:

## Stare at 1.0 patterns

In [265]:
layer_1_patterns = cache['pattern', 1]
print(layer_1_patterns.shape)

torch.Size([256, 1, 11, 11])


First we will look at the average pattern over the entire test sample.

In [266]:
def tokens_to_labels(tokens):
    if tokens.ndim==2:
        tokens = tokens[0]
    res = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
    res[0] = "BOS"
    res[LIST_LEN+1] = "nl_0"
    res[-LIST_LEN-1] = "nl_1"
    return res

labels = tokens_to_labels(test_sample[1])
print(labels)

['BOS', '14_1', '23_2', '48_3', 'nl_0', '1_5', '26_6', 'nl_1', '14_8', '26_9', '48_10']


In [267]:
imshow(
    layer_1_patterns[:, 0, :, :].mean(dim=0),
    title="Attn 1.0 patterns (avged over batch)",
    xaxis="src", yaxis="dest",
    x=labels, y=labels
)

Recall that all of the examples in the test sample have the same (idx, elt) pair. We see that attn 1.0 is clearly attending fully to the correct next token at the prediction positions. A natural hypothesis is that it is copying the answers. We should be able to check this by looking at the full OV circuit.

## 1.0 full OV circuit

In [268]:
W_E = model.W_E
print(W_E.shape)

W_U = model.W_U
print(W_U.shape)

torch.Size([52, 128])
torch.Size([128, 50])


In [269]:
imshow(
    W_E @ model.OV[1,0].AB @ W_U,
    title=f"1.0 Full OV circuit W_E @ W_OV^1.0 @ W_U",
    xaxis="logit", yaxis="src token"
)

In [270]:
full_OV_circuit = (W_E @ model.OV[1,0].AB @ W_U)[:cfg.d_vocab_out]
top_1_acc = (full_OV_circuit.argmax(dim=-1) == torch.arange(cfg.d_vocab_out, device=cfg.device)).float().mean()
print("Fraction of time top logit is on diagonal", top_1_acc.item())

Fraction of time top logit is on diagonal 1.0


We see clear copying. Now it seems clear that the model figures out to attend to the correct next token and then copy, but how does it know to attend to the correct next token in the first place? It must be using some kind of composition. We can start by looking at the attn patterns for 0.0 to get a sense of what information is being moved in L0:

## Stare at 0.0 patterns

In [271]:
layer_0_patterns = cache['pattern', 0]
print(layer_0_patterns.shape)

torch.Size([256, 1, 11, 11])


In [272]:
imshow(
    layer_0_patterns[:, 0, :, :].mean(dim=0),
    title="0.0 patterns (avged over batch)",
    xaxis='src', yaxis='dest',
    x=labels, y=labels
)

We notice that the "elt" token attends very strongly to the "idx" token.  Perhaps it uses this to broadcast "I am at prediction position 1", and the corresponding prediction position can just query for that. This would be an example of K-composition:

![picture](https://drive.google.com/uc?id=12IBkPV_ZZrfGySRgDDL2Kimsc7vmbHry)

We can refine this hypothesis with activation patching.

## Activation patching

In activation patching, we usually want to corrupt some vital piece of information in the input, causing the model to get a wrong answer. Then we patch in activations from the clean run to and measure how much performance we can recover. This helps us localize which specific activations from the clean run contain the most important information.

There are many important tokens in the input, but I will start by corrupting the "idx" token, since we hypothesize this is used in a K-composition circuit.

In [273]:
clean_tokens = test_sample[1]
print(clean_tokens)
print(tokens_to_labels(clean_tokens))

tensor([50, 14, 23, 48, 51,  1, 26, 51, 14, 26, 48], device='cuda:0')
['BOS', '14_1', '23_2', '48_3', 'nl_0', '1_5', '26_6', 'nl_1', '14_8', '26_9', '48_10']


In [274]:
corrupted_tokens = clean_tokens.clone()
corrupted_tokens[5] = 2
corrupted_tokens[-2] = 23
corrupted_tokens[-1] = 26
print(corrupted_tokens)
print(tokens_to_labels(corrupted_tokens))

tensor([50, 14, 23, 48, 51,  2, 26, 51, 14, 23, 26], device='cuda:0')
['BOS', '14_1', '23_2', '48_3', 'nl_0', '2_5', '26_6', 'nl_1', '14_8', '23_9', '26_10']


In [275]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
print(clean_logits.shape)
print(clean_cache)

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
print(corrupted_logits.shape)
print(corrupted_cache)

torch.Size([1, 11, 50])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_post']
torch.Size([1, 11, 50])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.at

For our patching metric we can use the logit difference between the correct answer given that we update the index (26) and the wrong answer if we didn't update (23):

In [276]:
idx_1_pred_pos = -3
correct_ans = 26
incorrect_ans = 23

clean_logit_diff = clean_logits[:, idx_1_pred_pos, correct_ans] - clean_logits[:, idx_1_pred_pos, incorrect_ans]
print(clean_logit_diff)

corrupted_logit_diff = corrupted_logits[:, idx_1_pred_pos, correct_ans] - corrupted_logits[:, idx_1_pred_pos, incorrect_ans]
print(corrupted_logit_diff)

tensor([10.8760], device='cuda:0', grad_fn=<SubBackward0>)
tensor([-16.4544], device='cuda:0', grad_fn=<SubBackward0>)


In [277]:
def patching_metric(patched_logits):
    patched_logit_diff = patched_logits[:, idx_1_pred_pos, correct_ans] - patched_logits[:, idx_1_pred_pos, incorrect_ans]
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff- corrupted_logit_diff)

print("clean metric", patching_metric(clean_logits))
print("corrupted metric", patching_metric(corrupted_logits))

clean metric tensor([1.], device='cuda:0', grad_fn=<DivBackward0>)
corrupted metric tensor([0.], device='cuda:0', grad_fn=<DivBackward0>)


We can start by patching the resid_pre activations, to get a general sense of where information is moved each layer:

In [278]:
resid_pre_patching_results = patching.get_act_patch_resid_pre(
    model,
    corrupted_tokens,
    clean_cache,
    patching_metric
)
resid_pre_patching_results.shape

  0%|          | 0/22 [00:00<?, ?it/s]

torch.Size([2, 11])

In [279]:
labels = tokens_to_labels(clean_tokens)
imshow(
    resid_pre_patching_results,
    title="Resid pre patching results (corrupting idx token)",
    xaxis="pos", yaxis="Layer",
    x=labels
)

This suggests that the updated “idx” is moved to the residual stream of the new “elt”. This is roughly what we expected by looking at the attn patterns. Now we can get more granular by patching attn outputs, queries, keys, values, and patterns for each layer.

In [280]:
attn_head_by_pos_every_patching_results = patching.get_act_patch_attn_head_by_pos_every(
    model,
    corrupted_tokens,
    clean_cache,
    patching_metric
)
print(attn_head_by_pos_every_patching_results.shape)

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

torch.Size([5, 2, 11, 1])


In [281]:
imshow(
    attn_head_by_pos_every_patching_results[..., 0],
    title="attn head patching results (corrupted idx token)",
    xaxis="pos", yaxis="Layer",
    facet_col=0,
    facet_labels=["output", "query", "key", "value", "pattern"],
    aspect="auto", x=labels
)

This is very clean and strongly suggests K-composition. Let's go layer by layer:


We see the output at the "elt" dest position, and the value at the "idx" src position in layer 0 fully recovers performance. This suggests that 0.0 moves the "idx" token information to the "elt" residual stream, as expected.


In L1, we can see that patching in the key at the "elt" source position fully recovers performance. This combined with the L0 patching results suggests that the output of 0.0 is used as a key for the "elt" position.


If this hypothesis is true, we should be able to read this off the QK circuit weights:

## K-composition circuit

The relevant circuit is:


$W_{pos}W_{QK}^{1.0}(W_{OV}^{0.0})^TW_E^T$.


We expect to see a strong diagonal for the query positions corresponding to the prediction positions, and the tokens corresponding to the “idx” tokens. This can be interpreted as "this prediction position should attend strongly to the src residual stream that attended fully to these src tokens in the previous layer"

In [282]:
W_pos = model.W_pos
print(W_pos.shape)

torch.Size([11, 128])


In [283]:
k_comp_circuit = W_pos @ model.QK[1,0].AB @ model.OV[0,0].AB.T @ W_E.T
print(k_comp_circuit.shape)

torch.Size([11, 52])


In [284]:
imshow(
    k_comp_circuit,
    yaxis="dest pos", xaxis="src token",
    y=labels,
    title="K comp circuit W_pos @ W_QK^1.0 @ (W_OV^0.0)^T @ W_E^T"
)

Notice that dominant 3x3 sub square in the bottom left with the strong positive diagonal, as expected.


A natural next question is how does the model also predict the indices that aren't updated? One hypothesis is that it just attends to the fixed position by default, but the K-comp circuit overwrites this with higher attn scores:

## 1.0 Full QK circuit

In [285]:
full_QK_circuit = W_pos @ model.QK[1,0].AB @ W_pos.T
print(full_QK_circuit.shape)

torch.Size([11, 11])


In [286]:
imshow(
    full_QK_circuit.tril(),
    title="1.0 Full QK circuit W_pos @ W_QK^1.0 @ W_pos.T",
    xaxis="src pos", yaxis="dest pos",
    x=labels, y=labels
)

Notice the 3x3 subsquare, where prediction positions attend strongly to their corresponding list positions. However the magnitude of these is about 1/2 of the magnitude of the scores for the k-comp circuit, suggesting that the k-comp circuit can "overwrites" the attn scores from this circuit.

## More refined direct logit attribution

I am convinced that attn head 1.0 can perform K-composition with 0.0 to attend to the correct answer and copy. But it's possible this is not the whole story. Direct logit attribution showed that the 1.0 output dominated, but that also includes the path through both attn heads (AKA the virtual attn head term). To check, we can try do some more refined direct logit attribution by decomposing 1_attn_out into these two terms:


$$
h_{1.0}(x_1) = A^{1.0}tW_EW_{OV}^{1.0}W_U + A^{1.0}
A^{0.0}tW_EW_{OV}^{0.0}W_{OV}^{1.0}W_U
$$


Where $t$ is the input tokens as one hot encoded vectors.

In [287]:
embed_both_heads_term = einops.einsum(
    cache['pattern', 1], cache['pattern', 0], cache['embed'], model.OV[0,0].AB, model.OV[1, 0].AB,
    "batch n_heads pos1 pos2, batch n_heads pos2 pos3, batch pos3 d_model1, d_model1 d_model2, d_model2 d_model3 -> batch pos1 d_model3"
)

embed_head_1_term = einops.einsum(
    cache['pattern', 1], cache['embed'], model.OV[1, 0].AB,
    "batch n_heads q_pos k_pos, batch k_pos d_model1, d_model1 d_model2 -> batch q_pos d_model2"
)

print(embed_both_heads_term.shape)
print(embed_head_1_term.shape)

assert torch.allclose(embed_both_heads_term  + embed_head_1_term , cache['attn_out', 1], atol=1e-3)

torch.Size([256, 11, 128])
torch.Size([256, 11, 128])


In [288]:
h1_stack = torch.stack([embed_both_heads_term, embed_head_1_term])
h1_stack = h1_stack[:, :, first_pred_pos, :]
labels = ["embed_both_heads_term", "embed_head_1_term"]
print(h1_stack.shape)
print(labels)

torch.Size([2, 256, 128])
['embed_both_heads_term', 'embed_head_1_term']


In [289]:
h1_contrib = einops.einsum(
    h1_stack, logit_dirs,
    "... batch d_model, batch d_model -> ..."
) / h1_stack.shape[1]

bar(
    h1_contrib,
    title="Decomposed 1_attn_out contribution to correct logit dirs",
    xaxis="component", yaxis="Contribution",
    x=labels
)

Clearly the K-composition circuit doesn't explain the whole story, as the virtual attn head term is also very important. We already showed that 1.0 attends fully to the correct answer token, so this kind of implies that 0.0 also moves some intermediate information in L0 that is then relayed by L1. This is pretty confusing and seems unnecessary, since the K-composition should work fine, but I suppose it has enough residual stream bandwidth to use both.


We can study this virtual attn head term, starting my staring at the virtual attn patterns:

## Stare at Virtual attn patterns

In [290]:
layer_1_patterns.shape

torch.Size([256, 1, 11, 11])

In [291]:
virtual_attn_pattern = einops.einsum(
    layer_1_patterns, layer_0_patterns,
    "batch n_heads pos1 pos2, batch n_heads pos2 pos3 -> batch n_heads pos1 pos3"
)
virtual_attn_pattern.shape

torch.Size([256, 1, 11, 11])

In [292]:
labels = tokens_to_labels(test_sample[1])
imshow(
    virtual_attn_pattern[:, 0, :, :].mean(dim=0),
    title="virtual attn pattern (avged over batch)",
    xaxis="src", yaxis="dest",
    x=labels, y=labels
)

It appears that the virtual attn head also attends to the correct answer, although not fully like 1.0. A natural hypothesis is that the virtual OV circuit also just copies:

## Virtual OV circuit

In [293]:
virtual_OV_circuit = W_E @ model.OV[0,0].AB @ model.OV[1,0].AB @ W_U
print(virtual_OV_circuit.shape)

torch.Size([52, 50])


In [294]:
imshow(
    virtual_OV_circuit,
    title="virtual OV circuit W_E @ W_OV^0.0 @ W_OV^1.0 @ W_U",
    xaxis="logit", yaxis="src token"
)

Yes, it appears to copy. Although it does not copy the index tokens (0, 1, 2).

In [295]:
top_1_acc = (virtual_OV_circuit[:cfg.d_vocab_out].argmax(dim=-1) == torch.arange(cfg.d_vocab_out, device=device)).float().mean()
print(f"Fraction of time top logit is on diagonal", top_1_acc.item())

Fraction of time top logit is on diagonal 0.9799999594688416


## Direct logit attribution on all preds

Note that the model must make 3 predictions for this task, and we only tried direct logit attribution on the first one. We can quickly do the same for the other predictions to sanity check that they are using the same circuits.

In [296]:
accumulated_resid_stack, labels = cache.accumulated_resid(layer=-1, pos_slice=(first_pred_pos,-1), return_labels=True)
print(accumulated_resid_stack.shape)
print(labels)

torch.Size([3, 256, 3, 128])
['0_pre', '1_pre', 'final_post']


In [297]:
logit_dirs = model.tokens_to_residual_directions(test_sample[:, -LIST_LEN:])
print(logit_dirs.shape)

torch.Size([256, 3, 128])


In [298]:
accumulated_contrib = einops.einsum(
    accumulated_resid_stack, logit_dirs,
    "... batch pos d_model, batch pos d_model -> pos ..."
) / logit_dirs.shape[0]


lines(
    accumulated_contrib,
    title="Logit lens: prefix contribution to correct logit dirs (all 3 predictions)",
    xaxis="Layer", yaxis="Contribution",
    x=labels,
)

In [299]:
decomposed_resid_stack, labels = cache.decompose_resid(layer=-1, pos_slice=(first_pred_pos,-1), return_labels=True)
print(decomposed_resid_stack.shape)
print(labels)

torch.Size([4, 256, 3, 128])
['embed', 'pos_embed', '0_attn_out', '1_attn_out']


In [300]:
decomposed_contrib = einops.einsum(
    decomposed_resid_stack, logit_dirs,
    "... batch pos d_model, batch pos d_model -> pos ..."
) / logit_dirs.shape[0]

lines(
    decomposed_contrib,
    title="Per component contribution to correct logit dirs (all 3 predictions)",
    xaxis="Layer", yaxis="Contribution",
    x=labels
)

We see that the direct logit attribution looks very similar for all predictions.

In [301]:
embed_both_heads_term = einops.einsum(
    cache['pattern', 1], cache['pattern', 0], cache['embed'], model.OV[0,0].AB, model.OV[1, 0].AB,
    "batch n_heads pos1 pos2, batch n_heads pos2 pos3, batch pos3 d_model1, d_model1 d_model2, d_model2 d_model3 -> batch pos1 d_model3"
)

embed_head_1_term = einops.einsum(
    cache['pattern', 1], cache['embed'], model.OV[1, 0].AB,
    "batch n_heads q_pos k_pos, batch k_pos d_model1, d_model1 d_model2 -> batch q_pos d_model2"
)

print(embed_both_heads_term.shape)
print(embed_head_1_term.shape)

assert torch.allclose(embed_both_heads_term  + embed_head_1_term , cache['attn_out', 1], atol=1e-3)

torch.Size([256, 11, 128])
torch.Size([256, 11, 128])


In [302]:
h1_stack = torch.stack([embed_both_heads_term, embed_head_1_term])
h1_stack = h1_stack[:, :, first_pred_pos:-1, :]
labels = ["embed_both_heads_term", "embed_head_1_term"]
print(h1_stack.shape)
print(labels)

torch.Size([2, 256, 3, 128])
['embed_both_heads_term', 'embed_head_1_term']


In [303]:
h1_contrib = einops.einsum(
    h1_stack, logit_dirs,
    "... batch pos d_model, batch pos d_model -> pos ..."
) / h1_stack.shape[1]

lines(
    h1_contrib,
    title="1_attn_out decomposed logit contribution (all 3 predictions)",
    xaxis="Layer", yaxis="Contribution",
    x=labels
)

We can see that the prediction corresponding to the array update (1) relies more on the k-comp circuit. But both paths are important for all 3 predictions.

# Summary

We found that a 2L, 1H attn-only transformer with shortformer positional embeddings and no biases and layer norms can learn to update fixed-length array indices with a combination of K-composition and V-composition.


The lines of evidence we used:
1. Direct logit attribution showed that the output of attn head 1.0 was the most important. More refined direct logit attribution also showed that both terms corresponding to the paths through 1.0 and both heads were very important.
2. Empirically both 1.0 attn pattern A^1.0 and the virtual attn pattern A^1.0A^0.0 showed that the prediction positions attended strongly to the correct answer.
3. We showed that both OV circuits $W_EW_{OV}^{1.0}W_U$ and $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ clearly copied.
4. Activation patching showed that patching in the key for the correct answer token fully recovered performance, suggesting K-composition.
5. We read off the K-comp algorithm from QK circuit weights.

## General techniques you can apply to other problems

* Direct logit attribution
* Staring at attention patterns
* Activation patching
* Multiplying out QK / OV circuits