# Problem

Interpret a toy model trained to sort fixed-length lists.

# Setup
(No need to read)

In [1]:
!git clone https://github.com/ckkissane/mech-interp-practice.git

fatal: destination path 'mech-interp-practice' already exists and is not an empty directory.


In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-wonneqei
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-wonneqei
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 10d2f8a026d73eada861c7d51064f7e24d8f482c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone

## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:2 https://deb.nodesource.com/node_16.x jammy InRelease
Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:4 http://security.ubuntu.com/ubuntu jammy-security InRelease

In [3]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [4]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [5]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [6]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [7]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

In [8]:
def visualize_attn_patterns(heads, local_tokens, local_cache, title: str = ""):
    labels = []
    patterns = []
    batch_index = 0

    for head in heads:
        if isinstance(head, tuple):
            layer, head_index = head
        else:
            layer, head_index = head // model.cfg.n_heads, head % model.cfg.n_heads
        patterns.append(local_cache["pattern", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    patterns = torch.stack(patterns, dim=-1)
    attn_viz = pysvelte.AttentionMulti(tokens=model.to_str_tokens(local_tokens[batch_index]), attention=patterns, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attn_viz.show()

# Load Model

In [9]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7d5b7c620400>

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


The smallest model I found that could solve this task is a 1L, 1 head, attn-only transformer with no biases or layernorms.

The model has already been trained and is loaded into this notebook below.

In [11]:
LIST_LEN = 10
MAX_NUM = 50
cfg = HookedTransformerConfig(
    n_layers=1,
    d_model=128,
    d_head=128,
    n_ctx=LIST_LEN*2 + 2, # BOS 1 4 2 MID 1 2 4
    d_vocab=MAX_NUM+2, # 0, 1, ..., MAX_NUM-1, BOS, MID
    d_vocab_out=MAX_NUM,
    attn_only=True,
    normalization_type=None,
    device=device,
    seed=0
)

model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [12]:
filename = "mech-interp-practice/models/sort_fixed_len_list_model.pt"
state_dict = torch.load(filename)

model.load_state_dict(state_dict, strict=False);

# Task Description

The model is given inputs of the form:


[BOS, a_1, ..., a_n, MID, b_1, ..., b_n]


Where:


* BOS (50) and MID (51) are special tokens included at the same position in every example
*[a_1, ..., a_n] is the original array (unsorted). The elements range between [0,49] inclusive
* [b_1, ..., b_n] is the sorted array, sort([a_1, ..., a_n]). Thus these also range between [0, 49] inclusive


The model is trained to predict (b_1, ..., b_n) at positions (MID, ..., b_{n-1}) respectively. Note the model is trained with a causal mask, so it can not look ahead to see the answers. Duplicate elements in the array are possible. n=10 is fixed for every training example.


Below I provide a data loader and some example tokens that you can use to start investigating the model.

In [13]:
def make_data_generator(cfg, batch_size, seed=0):
    torch.manual_seed(seed)
    MID_TOKEN = cfg.d_vocab-1
    BOS_TOKEN = cfg.d_vocab-2
    while True:
        x = torch.randint(0, cfg.d_vocab_out, (batch_size, LIST_LEN))
        sorted_x = x.sort(dim=-1).values
        mid_vec = torch.ones(batch_size).unsqueeze(-1) * MID_TOKEN
        bos_vec = torch.ones(batch_size).unsqueeze(-1) * BOS_TOKEN

        tokens = torch.cat([bos_vec, x, mid_vec, sorted_x], dim=-1).to(torch.long)
        yield tokens

batch_size = 256
train_data_loader = make_data_generator(cfg, batch_size, seed=42)
test_data = next(train_data_loader).to(device)
print(test_data.shape)
print(test_data[:5])

torch.Size([256, 22])
tensor([[50, 42, 17, 26, 14, 26, 35, 20, 24,  0, 13, 51,  0, 13, 14, 17, 20, 24,
         26, 26, 35, 42],
        [50, 28, 14, 10,  4, 31, 22, 15, 45, 17,  6, 51,  4,  6, 10, 14, 15, 17,
         22, 28, 31, 45],
        [50, 49, 26, 23, 11, 49, 13, 41, 19, 37, 19, 51, 11, 13, 19, 19, 23, 26,
         37, 41, 49, 49],
        [50, 22, 30, 25, 29, 33, 14, 39, 26, 32, 10, 51, 10, 14, 22, 25, 26, 29,
         30, 32, 33, 39],
        [50, 36, 22, 27, 19,  7, 23, 43, 44, 43, 27, 51,  7, 19, 22, 23, 27, 27,
         36, 43, 43, 44]], device='cuda:0')


The model was trained to minimize cross entropy loss. The loss function used during training is provided below:

In [14]:
def loss_fn(logits, tokens):
    logits = logits[:, -(LIST_LEN+1):-1, :]
    log_probs = logits.log_softmax(dim=-1)
    labels = tokens[:, -LIST_LEN:]
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    return -correct_log_probs.mean()

# Solution

## Sanity check

First lets just sanity check that the model has actually learned this task. Note this isn't really a held out test set, just a random sample. Although it's probably fine since we have essentially infinite data.

In [15]:
with torch.inference_mode():
    logits = model(test_data)
    logits = logits[:, -(LIST_LEN+1):-1, :]
    preds = logits.argmax(dim=-1)
    labels = test_data[:, -LIST_LEN:]
    acc = (preds == labels).float().mean()
    print("accuracy on test sample:", acc.item())

accuracy on test sample: 0.9945312738418579


## Stare at attn patterns

Staring at attn patterns is a reasonable next step. Based on the ["min of two ints"](https://colab.research.google.com/github/ckkissane/mech-interp-practice/blob/main/tutorials/min_of_two_ints_tutorial.ipynb) problem, a natural guess would be that it attends to the correct answer for each prediction position. I'll start with the average over our test sample.


In [16]:
original_logits, cache = model.run_with_cache(test_data)
print(original_logits.shape)
print(cache)

torch.Size([256, 22, 50])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post']


In [17]:
layer_0_patterns = cache['pattern', 0]
print(layer_0_patterns.shape)

torch.Size([256, 1, 22, 22])


In [18]:
def tokens_to_plotly_labels(tokens):
    if tokens.ndim == 2:
        tokens = tokens[0]
    labels = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
    labels[0] = 'BOS'
    labels[LIST_LEN+1] = 'MID'
    return labels

tokens_to_plotly_labels(test_data[0])

['BOS',
 '42_1',
 '17_2',
 '26_3',
 '14_4',
 '26_5',
 '35_6',
 '20_7',
 '24_8',
 '0_9',
 '13_10',
 'MID',
 '0_12',
 '13_13',
 '14_14',
 '17_15',
 '20_16',
 '24_17',
 '26_18',
 '26_19',
 '35_20',
 '42_21']

In [19]:
imshow(
    einops.reduce(layer_0_patterns[:, 0, :, :], "batch dest src -> dest src", "mean"),
    title=f"attn 0.0 patterns avged over test sample",
    xaxis="src", yaxis="dest",
    x=tokens_to_plotly_labels(test_data[0]),
    y=tokens_to_plotly_labels(test_data[0]),
)

One interesting observation is that it almost never attends to positions after MID. This might seem counterintuitive since you would expect that to be useful information, but it actually makes sense if the model is copying the answer. This gives us a hint that the attention head might have learned a copying OV circuit - a very common motif in these toy models.


Let's check the attn patterns on some random examples to refine our hypotheses.


In [20]:
for _ in range(3):
    r = random.randint(0, test_data.shape[0]-1)
    labels = tokens_to_plotly_labels(test_data[r])
    imshow(
        layer_0_patterns[r, 0, :, :],
        title=f"attn 0.0 pattern for example {labels}",
        yaxis="dest", xaxis="src",
        x=labels, y=labels
    )

The attn patterns are a lot less clear cut than the "min of two ints" problem, but it does appear that the model is paying the most attention to src tokens in proportion to how close they are to the correct answer.


Another observation is that the prediction for the last sorted token appears to more confidently attend to the max. I could imagine the model learning a separate circuit for this position since we are always sorting fixed length lists.


## Hypothesis: OV circuit copies

Now that we know the model seems to be paying attn to the answers, a natural hypothesis is that the OV circuit just copies the token that we attend to.

One way this problem differs from the min of two ints is that we now have positional embeddings, so the math is a bit different.

![picture](https://drive.google.com/uc?id=122RKveFXuNqYBN1lnEL3d8N9YTw1Y0Jj)

$$
\begin{aligned}
logits &= x_1W_U\\
&= \left(h_{0.0}(x_0) + x_0\right)W_U\\
&= h_{0.0}(x_0)W_U + x_0W_U\\
& = A^{0.0}x_0W_{OV}^{0.0}W_U + x_0W_U\\
& = A^{0.0}(e+p)W_{OV}^{0.0}W_U + (e+p)W_U\\
& = A^{0.0}(eW_{OV}^{0.0}W_U+pW_{OV}^{0.0}W_U) + eW_U+pW_U\\
& = A^{0.0}eW_{OV}^{0.0}W_U+A^{0.0}pW_{OV}^{0.0}W_U + eW_U+pW_U\\
& = A^{0.0}tW_EW_{OV}^{0.0}W_U+A^{0.0}posW_{pos}W_{OV}^{0.0}W_U + tW_EW_U+posW_{pos}W_U\\
\end{aligned}
$$

Where $t$ is the tokens as one hot encoded vectors, and $pos$ is the position indices as one hot encoded vectors.


Intuitively the $W_EW_{OV}^{0.0}W_U$ full OV circuit should be the most important, since the src token has more valuable information than the src position, so we can start by looking at that. We expect to see a strong diagonal indicating copying behavior, similar to the min of two ints problem.

In [21]:
full_OV_circuit = model.W_E @ model.OV[0,0].AB @ model.W_U
full_OV_circuit.shape

torch.Size([52, 50])

In [22]:
imshow(
    full_OV_circuit,
    title="attn 0.0 Full OV circuit W_E @ W_OV @ W_U",
    xaxis="logit", yaxis="src token"
)

In [23]:
top_1_acc = (full_OV_circuit[:-2, :].argmax(dim=-1) == torch.arange(cfg.d_vocab_out, device=device)).float().mean()
print("fraction of time top logit is on diagonal:", top_1_acc.item())

fraction of time top logit is on diagonal: 1.0


It does appear to be copying as expected.

Now we can also look at the $W_{pos}W_{OV}W_U$ circuit to see how the position we attend to is mapped to output logits. Intuitively this isn't useful information, so we should expect something like a constant bias term.

In [24]:
imshow(
    model.W_pos @ model.OV[0, 0].AB @ model.W_U,
    title="attn 0.0 Full OV circuit W_pos @ W_OV @ W_U",
    xaxis="logit", yaxis="src position"
)

Recall that the model only attends to positions 1-10 in practice, and each of these rows look the same. Because of this, I think the hypothesis that this circuit just writes a constant bias term seems reasonable.

## Study QK circuits

We want to figure out how the attn pattern is computed by looking at the weights. Note that unlike the min of two ints problem, this toy model has positional embeddings. When we do the QK circuit math, we can decompose the attn scores into 4 different components:




$$
\begin{aligned}
A^{0.0} &= softmax(\frac{x_0W_{QK}^{0.0}x_0^T}{\sqrt{d_{head}}})^*\\
&= softmax(\frac{(e+p)W_{QK}^{0.0}(e+p)^T}{\sqrt{d_{head}}})^*\\
&= softmax(\frac{(e+p)W_{QK}^{0.0}(e^T+p^T)}{\sqrt{d_{head}}})^*\\
&= softmax(\frac{(eW_{QK}^{0.0}+pW_{QK}^{0.0})(e^T+p^T)}{\sqrt{d_{head}}})^*\\
&= softmax(\frac{eW_{QK}^{0.0}e^T+eW_{QK}^{0.0}p^T+pW_{QK}^{0.0}e^T+pW_{QK}^{0.0}p^T}{\sqrt{d_{head}}})^*\\
&= softmax(\frac{tW_EW_{QK}^{0.0}W_E^Tt^T+tW_EW_{QK}^{0.0}W_{pos}^Tpos^T+posW_{pos}W_{QK}^{0.0}W_E^Tt^T+posW_{pos}W_{QK}^{0.0}W_{pos}^Tpos^T}{\sqrt{d_{head}}})^*\\
\end{aligned}
$$

Where $t$ is the input as one hot encoded vectors and $pos$ are the position indices as one hot encoded vectors. Notice we now have 4 different components that contribute to the attn scores. Thankfully this is small enough we can still just stare at all of them.

The first one I'll check is $W_{pos}W_{QK}W_E^T$. We should expect this to be easily interpretable, since it would make sense for early positions to pay attn to smaller numbers, while later positions pay attn to big numbers.

In [25]:
full_QK_circuit = model.W_pos @ model.QK[0, 0].AB @ model.W_E.T
full_QK_circuit.shape

torch.Size([22, 52])

In [26]:
imshow(
    full_QK_circuit,
    title="attn 0.0 full QK circuit W_pos @ W_QK @ W_E.T",
    yaxis="dest pos", xaxis="src token"
)

Let's zoom in on the destination positions corresponding to the actual predictions. (11-20)

In [27]:
imshow(
    full_QK_circuit[LIST_LEN+1:-1, :-2],
    title="attn 0.0 full QK circuit W_pos @ W_QK @ W_E.T",
    yaxis="dest prediction pos", xaxis="src token"
)

This is what we expected: For the first predictions the model gives higher scores to smaller numbers, but at later positions the model gives higher scores to bigger numbers.


We can also see that the last prediction is a bit special. The model seems to put bigger scores here. This makes sense: all you have to do is attend to the max int in the sequence, so the model probably wants to dominate the other attn scores components with this circuit. This is also consistent with the attention patterns we viewed earlier.

Another natural full QK circuit to check is $W_{pos}W_{QK}W_{pos}^T$. We noticed that attn was almost never paid to positions after MID, so I expect these weights to explain that. We can mask this one, since it has the same shape as the attn scores.

In [28]:
imshow(
    torch.tril(model.W_pos @ model.QK[0, 0].AB @ model.W_pos.T),
    title="Full QK circuit W_pos @ W_QK @ W_pos.T",
    yaxis="dest pos", xaxis="src pos"
)

It is basically what we expect: This circuit gives uniform positive attn scores to the src positions corresponding to the unsorted list tokens, and negative attn scores to the positions after MID. This kind of makes sense since the numbers after MID will be <= the current token (since they are sorted). Thus these should not be the answer (unless we have consecutive duplicates), so we should not attend to them (because the OV circuit copies).

Now we can check $W_EW_{QK}W_E^T$

In [29]:
imshow(
    model.W_E @ model.QK[0,0].AB @ model.W_E.T,
    title="Full QK circuit W_E @ W_QK @ W_E.T",
    yaxis="dest token", xaxis="src token"
)

We can interpret this as "attend to src tokens that are bigger than the current destination token, but not too much bigger." One rough metric to double check this is to see the fraction of the time a destination token gives the highest attn score to dest_token + 1


In [30]:
top_1_acc = ((model.W_E @ model.QK[0,0].AB @ model.W_E.T)[:-2, :-2].max(dim=-1).indices == (torch.arange(cfg.d_vocab_out, device=device) + 1)).float().mean()
print("Fraction of the time dest_token gives highest score to dest_token+1:", top_1_acc.item())

Fraction of the time dest_token gives highest score to dest_token+1: 0.6200000047683716


Finally, the $W_EW_{QK}W_{pos}^T$ term:

In [31]:
imshow(
    model.W_E @ model.QK[0,0].AB @ model.W_pos.T,
    title="Full QK circuit W_E @ W_QK @ W_pos.T",
    yaxis="dest token", xaxis="src pos",
    x=tokens_to_plotly_labels(test_data)
)

The left half of this suggests that every token should attend to the positions corresponding to the unsorted list uniformly. The right half is confusing. We see some pretty big positive scores for positions after MID, but in practice we never attend these positions. Also the $W_{pos}W_{QK}W_{pos}^T$ circuit we viewed earlier seems to suggest the model doesn't want to attend to these at all.


I suspect these just get dominated by negative scores from the other components. We can decompose attn score activations to check this.

## Decompose attn scores

Recall that the attn scores is the sum of four different terms. We should be able to look at all of them individually on some examples to see how they interact. First we can just stare at the average of them in the test sample.

In [32]:
W_QK = model.QK[0,0].AB
W_QK.shape

torch.Size([128, 128])

In [33]:
e = cache['embed']
p = cache['pos_embed']

decomposed_e_W_QK_e = einops.einsum(
    e, W_QK, e,
    "batch q_pos d_model_q, d_model_q d_model_k, batch k_pos d_model_k -> batch q_pos k_pos"
)

decomposed_e_W_QK_p = einops.einsum(
    e, W_QK, p,
    "batch q_pos d_model_q, d_model_q d_model_k, batch k_pos d_model_k -> batch q_pos k_pos"
)

decomposed_p_W_QK_e = einops.einsum(
    p, W_QK, e,
    "batch q_pos d_model_q, d_model_q d_model_k, batch k_pos d_model_k -> batch q_pos k_pos"
)

decomposed_p_W_QK_p = einops.einsum(
    p, W_QK, p,
    "batch q_pos d_model_q, d_model_q d_model_k, batch k_pos d_model_k -> batch q_pos k_pos"
)

imshow(
    torch.stack(
        [decomposed_e_W_QK_e.mean(0).tril(),
         decomposed_e_W_QK_p.mean(0).tril(),
         decomposed_p_W_QK_e.mean(0).tril(),
         decomposed_p_W_QK_p.mean(0).tril(),]),
    facet_labels=["e @ W_QK @ e.T", "e @ W_QK @ p.T", "p @ W_QK @ e.T", "p @ W_QK @ p.T"],
    facet_col=0,
    title="all components of attn scores, avged over batch",
    xaxis="k_pos", yaxis="q_pos",
    x=tokens_to_plotly_labels(test_data),
    y=tokens_to_plotly_labels(test_data),
)

Notice that positive scores in the bottom right corner basically just get canceled out by negative scores from other components. We can also just zoom in on just one example and see the same idea:

In [34]:
batch_index = 0
e = cache['embed'][batch_index, ...].unsqueeze(0)
p = cache['pos_embed'][batch_index, ...].unsqueeze(0)


decomposed_e_W_QK_e = einops.einsum(
    e, W_QK, e,
    "batch q_pos d_model_q, d_model_q d_model_k, batch k_pos d_model_k -> batch q_pos k_pos"
)

decomposed_e_W_QK_p = einops.einsum(
    e, W_QK, p,
    "batch q_pos d_model_q, d_model_q d_model_k, batch k_pos d_model_k -> batch q_pos k_pos"
)

decomposed_p_W_QK_e = einops.einsum(
    p, W_QK, e,
    "batch q_pos d_model_q, d_model_q d_model_k, batch k_pos d_model_k -> batch q_pos k_pos"
)

decomposed_p_W_QK_p = einops.einsum(
    p, W_QK, p,
    "batch q_pos d_model_q, d_model_q d_model_k, batch k_pos d_model_k -> batch q_pos k_pos"
)

imshow(
    torch.stack(
        [decomposed_e_W_QK_e.mean(0).tril(),
         decomposed_e_W_QK_p.mean(0).tril(),
         decomposed_p_W_QK_e.mean(0).tril(),
         decomposed_p_W_QK_p.mean(0).tril(),]),
    facet_labels=["e @ W_QK @ e.T", "e @ W_QK @ p.T", "p @ W_QK @ e.T", "p @ W_QK @ p.T"],
    facet_col=0,
    title="all components of attn scores, for one example",
    xaxis="k_pos", yaxis="q_pos",
    x=tokens_to_plotly_labels(test_data),
    y=tokens_to_plotly_labels(test_data),
)

When we add them up we can see that the negative scores for positions after mid dominate:

In [35]:
imshow(
    (decomposed_e_W_QK_e.mean(0) + decomposed_e_W_QK_p.mean(0) +decomposed_p_W_QK_e.mean(0) + decomposed_p_W_QK_p.mean(0)).tril(),
    title="total attn scores for one example",
    x=tokens_to_plotly_labels(test_data),
    y=tokens_to_plotly_labels(test_data),
    xaxis="src", yaxis="dest"
)

To summarize an interpretation of each component:

* $pW_{QK}p^T$: "attend to the tokens positions for the unsorted list, not the sorted prefix after MID"

* $pW_{QK}e^T$: "attend to smaller numbers at earlier positions, and bigger numbers at later positions"

* $eW_{QK}e^T$: "attend to tokens bigger than you, but ideally close in magnitude"

* $eW_{QK}p^T$: "each token should attend uniformly to all the positions for the unsorted list"

# Summary

We find that a 1L, 1 head attn only transformer with no biases or layernorms learns to sort fixed len lists with the following algorithm:


1. Attend the most to the correct token in the unsorted list with the QK circuit.
2. Copy the token you most attend to with the OV circuit


The way it attends to correct numbers is a bit more complicated, and seems to combine a few different heuristics such as:
1. attend to numbers slightly bigger than the token at the current position ($W_EW_{QK}W_E^T$)
2. Attend to smaller numbers at earlier positions, and bigger numbers at later positions ($W_{pos}W_{QK}W_E^T$)
3. Don't attend to tokens after the mid token ($W_{pos}W_{QK}W_{pos}^T$)


The lines of evidence are:
1. Attention pattern activations: We saw that the patterns attend most to the correct answers
2. OV circuit weights: When we multiplied out the OV circuit, we could see it was clearly copying.
3. QK circuit weights: When we multiplied out the full QK circuit, we could read off different heuristics from the weights.
4. Decomposing attn scores: When we decomposed attn scores into 4 separate components, the activations were consistent with our interpretations of the corresponding weights.


## General Techniques you can apply to other problems

* Staring at attn patterns
* Multiplying out full QK / OV circuits
* Decomposing activations into the sum of their components