# Problem

Train a transformer without positional embeddings to predict the previous token and reverse engineer how it does this.

# Setup
(No need to read)

In [1]:
!git clone https://github.com/ckkissane/mech-interp-practice.git

fatal: destination path 'mech-interp-practice' already exists and is not an empty directory.


In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-bp3qx6wi
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-bp3qx6wi
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 68bdb6d3a9d28cd155ca460e3cc25a8d2ca824c7
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [3]:
try:
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
except:
    import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-vz9ujeth
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-vz9ujeth
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit 3b7148fc7fb534e551a6ede1448293cc4a14f815
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [4]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [5]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import circuitsvis as cv
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [6]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [7]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [8]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

# Load Model

In [9]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x78e0a6864550>

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


The model is a 2L, 1H (per layer), attn-only transformer with no biases, layernorms, or positional embeddings.

![picture](https://drive.google.com/uc?id=115AlOlXFE24PMYugkPeAadrkrh-hkOcS)

It is loaded into this notebook below:

In [11]:
MAX_NUM=100
cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=1,
    d_model=128,
    d_head=128,
    attn_only=True,
    normalization_type=None,
    n_ctx=10,
    d_vocab=MAX_NUM+1, # 0,...,MAX_NUM-1, BOS
    device=device,
    seed=0
)

model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [12]:
filename = "mech-interp-practice/models/rederive_position_model.pt"
state_dict = torch.load(filename)
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

# Task description

The model is given a sequence of tokens and trained to predict the previous token at each position. The inputs are of the form

```
[BOS, a_0, a_1, ..., a_8]
```

* BOS (100) is a special token at the beginning of each sequence
* Each a_i is a random token between [0, 99] inclusive

Each example is fixed at length 10 (including BOS). At each position $i \in [1, ..., 9]$ (0 indexed), the model is trained to predict the token at position i-1. It isn't expected to make a prediction at position 0, since there is no previous token to predict.

A data loader and example tokens are provided below for you to begin your investigation:

In [13]:
BOS_TOKEN = cfg.d_vocab-1
def make_data_generator(cfg, batch_size, seed=0):
    torch.manual_seed(seed)
    while True:
        tokens = torch.randint(0, cfg.d_vocab, (batch_size, cfg.n_ctx))
        tokens[:, 0] = BOS_TOKEN
        yield tokens

batch_size = 256
data_loader = make_data_generator(cfg, batch_size, seed=42)
test_sample = next(data_loader).to(device)
print(test_sample.shape)
print(test_sample[:5])

torch.Size([256, 10])
tensor([[100,  32,  94,  55,   2,  21,  10,  12,  47,  30],
        [100,  38,  38,  58,  86,   4,  43,  66,  86,  46],
        [100,   6,  11,  26,  52,  96,  42,  53,  85,  90],
        [100,  32,  48,  28,  66,  74,  24,  60,  80,  68],
        [100,  63,  81,  89,  23,  26,  15,  26,  78,  91]], device='cuda:0')


The model is trained to minimize cross entropy loss. The exact loss function used during training is provided below:

In [14]:
def loss_fn(logits, tokens):
    logits = logits[:, 1:, :]
    logits = logits.to(torch.float64)
    labels = tokens[:, :-1]
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    return -correct_log_probs.mean()

# Solution

## Sanity check

Let's start by sanity checking that the model has actually learned the task:

In [15]:
with torch.inference_mode():
    logits = model(test_sample)
    logits = logits[:, 1:]
    labels = test_sample[:, :-1]
    preds = logits.argmax(dim=-1)
    acc = (preds == labels).float().mean()
    print("Accuracy on test sample:", acc.item())

Accuracy on test sample: 0.9891493320465088


## Stare at 1.0 attn patterns

Normally it's a good practice to use techniques like [direct logit attribution](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=disz2gTx-jooAcR0a5r8e7LZ) to narrow down what bits of the model are important for this task. However since this is the smallest model I found that could solve this task, I'm just going to assume everything is important.

The next natural step is to try to understand meaningful activations. Since this is an attention only model, this just entails the attention patterns. It's often helpful to work backwards when reverse engineering models, so we can start by looking at the attention patterns for L1H0.

In [16]:
original_logits, cache = model.run_with_cache(test_sample)
print(original_logits.shape)
print(cache)

torch.Size([256, 10, 101])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_post']


Since this task is fully position based, we might expect each example to have very similar attention patterns. Thus we can start by looking at the average of all attention patterns in the test sample.

In [17]:
def tokens_to_labels(tokens):
    if tokens.ndim==2:
        tokens = tokens[0]

    labels = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
    labels[0] = "BOS_0"
    return labels

labels = tokens_to_labels(test_sample)
print(labels)

['BOS_0', '32_1', '94_2', '55_3', '2_4', '21_5', '10_6', '12_7', '47_8', '30_9']


In [18]:
general_labels = ["BOS"] + [f"a_{i}" for i in range(model.cfg.n_ctx-1)]
layer, head_index = 1, 0
imshow(
    cache['pattern', layer][:, head_index, :, :].mean(dim=0), # TL attn patterns (cache['pattern', layer]) have shape [batch, n_heads, dest_pos, src_pos]
    title=f"{layer}.{head_index} attn patterns (avged over test sample)",
    xaxis="src", yaxis="dest",
    x=general_labels, y=general_labels
)

We can see that each destination position mostly attends to the previous src position. This is interesting, because it implies that at this point the model has already successfully re-derived some positional information, despite only starting with token embeddings.


One next natural hypothesis is that the OV circuit for 1.0 "copies", predicting the token that it most strongly attends to. As usual we can pretty easily check this by multiplying out the full OV circuit weights: $W_EW_{OV}^{1.0}W_U$

## 1.0 OV circuit copies

In [19]:
W_E = model.W_E
print(W_E.shape)
W_U = model.W_U
print(W_U.shape)

torch.Size([101, 128])
torch.Size([128, 101])


In [20]:
token_labels = list(map(str, range(cfg.d_vocab)))
token_labels[-1] = "BOS"
print(token_labels)

['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', 'BOS']


In [21]:
layer, head_index = 1, 0
imshow(
    W_E @ model.OV[layer, head_index].AB @ W_U,
    title=f"{layer}.{head_index} full OV circuit W_E @ W_OV @ W_U",
    xaxis="logit effect", yaxis="src token",
    x=token_labels, y=token_labels
)

Recall the interpretation of this full OV circuit: the direct logit effect of 1.0, given that 1.0 fully attends to some src token. The strong diagonal implies that if 1.0 attends to some src token, it will boost the logit for that same token. Thus it is "copying" what it attends to, as expected.


We can quantitatively verify this with a nice summary statistic, the fraction of the time that top logit is on the diagonal:

In [22]:
layer, head_index = 1, 0
top_1_acc = ((W_E @ model.OV[layer, head_index].AB @ W_U).argmax(dim=-1) == torch.arange(cfg.d_vocab, device=device)).float().mean()
print("Fraction of time top logit is on diagonal:", top_1_acc.item())

Fraction of time top logit is on diagonal: 1.0


## Stare at 0.0 attn patterns

So far we have some understanding of the 1.0 attn pattern, and how they connect with the logits via the OV circuit: 1.0 attends to the previous token, and copies this to the logits to predict that token. But how did 1.0 learn to attend to the previous token in the first place?


Continuing to work backwards, let's see if we can also understand 0.0 attn patterns.


Once again, let's just look at the average over the batch.

In [23]:
layer, head_index = 0, 0
imshow(
    cache['pattern', layer][:, head_index, :, :].mean(dim=0), # TL attn patterns (cache['pattern', layer]) have shape [batch, n_heads, dest_pos, src_pos]
    title="0.0 attn patterns (avged over test sample)",
    xaxis="src", yaxis="dest",
    x=general_labels, y=general_labels
)

Fascinatingly, this attn pattern also has clear structure. we see that the attn paid to BOS decreases smoothly as we increase destination position, and the attn paid to all non-BOS tokens looks roughly uniform for each row (although there is slightly less attn paid to itself).


How does the model actually compute this distinct pattern without positional embeddings? One guess is that each token gives the same attn score to BOS (and to all non-BOS tokens), and then the causal mask + softmax causes the decreasing attn we see above. The first way to sanity check this is just to directly look at the attn scores. We expect to see that the scores look roughly for every row.

In [24]:
layer, head_index = 0, 0
imshow(
    cache['attn_scores', layer][:, head_index, :, :].mean(dim=0).tril(), # TL attn scores (cache['attn_scores', layer]) have shape [batch, n_heads, dest_pos, src_pos]
    title=f"{layer}.{head_index} attn scores",
    xaxis="src", yaxis="dest",
    x=general_labels, y=general_labels
)

Here we clearly see that each destination token distributes attn scores in the same matter. Thus it is the causal mask + softmax that is creating the distinct attn pattern that we saw above.


We should also be able to see this directly encoded in the QK circuit weights, since $attnscores_{0.0} = \frac{tW_EW_{QK}^{0.0}W_E^Tt^T}{\sqrt{d_{head}}}$, where $t$ is the input tokens as one hot encoded vectors.

## 0.0 QK circuit

In [25]:
layer, head_index = 0, 0
imshow(
    W_E @ model.QK[layer, head_index].AB @ W_E.T,
    title=f"{layer}.{head_index} QK full circuit W_E @ W_QK @ W_E.T",
    xaxis="src token", yaxis="dest token",
    x=token_labels, y=token_labels
)

We notice that every row looks roughly the same. Each destination token gives:

1. High positive score to BOS
2. Low negative score to itself
3. Medium ~uniform negative scores to everything else

These are consistent with the attn pattern and attn scores we saw above.

At this point we have a some understanding of:
1. The empirical structure of both attn patterns, $A^{1.0}$ and $A^{0.0}$
2. How the model computes $A^{0.0}$ with 0.0 QK circuit, given the input tokens
3. How the model computes the logits by copying the tokens that $A^{1.0}$ attends to with 1.0 OV circuit


We are left to understand how $A^{1.0}$ is computed. Recall


$$
\begin{align}
A^{1.0} &= softmax(attnscores_{1.0})
\end{align}
$$


where


$$
\begin{align}
attnscores_{1.0} &= \frac{eW_{QK}^{1.0}e^T + eW_{QK}^{1.0}(W_{OV}^{0.0})^Te^T(A^{0.0})^T + A^{0.0}eW_{OV}^{0.0}W_{QK}^{1.0}e^T + A^{0.0}eW_{OV}^{0.0}W_{QK}^{1.0}(W_{OV}^{0.0})^Te^T(A^{0.0})^T}{\sqrt{d_{head}}}
\end{align}
$$


and $e = tW_E$ is the token embedding.


Given the striking attn paid to BOS we saw in $A^{0.0}$ we might expect the model to convert this information into some notion of position:


![picture](https://drive.google.com/uc?id=12iQg-tgDmnoXWELU-8kCBKiwOtRmS9kY)


(Note this diagram is actually wrong, as we show later. The key point is the hypothesis that the model should need both K and Q composition to represent positional information after L0.)




If 0.0 is moving the BOS token info to re-derive position, we would expect the K and Q composition term, $A^{0.0}eW_{OV}^{0.0}W_{QK}^{1.0}(W_{OV}^{0.0})^Te^T(A^{0.0})^T$, of 1.0 attn scores to be the most important. One way to sanity check this is to zero ablate each term and see how much it trashes performance. (See https://transformer-circuits.pub/2021/framework/index.html#term-importance-analysis for inspiration for this technique)

## 1.0 QK circuit - Term importance analysis

In [26]:
e = cache['embed'][:, None, :, :]
print(e.shape)
A_0 = cache['pattern', 0]
print(A_0.shape)
W_QK_1 = model.QK[1,0].AB
print(W_QK_1.shape)
W_OV_0 = model.OV[0,0].AB
print(W_OV_0.shape)

torch.Size([256, 1, 10, 128])
torch.Size([256, 1, 10, 10])
torch.Size([128, 128])
torch.Size([128, 128])


In [27]:
components = [
    (e @ W_QK_1 @ e.mT) / np.sqrt(cfg.d_head),
    (e @ W_QK_1 @ W_OV_0.mT @ e.mT @ A_0.mT) / np.sqrt(cfg.d_head),
    (A_0 @ e @ W_OV_0 @ W_QK_1 @ e.mT) / np.sqrt(cfg.d_head),
    (A_0 @ e @ W_OV_0 @ W_QK_1 @ W_OV_0.mT @ e.mT @ A_0.mT) / np.sqrt(cfg.d_head)
]

component_names = [
    "e @ W_QK_1 @ e.mT",
    "e @ W_QK_1 @ W_OV_0.mT @ e.mT @ A_0.mT",
    "A_0 @ e @ W_OV_0 @ W_QK_1 @ e.mT",
    "A_0 @ e @ W_OV_0 @ W_QK_1 @ W_OV_0.mT @ e.mT @ A_0.mT"
]

attn_scores_1 = cache['attn_scores', 1]
print(attn_scores_1.shape)

summed_components = torch.zeros_like(attn_scores_1)
for component in components:
    summed_components += component

summed_components = model.blocks[1].attn.apply_causal_mask(summed_components)
assert torch.allclose(attn_scores_1, summed_components, atol=1e-3)

torch.Size([256, 1, 10, 10])


In [28]:
original_loss = loss_fn(original_logits, test_sample)
print(original_loss)

tensor(0.0407, device='cuda:0', dtype=torch.float64)


In [29]:
def ablate_attn_score_component(attn_scores, hook, component):
    attn_scores -= component
    attn_scores = model.blocks[hook.layer()].attn.apply_causal_mask(attn_scores)
    return attn_scores

print(f"Original loss: {original_loss.item()}")
for name, component in zip(component_names, components):
    hook_fn = partial(ablate_attn_score_component, component=component)
    ablated_logits = model.run_with_hooks(
        test_sample,
        fwd_hooks=[(utils.get_act_name('attn_scores', 1), hook_fn)]
    )
    ablated_loss = loss_fn(ablated_logits, test_sample)
    print(f"Loss after ablating {name}: {ablated_loss.item()}")

Original loss: 0.04070200162050724
Loss after ablating e @ W_QK_1 @ e.mT: 3.4529969129421243
Loss after ablating e @ W_QK_1 @ W_OV_0.mT @ e.mT @ A_0.mT: 0.11072090213701982
Loss after ablating A_0 @ e @ W_OV_0 @ W_QK_1 @ e.mT: 0.10577847936937634
Loss after ablating A_0 @ e @ W_OV_0 @ W_QK_1 @ W_OV_0.mT @ e.mT @ A_0.mT: 0.4725709885204433


We do see that the K and Q composition component $A^{0.0}eW_{OV}^{0.0}W_{QK}^{1.0}(W_{OV}^{0.0})^Te^T(A^{0.0})^T$ is very important, but surprisingly the component that just uses token embeddings as keys/queries, $eW_{QK}^{1.0}e^T$, is also very important according to this metric. The other two components seem relatively less important, so let's ignore those for now.

Let's try to better understand these terms by zooming in.

## Understand 1.0 QK circuit components

Let's start with the $eW_{QK}^{1.0}e^T$ component by looking at the weights $W_EW_{QK}^{1.0}W_E^T$. Recall this can be interpreted a lookup table for what attn scores destination tokens should give source tokens for 1.0:

In [30]:
imshow(
    W_E @ W_QK_1 @ W_E.T,
    title="1.0 full QK circuit W_E @ W_QK @ W_E.T",
    xaxis="src token", yaxis="dest token",
    x=token_labels, y=token_labels
)

Notice the red diagonal. This suggests a "dont attend to the same token" interpretation of this component. This disproves the initial hypothesis that the K and Q composition circuits fully encode absolute positional information. We can refine our hypothesis:


1. The K and Q composition component, $A^{0.0}eW_{OV}^{0.0}W_{QK}^{1.0}(W_{OV}^{0.0})^Te^T(A^{0.0})^T$, component causes destination tokens to look for recent src tokens
2. But this also gives high score to itself, so the $eW_{QK}^{1.0}e^T$ implements "dont attend to the same token" to solve this


Let's check this by investigating the $A^{0.0}eW_{OV}^{0.0}W_{QK}^{1.0}(W_{OV}^{0.0})^Te^T(A^{0.0})^T$ component. We can start with the scores:

In [31]:
imshow(
    (components[-1][:, 0, :, :].mean(dim=0)).tril(),
    title="1.0 K and Q comp scores component (avged over test sample)",
    xaxis="src", yaxis="dest",
    x=general_labels, y=general_labels
)

We broadly see that this component gives higher scores to more recent tokens, including itself. (Although the first 3 destination positions look a bit different)

This is some evidence that this component's role is to give high attn scores to nearby tokens (rather than fully rederive absolute position). But how does it compute this? Let's look at the corresponding weights:

In [32]:
imshow(
    W_E @ W_OV_0 @ W_QK_1 @ W_OV_0.T @ W_E.T,
    title="1.0 K and Q comp circuit W_E @ W_OV_0 @ W_QK_1 @ W_OV_0.T @ W_E.T",
    xaxis="token embedding moved to src", yaxis="token embedding moved to dest",
    x=token_labels, y=token_labels
)

Recall the interpretation of this heatmap: "Lookup table of 1.0 attn scores if token y is moved by 0.0 to destination token and used as query, and token x is moved by 0.0 to src token and used as key". This is a bit harder to interpret, since empirically 0.0 spreads its attention over many tokens. However we can still reason about it if we notice there are just 3 regimes:
1. High positive score if both query and key use non-BOS token info
2. Low negative score if one side uses BOS token info, but the other used non-BOS token info
3. Medium positive score if both query and key use BOS token info.


If we recall $A^{0.0}$, we can see that mostly BOS information is moved for early tokens, but non-BOS token information tokens starts to dominate as we increase position. Thus regime #1 will become less important after the first few positions.


Regime #2 will cause very negative scores for src tokens far away, since the older tokens will have relatively more BOS information, and the current token will have relatively more non-BOS info (also notice that this will dominate this term, since later destination tokens have less BOS used as query).


Regime #3 will cause higher scores for more recent tokens, since later positions attend to more non-BOS information.


All of this considered, we can roughly interpret this K and Q composition circuit to implement "attend more to recent tokens", as we saw in this component's attn scores.

## Adversarial examples

If we really understand this model, we should be able to predict its behavior on inputs that we've never seen. In particular we can create adversarial examples where the model will perform badly.


If our understanding of the algorithm is correct, then we should expect the model to perform worse on text with consecutive repeats. The reason for this is it should exploit the "don't attend to my own token" heuristic encoded in $W_EW_{QK}^{1.0}W_E^T$

In [33]:
tokens = einops.repeat(torch.randint(0, cfg.d_vocab-1, (cfg.n_ctx//2,), device=cfg.device), "i -> (i j)", j=2)[None, :]
tokens[:, 0] = BOS_TOKEN
print(tokens.shape)
print(tokens)

torch.Size([1, 10])
tensor([[100,  43,   9,   9,  98,  98,  50,  50,  93,  93]], device='cuda:0')


In [34]:
advex_logits, advex_cache = model.run_with_cache(tokens)
print(advex_logits.shape)
print(advex_cache)

torch.Size([1, 10, 101])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_post']


In [35]:
advex_loss = loss_fn(advex_logits, tokens)
print("original loss", original_loss.item())
print("advex_loss", advex_loss.item())

original loss 0.04070200162050724
advex_loss 0.11380505347082581


Notice the loss on this adversarial example is significantly worse than our normal test sample. Since we know that this model copied whatever 1.0 attends to, we should also expect to see that the 1.0 attn pattern attends less to the previous token when prev_token == current_token:

In [36]:
imshow(
    advex_cache['pattern', 1][0, 0],
    title="1.0 attn pattern on adversarial example",
    xaxis="src", yaxis="dest",
    x=tokens_to_labels(tokens), y=tokens_to_labels(tokens)
)

Another adversarial example: If we don't include a BOS token we should expect the model to struggle (since it seems to rely on this information on this in 1.0 K/Q composition)

In [37]:
tokens = torch.randint(0, cfg.d_vocab-1, (1, cfg.n_ctx), device=device)
print(tokens.shape)
print(tokens)

torch.Size([1, 10])
tensor([[ 3, 69,  0, 24, 62, 93, 70, 32, 17, 40]], device='cuda:0')


In [38]:
logits, cache = model.run_with_cache(tokens)
print(logits.shape)
print(cache)

torch.Size([1, 10, 101])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_post']


In [39]:
loss = loss_fn(logits, tokens)
print("original loss", original_loss.item())
print("loss without BOS", loss.item())

original loss 0.04070200162050724
loss without BOS 1.1932205836434049


Loss is way worse as expected.

In [40]:
labels = tokens_to_labels(tokens)
labels[0] = str(tokens[0, 0].item()) + '_0'
imshow(
    cache['pattern', 1][0,0],
    title="1.0 pattern on example without BOS",
    xaxis="src", yaxis="dest",
    x=labels, y=labels
)
imshow(
    cache['pattern', 0][0,0],
    title="0.0 pattern on example without BOS",
    xaxis="src", yaxis="dest",
    x=labels, y=labels
)

# Summary

We find that a 2L attn-only transformer with no positional embeddings, biases, or layernorms can still predict the previous token using the following algorithm:


1. Use K and Q composition to attend to recent tokens in 1.0, but attend less your own token, resulting in attending most strongly to the previous token.
2. Predict the token you attended to in 1.0 by copying with the OV circuit




The lines of evidence we used:
1. Empirically $A^{1.0}$ attends most strongly to the previous token.
2. 1.0 full OV circuit weights show a clear "copying" pattern, suggesting the model will predict what 1.0 attends to.
3. 1.0 attn score term importance analysis: we found the K+Q composition term and the embeddings (no composition) term are the most important by far.
4. 1.0 QK circuit weight analysis: We saw that the $W_EW_{QK}^{1.0}W_E^T$ circuit gave low attn scores to the same token, and reasoned that the K+Q composition circuit, $W_EW_{OV}^{0.0}W_{QK}^{1.0}(W_{OV}^{0.0})^TW_E^T$, gave high scores to more recent tokens.
6. Adversarial examples: We were able to use mechanistic understanding to handcraft inputs that the model failed

## General techniques you can apply to other problems

* Staring at attn patterns
* Multiplying out QK/OV circuits
* Using zero ablations for term importance analysis
* Adversarial examples to sanity check hypotheses