# Problem

Interpret a toy model trained to permute lists.

# Setup
(No need to read)

In [1]:
!git clone https://github.com/ckkissane/mech-interp-practice.git

Cloning into 'mech-interp-practice'...
remote: Enumerating objects: 351, done.[K
remote: Counting objects: 100% (170/170), done.[K
remote: Compressing objects: 100% (156/156), done.[K
remote: Total 351 (delta 103), reused 32 (delta 14), pack-reused 181[K
Receiving objects: 100% (351/351), 38.64 MiB | 12.07 MiB/s, done.
Resolving deltas: 100% (182/182), done.


In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-zopau3nk
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-zopau3nk
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 20a44fe3a8022d353c9cc7c984a8fcab14552d1c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━

In [3]:
try:
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
except:
    import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-sl3v2xsz
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-sl3v2xsz
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit df9bfc252807e8b1c3a26c3c4796c18342c7fc71
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting importlib-metadata<6.0.0,>=5.1.0 (from circuitsvis==0.0.0)
  Downloading importlib_metadata-5.2.0-py3-none-any.whl (21 kB)
Building wheels for collected packages: circuitsvis
  Building wheel for circuitsvis (pyproject.toml) ... [?25l[?25hdone
  Created wheel for circuitsvis: filename=circuitsvis-0.0.0-py3-none-any.whl size=6170923 sha256=7101b262d0720ea697e8ca0933183e55c9d6b28a592d89e9c7

In [4]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [5]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import circuitsvis as cv
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [6]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [7]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [8]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math
import pandas as pd

In [9]:
def disable_biases(model):
    for name, param in model.named_parameters():
        if 'b_' in name:
            param.requires_grad = False

def disable_pos_embed(model):
    assert model.cfg.positional_embedding_type == "standard"
    model.pos_embed.W_pos = nn.Parameter(torch.zeros_like(model.pos_embed.W_pos))
    model.pos_embed.W_pos.requires_grad = False

# Load Model

In [10]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x79c1331b76a0>

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


The smallest model I found was a 2L, 1H (per layer), attn-only transformer with no layernorm or biases. Note a 1L transformer was not able to solve this task (even with up to 8 heads).

![picture](https://drive.google.com/uc?id=15DCokm0rjNEBFn5znNJ8mEu5ATF5z7CH)

The model has already been trained, and is loaded into this notebook below:

In [12]:
LIST_LEN = 3
MAX_INT = 50
cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=1,
    d_model=128,
    d_head=128,
    n_ctx=3*LIST_LEN+3, # BOS d1 d2 d3 MID p1 p2 p3 END a1 a2 a3
    d_vocab= MAX_INT+3, # 0, ..., MAX_INT-1, BOS, MID, END
    d_vocab_out=MAX_INT,
    attn_only=True,
    normalization_type=None,
    device=device,
    seed=0,
)
model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [13]:
filename = "mech-interp-practice/models/permute_lists_model.pt"
state_dict = torch.load(filename)
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

# Task Description

The models inputs are of the form:


```
[BOS, a_0, a_1, a_2, MID, idx_0, idx_1, idx_2, END, ans_0, ans_1, ans_2]
```


For example:


```
[BOS, 44, 40, 15, MID, 2, 1, 0, END, 15, 40, 44]
```


* BOS (52), MID (51), and END (50) are special tokens at the same position in every sequence
* The first 3 tokens after BOS, [a_0, a_1, a_2], is the original list the model needs to permute. Each list token ranges between [0,49] inclusive.
* The 3 tokens after MID, [idx_0, idx_1, idx_2] are the permutation indices. These are always some shuffling of {0, 1, 2}.
* The last three tokens after END, [ans_0, ans_1, ans_2], is the permuted list after applying the perm indices. In general, ans_i = a_{idx_i}. In the concrete example above: ans_0 = a_{idx_0} = a_2 = 15.


Note that the list length is fixed to 3 for every example. The model is trained to predict the ans tokens [ans_0, ans_1, ans_2] at positions [END, ans_0, ans_1] respectively. The model is trained with a causal mask so it can't peek ahead.


Below I provide a data loader and some example tokens that you can use to start your investigation:

In [14]:
def make_data_generator(cfg, batch_size, seed=0):
    torch.manual_seed(seed)
    BOS_TOKEN = cfg.d_vocab-1
    MID_TOKEN = cfg.d_vocab-2
    END_TOKEN = cfg.d_vocab-3
    while True:
        seq = torch.randint(0, MAX_INT, (batch_size, LIST_LEN))
        perm = torch.randperm(LIST_LEN)
        ans = seq[:, perm]

        bos_tensor = einops.repeat(torch.tensor(BOS_TOKEN), " -> i 1", i=batch_size)
        mid_tensor = einops.repeat(torch.tensor(MID_TOKEN), " -> i 1", i=batch_size)
        end_tensor = einops.repeat(torch.tensor(END_TOKEN), " -> i 1", i=batch_size)

        x = torch.cat([bos_tensor, seq, mid_tensor, einops.repeat(perm, "seq -> batch seq", batch=batch_size), end_tensor, ans], dim=-1)
        yield x

batch_size = 256
data_loader = make_data_generator(cfg, batch_size, seed=42)

test_data = []
sub_batch_size = 4
for i in range(batch_size // sub_batch_size):
    test_data.append(next(make_data_generator(cfg, sub_batch_size, seed=i)))

test_data = torch.cat(test_data, dim=0).to(device)
print(test_data.shape)
print(test_data[:5])

torch.Size([256, 12])
tensor([[52, 44, 39, 33, 51,  0,  2,  1, 50, 44, 33, 39],
        [52, 10, 13, 29, 51,  0,  2,  1, 50, 10, 29, 13],
        [52, 27,  3, 47, 51,  0,  2,  1, 50, 27, 47,  3],
        [52, 33,  1, 16, 51,  0,  2,  1, 50, 33, 16,  1],
        [52, 45, 39, 24, 51,  1,  2,  0, 50, 39, 24, 45]], device='cuda:0')


The model was trained to minimize cross entropy loss. I provide the loss function used in training below.

In [15]:
def loss_fn(logits, tokens):
    logits=logits[:, -LIST_LEN-1:-1, :]
    logits = logits.to(torch.float64)
    labels = tokens[:, -LIST_LEN:]

    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    return -correct_log_probs.mean()

# Solution

## Sanity check

First let's sanity check that the model has learned to successfully solve this task on our sample tokens:

In [16]:
with torch.inference_mode():
    test_logits = model(test_data)
    test_logits = test_logits[:, -LIST_LEN-1:-1, :]
    preds = test_logits.argmax(dim=-1)
    test_labels = test_data[:, -LIST_LEN:]

    acc = (preds == test_labels).float().mean()
    print("Test sample accuracy:", acc.item())

Test sample accuracy: 1.0


## Direct logit attribution

Direct logit attribution is a great technique to localize which bits of the model directly affect the logits. We can start with the logit lens technique, which can give us intuitions about when the model has done enough processing to make a prediction.


Note the model has to make three predictions (one for each list element), but for simplicity, we'll just study the first prediction in this section, corresponding to destination token "END".

In [17]:
original_logits, cache = model.run_with_cache(test_data)
print(original_logits.shape)
print(cache)

torch.Size([256, 12, 50])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_post']


In [18]:
accumulated_resid_stack, labels = cache.accumulated_resid(layer=-1, pos_slice=-4, return_labels=True)
print(accumulated_resid_stack.shape)
print(labels)

torch.Size([3, 256, 128])
['0_pre', '1_pre', 'final_post']


In [19]:
correct_answers = test_data[:, -3]
print(correct_answers.shape)
correct_logit_dirs = model.tokens_to_residual_directions(correct_answers)
print(correct_logit_dirs.shape)

torch.Size([256])
torch.Size([256, 128])


In [20]:
accumulated_contrib = einops.einsum(
    accumulated_resid_stack,
    correct_logit_dirs,
    "... batch d_model, batch d_model -> ..."
) / test_data.shape[0]

line(
    accumulated_contrib,
    title="Logit lens: direct logit attribution for prefix of model (first prediction)",
    xaxis="Layer", yaxis="attribution",
    x=labels
)

We see that the model's is only able to predict the correct answer after the second attn layer.


We can also check the effect of each individual components output on the logits by using decomposed resid (Note we should see almost the exact same plot since the above is just the sum of the prefix of the decomposed components):

In [21]:
decomposed_resid_stack, labels = cache.decompose_resid(layer=-1, pos_slice=-4, return_labels=True)
print(decomposed_resid_stack.shape)
print(labels)

torch.Size([4, 256, 128])
['embed', 'pos_embed', '0_attn_out', '1_attn_out']


In [22]:
decomposed_contrib = einops.einsum(
    decomposed_resid_stack,
    correct_logit_dirs,
    "... batch d_model, batch d_model -> ..."
) / test_data.shape[0]

line(
    decomposed_contrib,
    title="Per component direct logit attribution for first prediction",
    xaxis="Component", yaxis="attribution",
    x=labels
)

As expected, the second attn layer has the biggest direct effect on the logits.


This attribution combined with the fact that a 1L model could not perform the task suggests the model is doing some kind of composition. A next natural step is to start trying to understand interpretable activations - in this case the attn patterns.

## Stare at 1.0 patterns

When trying to understand activations, it's often useful to work backwards, so we start with head 1.0 pattern.

Since the correct prediction is earlier in the context, one natural guess is that head 1.0 might attend to the correct token and copy it to the logits - a common motif that we've also seen in the "min of two ints" and "sort fixed length lists" toy model tutorials.

First we can look at the average pattern over the test sample to get a lay of the land:

In [23]:
def tokens_to_plotly_labels(tokens, incl_pos=True):
    if tokens.ndim==2:
        tokens = tokens[0]
    if incl_pos:
        res = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
    else:
        res = list(map(str, tokens.tolist()))
    res[0] = 'BOS'
    res[LIST_LEN+1] = 'MID'
    res[2*LIST_LEN+2]= 'END'
    return res

tokens_to_plotly_labels(test_data)
tokens_to_plotly_labels(test_data, incl_pos=False)

['BOS', '44', '39', '33', 'MID', '0', '2', '1', 'END', '44', '33', '39']

In [24]:
labels = tokens_to_plotly_labels(test_data)

layer, head_index = 1, 0
imshow(
    cache['pattern', layer][:, head_index, :, :].mean(dim=0),
    title=f"Attn head {layer}.{head_index} pattern (averged over batch)",
    xaxis="src", yaxis="dest",
    x=labels, y=labels
)

Interestingly, the prediction positions do not attend to the answer token and copy as I guessed. Instead they almost always seem to attend to the same fixed position. Notice that all 3 of the prediction positions attend to fixed perm positions.


Why would the model do this? The token and positional embeddings at these source positions clearly don't contain information about the answer, so the information must have been moved to their residual stream by attn 0.0, and then moved again by 1.0! This is the first instance of attn head composition we've seen in these tutorials.


If this hypothesis is correct, this would be an example of a "V-composition" (see https://transformer-circuits.pub/2021/framework/index.html#three-kinds-of-composition). We can also think of it as a single "virtual attention head" - effectively a single head where the prediction positions attend to the correct list token, and then copy (see https://transformer-circuits.pub/2021/framework/index.html#virtual-attention-heads)


![picture](https://drive.google.com/uc?id=1orrWb0s7zTbkVShAl8MbMytq_8nHlc_Q)

Taking a step back to summarize what we know at this point:
* Head 1.0 has the largest direct effect on the logits (DLA)
* The average of all 1.0 attn patterns showed the prediction positions (END, ans_0, ans_1) always attend to the permutation index positions (idx_0, idx_1, idx_2) respectively.

Continuing to work backwards, we can check some attn patterns for 0.0. We expect to see that the permutation tokens attend to their corresponding list positions (perm token=i attends to pos=i+1).

## Stare at 0.0 patterns

Once again I'll just start with the avg across the batch dimension:

In [25]:
labels = tokens_to_plotly_labels(test_data)

layer, head_index = 0, 0
imshow(
    cache['pattern', layer][:, head_index, :, :].mean(dim=0),
    title=f"Attn head {layer}.{head_index} patterns (avged over batch)",
    xaxis="src", yaxis="dest",
    x=labels, y=labels
)

We see that the permutation tokens mostly attend to the list positions, although this isn't very surprising. Let's zoom in on some concrete examples:

In [26]:
all_labels = [tokens_to_plotly_labels(d, incl_pos=False) for d in test_data]
cv.attention.from_cache(
    cache,
    tokens=all_labels,
    batch_idx=list(range(10)),
    heads=[(0, 0),]
)

Notice that the perm destination tokens mostly attend to the corresponding positions in the list (For example: perm token 0 always attends the pos=1, 0 indexed). This is consistent with our V composition hypothesis above: 0.0 can move the list token information to the permutation index positions, and head 1.0 can then relay that same token info to the prediction positions.

We can add this new observation to the list of facts we already know:
* Head 1.0 has the largest direct effect on the logits (DLA)
* The average of all 1.0 attn patterns showed the prediction positions (END, ans_0, ans_1) always attend to the permutation index positions (idx_0, idx_1, idx_2) respectively
* **Every 0.0 pattern we checked shows that permutation tokens (0, 1, 2) always attend to list positions (a_0, a_1, a_2) respectively**

One way to double check this is to look at the QK circuits. Recall that in the 0th layer we can decompose the attn scores into 4 terms. We expect the $W_E @ W_{QK} @ W_{pos}^T$ term to stand out, with a stripe when dest_token=i, src_pos=i+1 for i = 0, 1, 2.

## Understand 0.0 QK circuits

In [27]:
W_E = model.W_E
print(W_E.shape)
W_pos = model.W_pos
print(W_pos.shape)
W_QK = model.QK[0,0].AB
print(W_QK.shape)

torch.Size([53, 128])
torch.Size([12, 128])
torch.Size([128, 128])


In [28]:
general_labels = ["BOS", "a_0", "a_1", "a_2", "MID", "idx_0", "idx_1", "idx_2", "END", "ans_0", "ans_1", "ans_2"]
imshow(
    W_E @ W_QK @ W_pos.T,
    title="Attn 0.0 full QK circuit W_E @ W_QK @ W_pos.T",
    xaxis="src pos", yaxis="dest token",
    x=general_labels,
    aspect="equal"
)

Notice the 3x3 top left subsquare with the clear diagonal stripe. This suggests tokens (0, 1, 2) will give the highest attention scores to positions (a_0, a_1, a_2) respectively. Given that the permutation index tokens are always some shuffling of (0, 1, 2), this is exactly consistent with our interpretation of the 0.0 attn patterns above.

However there are 3 other QK circuit components for this head, so let's quickly check them to see if we can find additional structure. We can do this all at once by concatenating $W_E$ and $W_{pos}$:

In [29]:
W_E_pos = torch.cat([W_E, W_pos], dim=0)

W_E_labels = [f"W<sub>E</sub>[{i}]" for i in range(model.cfg.d_vocab)]
W_pos_labels = [f"W<sub>pos</sub>[{i}]" for i in range(model.cfg.n_ctx)]

imshow(
    W_E_pos @ W_QK @ W_E_pos.T,
    title="0.0 full QK circuits",
    xaxis="src", yaxis="dest",
    x=W_E_labels+W_pos_labels, y=W_E_labels+W_pos_labels,
)

Viewing these all at once allows us to get a sense of the general importance of each component. We notice the diagonal stripe we just observed above still stands out compared to most of the other values in this heatmap, suggesting that it explains a large part of this head's function. (Note that rigorously showing that certain components are unimportant is often difficult and requires more work. See [path patching](https://arxiv.org/pdf/2304.05969.pdf) / [causal scrubbing](https://www.lesswrong.com/posts/JvZhhzycHu2Yd57RN/causal-scrubbing-a-method-for-rigorously-testing))


Adding to the list of facts we know:
* Head 1.0 has the largest direct effect on the logits (DLA)
* The average of all 1.0 attn patterns showed the prediction positions (END, ans_0, ans_1) always attend to the permutation index positions (idx_0, idx_1, idx_2) respectively
* Every 0.0 pattern we checked shows that permutation tokens (0, 1, 2) always attend to list positions (a_0, a_1, a_2) respectively
* **0.0 full QK circuit shows clear diagonal stripe where tokens (0, 1, 2) give the highest attn scores to positions corresponding to (a_0, a_1, a_2) respectively**


Another observation from the full QK circuit is that the rows corresponding to the perm tokens in $W_EW_{QK}W_E^T$ weakly stand out compared to the other rows in that section. This might suggest that this head is specialized to move information to these positions.


One way to sanity check this is to decompose query / key vectors into their components, $q = x_0W_Q = (e+p)W_Q = eW_Q + pW_Q$ (same idea for $k$), and checking the norms. We expect to see that the query vectors are dominated by the perm token embeddings, and the keys are dominated by the positional embeddings for the list:

In [30]:
imshow(
    torch.stack(
        [(cache['embed'] @ model.W_Q[0, 0]).norm(dim=-1).mean(dim=0),
        (cache['pos_embed'] @ model.W_Q[0, 0]).norm(dim=-1).mean(dim=0)]
    ),
    title="Query component norms",
    xaxis="pos", yaxis="component",
    y=["Embed", "Pos embed"], x=labels
)

imshow(
    torch.stack(
        [(cache['embed'] @ model.W_K[0, 0]).norm(dim=-1).mean(dim=0),
        (cache['pos_embed'] @ model.W_K[0, 0]).norm(dim=-1).mean(dim=0)]
    ),
    title="Key component norms",
    xaxis="pos", yaxis="component",
    y=["Embed", "Pos embed"], x=labels
)

It's weaker than I expected, but we do see the perm token embeddings slightly stand out in the queries, and the list positional embeddings stand out for their keys.

## Understand 1.0 QK circuits

Since we were able to get some traction on understanding the QK circuit weights for 0.0, let's try the same for 1.0.

We expect that the attn pattern is mainly determined for by the positions (recall the prediction destination positions always attended very strongly to the fixed perm positions), so we expect the $W_{pos}W_{QK}W_{pos}^T$ term to explain this.

In [31]:
W_QK = model.QK[1, 0].AB
W_QK.shape

torch.Size([128, 128])

In [32]:
imshow(
    torch.tril(W_pos @ W_QK @ W_pos.T),
    title=f"1.0 full QK circuit W_pos @ W_QK @ W_pos.T",
    xaxis="src pos", yaxis='dest pos',
    x=general_labels, y=general_labels,
)

Notice the clear diagonal stripe in the subsquare for rows (END, ans_0, ans_1) and columns (idx_0, idx_1, idx_2). This suggests that the (END, ans_0, ans_1) destination positions will give the highest attention scores to the (idx_0, idx_1, idx_2) source positions respectively. This is exactly consistent with what we observed in the average 1.0 attn patterns earlier in this notebook.


Updating our understanding of this model:
* Head 1.0 has the largest direct effect on the logits (DLA)
* The average of all 1.0 attn patterns showed the prediction positions (END, ans_0, ans_1) always attend to the permutation index positions (idx_0, idx_1, idx_2) respectively.
    * **1.0 $W_{pos}W_{QK}W_{pos}^T$ QK circuit component shows clear diagonal stripe where prediction positions (END, ans_0, ans_1) give the highest scores to positions (idx_0, idx_1, idx_2) respectively**
* Every 0.0 pattern we checked shows that permutation tokens (0, 1, 2) always attend to list positions (a_0, a_1, a_2)respectively.
* 0.0 full QK circuit shows clear diagonal stripe where tokens (0, 1, 2) give the highest attn scores to positions corresponding to (a_0, a_1, a_2) respectively.

## Decompose logits

I feel somewhat convinced that 0.0 causes perm tokens to attend to their corresponding list position, and that 1.0 causes prediction positions to attend the fixed perm positions. However we have yet to determine *what* information is moved. It would be really natural for the composition of both heads to just relay the token information for the correct answer using v-comp, acting as a virtual attn head.


In my [min of two ints](https://colab.research.google.com/github/ckkissane/mech-interp-practice/blob/main/tutorials/min_of_two_ints_tutorial.ipynb) tutorial we decomposed the logits into a sum of terms. For a quick recap of why this is useful: we can decompose the logits (or more generally, any residual stream vector) into a sum of terms that each correspond to some path through the model. This allows us to localize which paths are the most important (using techniques like DLA or ablations), which helps narrow down what activations / weights to focus on interpreting. Since this is a 2 layer model there are more paths corresponding to composition between the two layers. Thus we have many more terms:

$$
\begin{align}
logits &= x_2W_U\\
&= (h_{1.0}(x_1) + x_1)W_U \\
&= (A^{1.0}x_1W_{OV}^{1.0} + x_1)W_U \\
&= (A^{1.0}(A^{0.0}x_0W_{OV}^{0.0} + x_0)W_{OV}^{1.0} + (A^{0.0}x_0W_{OV}^{0.0} + x_0))W_U \\
&= (A^{1.0}A^{0.0}x_0W_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}x_0W_{OV}^{1.0} + A^{0.0}x_0W_{OV}^{0.0} + x_0)W_U \\
&= (A^{1.0}A^{0.0}(e+p)W_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}(e+p)W_{OV}^{1.0} + A^{0.0}(e+p)W_{OV}^{0.0} + (e+p))W_U \\
&= (A^{1.0}A^{0.0}eW_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}A^{0.0}pW_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}eW_{OV}^{1.0} + A^{1.0}pW_{OV}^{1.0} + A^{0.0}eW_{OV}^{0.0} + A^{0.0}pW_{OV}^{0.0} + e + p)W_U \\
&= A^{1.0}A^{0.0}tW_EW_{OV}^{0.0}W_{OV}^{1.0}W_U + A^{1.0}A^{0.0}posW_{pos}W_{OV}^{0.0}W_{OV}^{1.0}W_U + A^{1.0}tW_EW_{OV}^{1.0}W_U + A^{1.0}posW_{pos}W_{OV}^{1.0}W_U + A^{0.0}tW_EW_{OV}^{0.0}W_U + A^{0.0}posW_{pos}W_{OV}^{0.0}W_U + tW_EW_U + posW_{pos}W_U \\
\end{align}
$$

Where $t$ and $pos$ are the tokens and positions as one hot encoded vectors respectively.

To emphasize the key idea that each term corresponds to a path through the model, I recommend you take some time to think about what paths each term corresponds to. For example, the $A^{1.0}A^{0.0}tW_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ term corresponds to a path from the token embeddings, through both attention heads, all the way to the logits:

![picture](https://drive.google.com/uc?id=18bA5FB3KnnDeJEX_4-DcmNmFlVQVo49y)

It would be a pain to study all eight terms, so we'd like to narrow down which terms are the most important. Suggestion: take a few minutes to think about what paths through the model you expect to be the most important before reading on.

Recall the hypothesis that the model is using V composition to predict the correct token. Thus we expect the path from the embedding, through both attention heads, to the logits (the $A^{1.0}A^{0.0}tW_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ term, seen in diagram above) to dominate in importance. We should be able to check this through a more refined direct logit attribution.

## Finer grained direct logit attribution

Recall we already used direct logit attribution to show $h_{1.0}(x_1)$ had the biggest effect on the logits ('Direct logit attribution' section). However we want to further decompose:
$$
h_{1.0}(x_1) = A^{1.0}A^{0.0}eW_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}A^{0.0}pW_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}eW_{OV}^{1.0} + A^{1.0}pW_{OV}^{1.0}
$$

 Now we can just apply direct logit attribution on these terms, and prioritize the terms that matter most:

In [33]:
embed_both_heads_term = einops.einsum(
    cache['pattern', 1], cache['pattern', 0], cache['embed'], model.OV[0,0].AB, model.OV[1, 0].AB,
    "batch n_heads pos1 pos2, batch n_heads pos2 pos3, batch pos3 d_model1, d_model1 d_model2, d_model2 d_model3 -> batch pos1 d_model3"
)

pos_embed_both_heads_term = einops.einsum(
    cache['pattern', 1], cache['pattern', 0], cache['pos_embed'], model.OV[0,0].AB, model.OV[1, 0].AB,
    "batch n_heads pos1 pos2, batch n_heads pos2 pos3, batch pos3 d_model1, d_model1 d_model2, d_model2 d_model3 -> batch pos1 d_model3"
)

embed_head_1_term = einops.einsum(
    cache['pattern', 1], cache['embed'], model.OV[1, 0].AB,
    "batch n_heads q_pos k_pos, batch k_pos d_model1, d_model1 d_model2 -> batch q_pos d_model2"
)

pos_embed_head_1_term = einops.einsum(
    cache['pattern', 1], cache['pos_embed'], model.OV[1, 0].AB,
    "batch n_heads q_pos k_pos, batch k_pos d_model1, d_model1 d_model2 -> batch q_pos d_model2"
)


assert torch.allclose(embed_both_heads_term + pos_embed_both_heads_term + embed_head_1_term + pos_embed_head_1_term, cache['attn_out', 1], atol=1e-5)

In [34]:
resid_stack = torch.stack([embed_both_heads_term, pos_embed_both_heads_term, embed_head_1_term, pos_embed_head_1_term])
resid_stack = resid_stack[:, :, -4, :]
print(resid_stack.shape)
labels = ['embed_both_heads_term', 'pos_embed_both_heads_term', 'embed_head_1_term', 'pos_embed_head_1_term',]
print(labels)

torch.Size([4, 256, 128])
['embed_both_heads_term', 'pos_embed_both_heads_term', 'embed_head_1_term', 'pos_embed_head_1_term']


In [35]:
per_component_contribution = einops.einsum(
    resid_stack,
    correct_logit_dirs,
    "... batch d_model, batch d_model -> ..."
) / test_data.shape[0]

bar(
    per_component_contribution,
    title="Direct logit attribution for components of attn_out_1",
    xaxis="attn_out_1 component", yaxis="attribution",
    x=labels
)

We show that the $A^{1.0}A^{0.0}eW_{OV}^{0.0}W_{OV}^{1.0}$ term is the most important by far. Note this is pretty great! We had a huge equation with 8 terms, but direct logit attribution allows us to prioritize one path through the model.

Now we can add this our existing observation that 1.0 had the highest DLA:


* Head 1.0 has the largest direct effect on the logits (DLA)
    * **In particular, decomposing 1.0s output into sub terms and applying more refined DLA shows that the "virtual attn head term" $A^{1.0}A^{0.0}eW_{OV}^{0.0}W_{OV}^{1.0}$ dominates**
* The average of all 1.0 attn patterns showed the prediction
positions (END, ans_0, ans_1) always attend to the permutation index positions (idx_0, idx_1, idx_2) respectively.
    * 1.0 $W_{pos}W_{QK}W_{pos}^T$ QK circuit component shows clear diagonal stripe where prediction positions (END, ans_0, ans_1) give the highest scores to positions (idx_0, idx_1, idx_2) respectively.
* Every 0.0 pattern we checked shows that permutation tokens (0, 1, 2) always attend to list positions (a_0, a_1, a_2)respectively.
    * 0.0 full QK circuit shows clear diagonal stripe where tokens (0, 1, 2) give the highest attn scores to positions corresponding to (a_0, a_1, a_2) respectively.


## Virtual attn pattern

Now that we know that the $A^{1.0}A^{0.0}tW_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ term explains most of the models behavior, we want to understand $A^{1.0}A^{0.0}$ and $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$. Note $A^{1.0}A^{0.0}$ tells us *where* the virtual attn head moves information from, while $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$. $A^{1.0}A^{0.0}$ tells us *what* information to move from the src to the logits (given that we've attended to it). Let's tackle the $A^{1.0}A^{0.0}$ term first. We should expect that the virtual attn head effectively just attends to the correct next token.

In [36]:
for _ in range(3):
    r = random.randint(0, test_data.shape[0]-1)
    labels = tokens_to_plotly_labels(test_data[r])
    imshow(
        cache['pattern', 1][r, 0, :, :] @ cache['pattern', 0][r, 0, :, :],
        title=f"virtual attn pattern A^1.0 @ A^0.0 for example {r}",
        xaxis="src", yaxis="dest",
        x=labels, y=labels
    )

Notice that the prediction destination tokens attend most strongly to the correct list token. This is consistent with the V-composition hypothesis, since it implies that via the composition of both heads, each list token information is routed to the correct prediction destination token position. Adding to our list of observations:

* Head 1.0 has the largest direct effect on the logits (DLA)
    * In particular, decomposing 1.0s output into sub terms and applying more refined DLA shows that the "virtual attn head term" $A^{1.0}A^{0.0}eW_{OV}^{0.0}W_{OV}^{1.0}$ dominates
* The average of all 1.0 attn patterns showed the prediction positions (END, ans_0, ans_1) always attend to the permutation index positions (idx_0, idx_1, idx_2) respectively.
    * 1.0 $W_{pos}W_{QK}W_{pos}^T$ QK circuit component shows clear diagonal stripe where prediction positions (END, ans_0, ans_1) give the highest scores to positions (idx_0, idx_1, idx_2) respectively.
* Every 0.0 pattern we checked shows that permutation tokens (0, 1, 2) always attend to list positions (a_0, a_1, a_2)respectively.
    * 0.0 full QK circuit shows clear diagonal stripe where tokens (0, 1, 2) give the highest attn scores to positions corresponding to (a_0, a_1, a_2) respectively.
* **In every example we checked, the virtual attn pattern $A^{1.0}A^{0.0}$ showed that the prediction positions effectively attended to the correct list token**

## Virtual OV circuit

Now that we are convinced the virtual attn head just attends to the correct next token, it would be natural for $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ to just copy that to the logits. We expect to see a strong diagonal as we've seen in many other circuits:

In [37]:
virtual_OV_circuit = W_E @ model.OV[0,0].AB @ model.OV[1, 0].AB @ model.W_U
print(virtual_OV_circuit.shape)

torch.Size([53, 50])


In [38]:
imshow(
    virtual_OV_circuit,
    title="Virtual OV circuit: W_E @ W_OV^0.0 @ W_OV^1.0 @ W_U",
    xaxis="logit", yaxis="src token",
    aspect="equal"
)

In [39]:
top_1_acc = (virtual_OV_circuit[:MAX_INT].argmax(dim=-1) == torch.arange(MAX_INT, device=device)).float().mean()
print(f"Fraction of time top logit is on diagonal:", top_1_acc.item())

Fraction of time top logit is on diagonal: 1.0


Notice the strong main diagonal, indicating clear copying. When combined with the observation that the virtual attention patterns attend to the correct token, this confirms the V-composition that we hypothesized: Heads 0.0 and 1.0 compose to attend to and then copy each list token information to its correct prediction position.


This completes a pretty healthy collection of evidence showing how the model uses V-composition to solve this task:

* Head 1.0 has the largest direct effect on the logits (DLA)
    * In particular, decomposing 1.0s output into sub terms and applying more refined DLA shows that the "virtual attn head term" $A^{1.0}A^{0.0}eW_{OV}^{0.0}W_{OV}^{1.0}$ dominates
* The average of all 1.0 attn patterns showed the prediction positions (END, ans_0, ans_1) always attend to the permutation index positions (idx_0, idx_1, idx_2) respectively.
    * 1.0 $W_{pos}W_{QK}W_{pos}^T$ QK circuit component shows clear diagonal stripe where prediction positions (END, ans_0, ans_1) give the highest scores to positions (idx_0, idx_1, idx_2) respectively.
* Every 0.0 pattern we checked shows that permutation tokens (0, 1, 2) always attend to list positions (a_0, a_1, a_2)respectively.
    * 0.0 full QK circuit shows clear diagonal stripe where tokens (0, 1, 2) give the highest attn scores to positions corresponding to (a_0, a_1, a_2) respectively.
* In every example we checked, the virtual attn pattern $A^{1.0}A^{0.0}$ showed that the prediction positions effectively attended to the correct list token
    * **The virtual OV circuit $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ shows clear copying (100% top 1 logit on diagonal).**

# Summary

We found that a 2L, 1H, attn only transformer with no biases or LN can learn to permute lists by leveraging V-composition:


1. attn 0.0 moves answer information from list indices to corresponding perm positions.
2. attn 1.0 routes the same information from perm positions to the prediction positions.


The lines of evidence we used are:


1. Direct logit attribution: we found the last attn layer had the greatest direct effect on the logits by far (specifically the term corresponding to the virtual attn head).
2. 1.0 attn patterns showed that the prediction positions empirically attended to fixed perm positions, and the QK circuit weights explained this.
3. 0.0 attn patterns showed that the perm token i attended to corresponding list positions i+1, and the QK circuit weights explained this.
4. The virtual attn pattern $A^{1.0}A^{0.0}$ showed that the prediction positions effectively attended to the correct list token for the examples we checked.
5. The virtual OV circuit $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ showed clear copying (100% top 1 logit on diagonal).

## General techniques you can apply to other problems

* Direct logit attribution
* Staring at attn patterns
* Decomposing activations into a sum of terms, and analyzing their importance
* Multiplying out QK / OV circuits