# Problem

Interpret a toy model trained to permute lists.

# Setup
(No need to read)

In [70]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-rkgcytf4
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-rkgcytf4
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 6983358b24c933e8787be8ffd3f518d0374fdd09
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting typeguard<4.0.0,>=3.0.2 (from transformer-lens==0.0.0)
  Using cached typeguard-3.0.2-py3-none-any.whl (30 kB)
Installing collected packages: typeguard
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.13.3
    Uninstalling typeguard-2.13.3:
      Succe


## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
0% [Working]            Hit:1 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease
0% [Connecting to archive.ubuntu.com (185.125.190.39)] [Waiting for headers] [C                                                                               Hit:2 http://archive.ubuntu.com/ubuntu focal InRelease
                                                                               Get:3 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
                                                                               Get:4 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]
                                                                               Hit:5 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease
0% [3 InRelease 15.6 kB/114 kB 14%] [4 InRelease 83.7 kB/114 kB 74%] [Waiting f                                           

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [71]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [72]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [73]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [74]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor), 
        labels={"x": xaxis, "y": yaxis}, 
        template="simple_white",
        **kwargs).show(renderer)

In [75]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

In [76]:
def visualize_attn_patterns(heads, local_tokens, local_cache, title: str = ""):
    labels = []
    patterns = []
    batch_index = 0
    
    for head in heads:
        if isinstance(head, tuple):
            layer, head_index = head
        else:
            layer, head_index = head // model.cfg.n_heads, head % model.cfg.n_heads
        patterns.append(local_cache["pattern", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    patterns = torch.stack(patterns, dim=-1)
    attn_viz = pysvelte.AttentionMulti(tokens=model.to_str_tokens(local_tokens[batch_index]), attention=patterns, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attn_viz.show()

# Load Model

In [77]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fc202a46140>

In [78]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


The smallest model I found was a 2L, 1H, attn-only transformer with no layernorm or biases. A 1L transformer was not able to solve this task (even with up to 8 heads).


Note that the input is of the form: [BOS, 44, 40, 15, MID, 2, 1, 0, END, 15, 40, 44]. The first 3 tokens after BOS, [44, 40, 15], is the list the model needs to permute. I call them "list tokens". The 3 tokens after MID, [2, 1, 0] are the permutation indices. I call these the "perm tokens". The model makes its first prediction at the END token, and continues to permute the entire list. Thus I call [END, 15, 40] the "prediction tokens". The last 3 tokens, [15, 40, 44], is the correct permutation of the list.

![picture](https://drive.google.com/uc?id=15DCokm0rjNEBFn5znNJ8mEu5ATF5z7CH)

In [79]:
LIST_LEN = 3
MAX_INT = 50
cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=1,
    d_model=128,
    d_head=128,
    n_ctx=3*LIST_LEN+3, # BOS d1 d2 d3 MID p1 p2 p3 END a1 a2 a3
    d_vocab= MAX_INT+3, # 0, ..., MAX_INT-1, BOS, MID, END
    d_vocab_out=MAX_INT,
    attn_only=True,
    normalization_type=None,
    device=device,
    seed=0,
)
model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [80]:
# disable biases
for name, param in model.named_parameters():
    if 'b_' in name:
        param.requires_grad = False
    print(name, param.shape, param.requires_grad)

embed.W_E torch.Size([53, 128]) True
pos_embed.W_pos torch.Size([12, 128]) True
blocks.0.attn.W_Q torch.Size([1, 128, 128]) True
blocks.0.attn.W_K torch.Size([1, 128, 128]) True
blocks.0.attn.W_V torch.Size([1, 128, 128]) True
blocks.0.attn.W_O torch.Size([1, 128, 128]) True
blocks.0.attn.b_Q torch.Size([1, 128]) False
blocks.0.attn.b_K torch.Size([1, 128]) False
blocks.0.attn.b_V torch.Size([1, 128]) False
blocks.0.attn.b_O torch.Size([128]) False
blocks.1.attn.W_Q torch.Size([1, 128, 128]) True
blocks.1.attn.W_K torch.Size([1, 128, 128]) True
blocks.1.attn.W_V torch.Size([1, 128, 128]) True
blocks.1.attn.W_O torch.Size([1, 128, 128]) True
blocks.1.attn.b_Q torch.Size([1, 128]) False
blocks.1.attn.b_K torch.Size([1, 128]) False
blocks.1.attn.b_V torch.Size([1, 128]) False
blocks.1.attn.b_O torch.Size([128]) False
unembed.W_U torch.Size([128, 50]) True
unembed.b_U torch.Size([50]) False


# Training (Optional)

Mostly boilerplate. Skip unless you want to read the training code.

## Task Dataset

In [81]:
def make_data_generator(cfg, batch_size, seed=0):
    torch.manual_seed(seed)
    BOS_TOKEN = cfg.d_vocab-1
    MID_TOKEN = cfg.d_vocab-2
    END_TOKEN = cfg.d_vocab-3
    while True:
        seq = torch.randint(0, MAX_INT, (batch_size, LIST_LEN))
        perm = torch.randperm(LIST_LEN)
        ans = seq[:, perm]

        bos_tensor = einops.repeat(torch.tensor(BOS_TOKEN), " -> i 1", i=batch_size)
        mid_tensor = einops.repeat(torch.tensor(MID_TOKEN), " -> i 1", i=batch_size)
        end_tensor = einops.repeat(torch.tensor(END_TOKEN), " -> i 1", i=batch_size)

        x = torch.cat([bos_tensor, seq, mid_tensor, einops.repeat(perm, "seq -> batch seq", batch=batch_size), end_tensor, ans], dim=-1)
        yield x

print(next(make_data_generator(cfg, 2)))


tensor([[52, 44, 39, 33, 51,  1,  2,  0, 50, 39, 33, 44],
        [52, 10, 13, 29, 51,  1,  2,  0, 50, 13, 29, 10]])


## Loss fn

In [82]:
def loss_fn(logits, tokens):
    logits=logits[:, -LIST_LEN-1:-1, :]
    logits = logits.to(torch.float64)
    labels = tokens[:, -LIST_LEN:]

    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    return -correct_log_probs.mean()

with torch.no_grad():
    tokens = next(make_data_generator(cfg, 2)).to(device)
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    print(loss)

tensor(3.9504, device='cuda:0', dtype=torch.float64)


In [83]:
print('uniform loss:', np.log(cfg.d_vocab_out))

uniform loss: 3.912023005428146


## Setup Optimizer / Dataloader

In [84]:
lr = 1e-3
wd = 0.01
betas = (0.9, 0.98)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

batch_size = 256
train_data_loader = make_data_generator(cfg, batch_size)

## Training loop

In [85]:
num_epochs = 4000

train_losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
    tokens = next(train_data_loader).to(device)
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()
    train_losses.append(loss.item())

    optimizer.step()
    optimizer.zero_grad()

    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, train_loss: {loss.item()}")

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch: 0, train_loss: 3.9134952571706085
Epoch: 100, train_loss: 1.6822969221402797
Epoch: 200, train_loss: 1.4785068812896835
Epoch: 300, train_loss: 1.0795397263291033
Epoch: 400, train_loss: 0.7309813751495354
Epoch: 500, train_loss: 0.7055179790887276
Epoch: 600, train_loss: 0.23109689003146974
Epoch: 700, train_loss: 0.07998041691299426
Epoch: 800, train_loss: 0.03755524799578232
Epoch: 900, train_loss: 0.01263726209428748
Epoch: 1000, train_loss: 0.004681462867619774
Epoch: 1100, train_loss: 0.012007361321558628
Epoch: 1200, train_loss: 0.011887284356212124
Epoch: 1300, train_loss: 0.0018767118027489004
Epoch: 1400, train_loss: 0.0015075586712569589
Epoch: 1500, train_loss: 0.006591717328680522
Epoch: 1600, train_loss: 0.002214333290520977
Epoch: 1700, train_loss: 0.0007977900374918777
Epoch: 1800, train_loss: 0.002376661674472637
Epoch: 1900, train_loss: 0.0005226649263945612
Epoch: 2000, train_loss: 8.124520167040763e-05
Epoch: 2100, train_loss: 0.0008260170140734301
Epoch: 220

In [86]:
line(train_losses,
     title="Loss curve",
     xaxis="Epoch", yaxis="Loss")

## Sanity check

In [87]:
# get a test sample with multiple different permutations
test_batch_size = 256
test_data = []
sub_batch_size = 4
for i in range(test_batch_size // sub_batch_size):
    test_data.append(next(make_data_generator(cfg, sub_batch_size, seed=i)))

test_data = torch.cat(test_data, dim=0).to(device)
print(test_data.shape)

torch.Size([256, 12])


In [88]:
with torch.inference_mode():
    test_logits = model(test_data)
    test_logits = test_logits[:, -LIST_LEN-1:-1, :]
    preds = test_logits.argmax(dim=-1)
    test_labels = test_data[:, -LIST_LEN:]
    
    acc = (preds == test_labels).float().mean()
    print("Test sample accuracy:", acc.item())

Test sample accuracy: 1.0


# Reverse Engineering

## Direct logit attribution

Direct logit attribution is a great technique to localize which bits of the model directly affect the logits. We can start with the logit lens technique, which can give us intuitions about when the model has done enough processing to make a prediction.


Note the model has to make three predictions (one for each list element), but for simplicity, we'll just study the first prediction in this section, corresponding to destination token "END".

In [89]:
original_logits, cache = model.run_with_cache(test_data)
print(original_logits.shape)
print(cache)

torch.Size([256, 12, 50])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_post']


In [90]:
accumulated_resid_stack, labels = cache.accumulated_resid(layer=-1, pos_slice=-4, return_labels=True)
print(accumulated_resid_stack.shape)
print(labels)

torch.Size([3, 256, 128])
['0_pre', '1_pre', 'final_post']


In [91]:
correct_answers = test_data[:, -3]
print(correct_answers.shape)
correct_logit_dirs = model.tokens_to_residual_directions(correct_answers)
print(correct_logit_dirs.shape)

torch.Size([256])
torch.Size([256, 128])


In [92]:
accumulated_contrib = einops.einsum(
    accumulated_resid_stack,
    correct_logit_dirs,
    "... batch d_model, batch d_model -> ..."
) / test_data.shape[0]

line(
    accumulated_contrib,
    title="Logit lens: direct logit attribution for prefix of model (first prediction)",
    xaxis="Layer", yaxis="attribution",
    x=labels
)

We see that the model's is only able to predict the correct answer after the second attn layer.


We can also check the effect of each individual components output on the logits by using decomposed resid (Note we should see almost the exact same plot since the above is just the sum of the prefix of the decomposed components):

In [93]:
decomposed_resid_stack, labels = cache.decompose_resid(layer=-1, pos_slice=-4, return_labels=True)
print(decomposed_resid_stack.shape)
print(labels)

torch.Size([4, 256, 128])
['embed', 'pos_embed', '0_attn_out', '1_attn_out']


In [94]:
decomposed_contrib = einops.einsum(
    decomposed_resid_stack,
    correct_logit_dirs,
    "... batch d_model, batch d_model -> ..."
) / test_data.shape[0]

line(
    decomposed_contrib,
    title="Per component direct logit attribution for first prediction",
    xaxis="Component", yaxis="attribution",
    x=labels
)

As expected, the second attn layer has the biggest direct effect on the logits.


This attribution combined with the fact that a 1L model could not perform the task suggests the model is doing some kind of composition. A next natural step is to start trying to understand interpretable activations - in this case the attn patterns.

## Stare at 1.0 patterns

When trying to understand activations, it's often useful to work backwards, so we start with head 1.0 pattern. 

Since the correct prediction is earlier in the context, one natural guess is that head 1.0 might attend to the correct token and copy it to the logits - a common motif that we've also seen in the "min of two ints" and "sort fixed length lists" toy model tutorials.

In [95]:
layer_1_patterns = cache['pattern', 1]
print(layer_1_patterns.shape)

torch.Size([256, 1, 12, 12])


First we can look at the average pattern over the test sample to get a lay of the land:

In [96]:
def tokens_to_plotly_labels(tokens):
    if tokens.ndim==2:
        tokens = tokens[0]
    res = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
    res[0] = 'BOS'
    res[LIST_LEN+1] = 'MID'
    res[2*LIST_LEN+2]= 'END'
    return res

tokens_to_plotly_labels(test_data)

['BOS',
 '44_1',
 '39_2',
 '33_3',
 'MID',
 '0_5',
 '2_6',
 '1_7',
 'END',
 '44_9',
 '33_10',
 '39_11']

In [97]:
labels = tokens_to_plotly_labels(test_data)
imshow(
    layer_1_patterns[:, 0, :, :].mean(dim=0),
    title="Attn head 1.0 pattern (averged over batch)",
    xaxis="src", yaxis="dest",
    x=labels, y=labels
)

Interestingly, the prediction positions do not attend to the answer token and copy as I guessed. Instead they almost always seem to attend to the same fixed position. Notice that all 3 of the prediction positions attend to fixed perm positions.


Why would the model do this? The token and positional embeddings at these source positions clearly don't contain information about the answer, so the information must have been moved to their residual stream by attn 0.0, and then moved again by 1.0! This is the first instance of attn head composition we've seen in these tutorials.


If this hypothesis is correct, this would be an example of a "V-composition" (see https://transformer-circuits.pub/2021/framework/index.html#three-kinds-of-composition). We can also think of it as a single "virtual attention head" - effectively a single head where the prediction positions attend to the correct list token, and then copy (see https://transformer-circuits.pub/2021/framework/index.html#virtual-attention-heads)


![picture](https://drive.google.com/uc?id=1orrWb0s7zTbkVShAl8MbMytq_8nHlc_Q)

Continuing to work backwards, we can check some attn patterns for 0.0. We expect to see that the permutation tokens attend to their corresponding list positions (perm token=i attends to pos=i+1).

## Stare at 0.0 patterns

In [98]:
layer_0_patterns = cache['pattern', 0]
print(layer_0_patterns.shape)

torch.Size([256, 1, 12, 12])


Once again I'll just start with the avg across the batch dimension:

In [99]:
labels = tokens_to_plotly_labels(test_data)
imshow(
    layer_0_patterns[:, 0, :, :].mean(dim=0),
    title="Attn head 0.0 patterns (avged over batch)",
    xaxis="src", yaxis="dest",
    x=labels, y=labels
)

We see that the permutation tokens mostly attend to the list positions, although this isn't very surprising. Let's zoom in on some concrete examples:

In [100]:
for _ in range(3):
    r = random.randint(0, test_data.shape[0]-1)
    labels = tokens_to_plotly_labels(test_data[r])
    imshow(
        layer_0_patterns[r, 0, :, :],
        title=f"Attn head 0.0 patterns for test example {r}",
        xaxis="src", yaxis="dest",
        x=labels, y=labels
    )

Notice that the perm destination tokens mostly attend to the corresponding positions in the list, as expected. (For example: perm token 0 always attends the pos=1, 0 indexed)

One way to double check this is to look at QK circuits. Recall that in the 0th layer we can decompose the attn scores into 4 terms. We expect the $W_E @ W_{QK} @ W_{pos}^T$ term to stand out, with a stripe when dest_token=i, src_pos=i+1 for i = 0, 1, 2.

## Understand 0.0 QK circuits

In [101]:
W_E = model.W_E
print(W_E.shape)
W_pos = model.W_pos
print(W_pos.shape)
W_QK = model.QK[0,0].AB
print(W_QK.shape)

torch.Size([53, 128])
torch.Size([12, 128])
torch.Size([128, 128])


In [102]:
imshow(
    W_E @ W_QK @ W_pos.T,
    title="Attn 0.0 full QK circuit W_E @ W_QK @ W_pos.T",
    xaxis="src pos", yaxis="dest token",
    x=labels
)

Notice we see the stripe we expected, and it stands out pretty clearly. Let's quickly look at the other QK circuit components:

In [103]:
imshow(
    W_E @ W_QK @ W_E.T,
    title="0.0 full QK circuit W_E @ W_QK @ W_E.T",
    xaxis="src token", yaxis="dest token"
)

imshow(
    W_pos @ W_QK @ W_E.T,
    title="0.0 full QK circuit W_pos @ W_QK @ W_E.T",
    xaxis="src token", yaxis="dest pos",
    y=labels
)

imshow(
    torch.tril(W_pos @ W_QK @ W_pos.T),
    title="0.0 full QK circuit W_pos @ W_QK @ W_pos.T",
    xaxis="src pos", yaxis="dest pos",
    x=labels, y=labels
)

While there's not as much clear structure here, the rows corresponding to the perm tokens in $W_EW_{QK}W_E^T$ stand out. This weakly suggests that this head is specialized to move information to these positions.


One way to sanity check this is to decompose query / key vectors into their components, $q = x_0W_Q = (e+p)W_Q = eW_Q + pW_Q$ (same idea for $k$), and checking the norms. We expect to see that the query vectors are dominated by the perm token embeddings, and the keys are dominated by the positional embeddings for the list:

In [104]:
imshow(
    torch.stack(
        [(cache['embed'] @ model.W_Q[0, 0]).norm(dim=-1).mean(dim=0),
        (cache['pos_embed'] @ model.W_Q[0, 0]).norm(dim=-1).mean(dim=0)]
    ),
    title="Query component norms",
    xaxis="pos", yaxis="component",
    y=["Embed", "Pos embed"], x=labels
)

imshow(
    torch.stack(
        [(cache['embed'] @ model.W_K[0, 0]).norm(dim=-1).mean(dim=0),
        (cache['pos_embed'] @ model.W_K[0, 0]).norm(dim=-1).mean(dim=0)]
    ),
    title="Key component norms",
    xaxis="pos", yaxis="component",
    y=["Embed", "Pos embed"], x=labels
)

It's weaker than I expected, but we do see the perm token embeddings slightly stand out in the queries, and the list positional embeddings stand out for their keys.

## Understand 1.0 QK circuits

Since we were able to get some traction on understanding the QK circuit weights for 0.0, let's try the same for 1.0.

We expect that the attn pattern is mainly determined for by the positions (recall the destination positions always attended very strongly to the fixed perm positions), so we expect the $W_{pos}W_{QK}W_{pos}^T$ term to explain this.

In [105]:
W_QK = model.QK[1, 0].AB
W_QK.shape

torch.Size([128, 128])

In [106]:
imshow(
    torch.tril(W_pos @ W_QK @ W_pos.T),
    title=f"1.0 full QK circuit W_pos @ W_QK @ W_pos.T",
    xaxis="src pos", yaxis='dest pos',
    x=labels, y=labels
)

Notice that the diagonal stripe for the positions we expected stands out. This partially explains how the model computes the 1.0 patterns, although there are still loads of terms we haven't considered.

## Decompose logits

I feel somewhat convinced that 0.0 causes perm tokens to attend to their corresponding list position, and that 1.0 causes prediction positions to the fixed perm positions. However we have yet to determine *what* information is moved. It would be really natural for the composition of both heads to just relay the token information for the correct answer using v-comp, acting as a virtual attn head.


In the past we were able to decompose the logits into a sum of terms, and stare at full OV circuits to understand what was going on, but here we have a lot more terms:

$$
\begin{align}
logits &= x_2W_U\\
&= (h_{1.0}(x_1) + x_1)W_U \\
&= (A^{1.0}x_1W_{OV}^{1.0} + x_1)W_U \\
&= (A^{1.0}(A^{0.0}x_0W_{OV}^{0.0} + x_0)W_{OV}^{1.0} + (A^{0.0}x_0W_{OV}^{0.0} + x_0))W_U \\
&= (A^{1.0}A^{0.0}x_0W_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}x_0W_{OV}^{1.0} + A^{0.0}x_0W_{OV}^{0.0} + x_0)W_U \\
&= (A^{1.0}A^{0.0}(e+p)W_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}(e+p)W_{OV}^{1.0} + A^{0.0}(e+p)W_{OV}^{0.0} + (e+p))W_U \\
&= (A^{1.0}A^{0.0}eW_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}A^{0.0}pW_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}eW_{OV}^{1.0} + A^{1.0}pW_{OV}^{1.0} + A^{0.0}eW_{OV}^{0.0} + A^{0.0}pW_{OV}^{0.0} + e + p)W_U \\
&= (A^{1.0}A^{0.0}tW_EW_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}A^{0.0}posW_{pos}W_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}tW_EW_{OV}^{1.0} + A^{1.0}posW_{pos}W_{OV}^{1.0} + A^{0.0}tW_EW_{OV}^{0.0} + A^{0.0}posW_{pos}W_{OV}^{0.0} + tW_E + posW_{pos})W_U \\
&= A^{1.0}A^{0.0}tW_EW_{OV}^{0.0}W_{OV}^{1.0}W_U + A^{1.0}A^{0.0}posW_{pos}W_{OV}^{0.0}W_{OV}^{1.0}W_U + A^{1.0}tW_EW_{OV}^{1.0}W_U + A^{1.0}posW_{pos}W_{OV}^{1.0}W_U + A^{0.0}tW_EW_{OV}^{0.0}W_U + A^{0.0}posW_{pos}W_{OV}^{0.0}W_U + tW_EW_U + posW_{pos}W_U \\
\end{align}
$$

Where $t$ and $pos$ are the tokens and positions as one hot encoded vectors respectively. It would be a pain to study every term, but if the model is doing what we expect, then the $A^{1.0}A^{0.0}tW_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ term should dominate in importance. We should be able to check this through a more refined direct logit attribution.

## Finer grained direct logit attribution

Recall we already used direct logit attribution to show $h_{1.0}(x_1)$ had the biggest effect on the logits. However we want to further decompose: 
$$
h_{1.0}(x_1) = A^{1.0}A^{0.0}eW_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}A^{0.0}pW_{OV}^{0.0}W_{OV}^{1.0} + A^{1.0}eW_{OV}^{1.0} + A^{1.0}pW_{OV}^{1.0}
$$

 Now we can just apply direct logit attribution on these terms, and prioritize the terms that matter most:

In [107]:
embed_both_heads_term = einops.einsum(
    cache['pattern', 1], cache['pattern', 0], cache['embed'], model.OV[0,0].AB, model.OV[1, 0].AB,
    "batch n_heads pos1 pos2, batch n_heads pos2 pos3, batch pos3 d_model1, d_model1 d_model2, d_model2 d_model3 -> batch pos1 d_model3"
)

pos_embed_both_heads_term = einops.einsum(
    cache['pattern', 1], cache['pattern', 0], cache['pos_embed'], model.OV[0,0].AB, model.OV[1, 0].AB,
    "batch n_heads pos1 pos2, batch n_heads pos2 pos3, batch pos3 d_model1, d_model1 d_model2, d_model2 d_model3 -> batch pos1 d_model3"
)

embed_head_1_term = einops.einsum(
    cache['pattern', 1], cache['embed'], model.OV[1, 0].AB,
    "batch n_heads q_pos k_pos, batch k_pos d_model1, d_model1 d_model2 -> batch q_pos d_model2"
)

pos_embed_head_1_term = einops.einsum(
    cache['pattern', 1], cache['pos_embed'], model.OV[1, 0].AB,
    "batch n_heads q_pos k_pos, batch k_pos d_model1, d_model1 d_model2 -> batch q_pos d_model2"
)


assert torch.allclose(embed_both_heads_term + pos_embed_both_heads_term + embed_head_1_term + pos_embed_head_1_term, cache['attn_out', 1], atol=1e-5)

In [108]:
resid_stack = torch.stack([embed_both_heads_term, pos_embed_both_heads_term, embed_head_1_term, pos_embed_head_1_term])
resid_stack = resid_stack[:, :, -4, :]
print(resid_stack.shape)
labels = ['embed_both_heads_term', 'pos_embed_both_heads_term', 'embed_head_1_term', 'pos_embed_head_1_term',]
print(labels)

torch.Size([4, 256, 128])
['embed_both_heads_term', 'pos_embed_both_heads_term', 'embed_head_1_term', 'pos_embed_head_1_term']


In [109]:
per_component_contribution = einops.einsum(
    resid_stack,
    correct_logit_dirs,
    "... batch d_model, batch d_model -> ..."
) / test_data.shape[0]

bar(
    per_component_contribution,
    title="Direct logit attribution for components of attn_out_1",
    xaxis="Layer", yaxis="attribution",
    x=labels
)

We show that the $A^{1.0}A^{0.0}eW_{OV}^{0.0}W_{OV}^{1.0}$ term is the most important by far, as expected. Note this is pretty great! We had a huge equation with 8 terms, but direct logit attribution allows us to prioritize one path through the model.

## Virtual attn pattern

Now that we know that the $A^{1.0}A^{0.0}tW_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ term explains most of the models behavior, we want to understand $A^{1.0}A^{0.0}$ and $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$. Note $A^{1.0}A^{0.0}$ tells us *where* the virtual attn head moves information from, while $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$. $A^{1.0}A^{0.0}$ tells us *what* information to move from the src to the logits (given that we've attended to it). Let's tackle the $A^{1.0}A^{0.0}$ term first. We should expect that the virtual attn head effectively just attends to the correct next token.

In [110]:
for _ in range(3):
    r = random.randint(0, test_data.shape[0]-1)
    labels = tokens_to_plotly_labels(test_data[r])
    imshow(
        layer_1_patterns[r, 0, :, :] @ layer_0_patterns[r, 0, :, :],
        title=f"virtual attn pattern A^1.0 @ A^0.0 for example {r}",
        xaxis="src", yaxis="dest",
        x=labels, y=labels
    )

This is basically what we expected. The prediction destination tokens attend most strongly to the correct list token.

## Virtual OV circuit

Now that we are convinced the virtual attn head just attends to the correct next token, it would be natural for $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ to just copy that to the logits. We expect to see a strong diagonal as we've seen in many other circuits:

In [111]:
virtual_OV_circuit = W_E @ model.OV[0,0].AB @ model.OV[1, 0].AB @ model.W_U
print(virtual_OV_circuit.shape)

torch.Size([53, 50])


In [112]:
imshow(
    virtual_OV_circuit,
    title="Virtual OV circuit: W_E @ W_OV^0.0 @ W_OV^1.0 @ W_U",
    xaxis="logit", yaxis="src token",
)

In [113]:
top_1_acc = (virtual_OV_circuit[:MAX_INT].argmax(dim=-1) == torch.arange(MAX_INT, device=device)).float().mean()
print(f"Fraction of time top logit is on diagonal:", top_1_acc.item())

Fraction of time top logit is on diagonal: 1.0


We see clear copying, confirming the V-composition that we expected.

# Summary

We found that a 2L, 1H, attn only transformer with no biases or LN can learn to permute lists by leveraging V-composition:


1. attn 0.0 moves answer information from list indices to corresponding perm positions.
2. attn 1.0 routes the same information from perm positions to the prediction positions.


The lines of evidence we used are:


1. Direct logit attribution: we found the last attn layer had the greatest direct effect on the logits by far (specifically the term corresponding to the virtual attn head).
2. 1.0 attn patterns showed that the prediction positions empirically attended to fixed perm positions, and the QK circuit weights explained this.
3. 0.0 attn patterns showed that the perm token i attended to corresponding list positions i+1, and the QK circuit weights explained this.
4. The virtual attn pattern $A^{1.0}A^{0.0}$ showed that the prediction positions effectively attended to the correct list token for the examples we checked.
5. The virtual OV circuit $W_EW_{OV}^{0.0}W_{OV}^{1.0}W_U$ showed clear copying (100% top 1 logit on diagonal).

## General techniques you can apply to other problems

* Direct logit attribution
* Staring at attn patterns
* Decomposing activations into a sum of terms, and analyzing their importance
* Multiplying out QK / OV circuits