# Problem

Interpret a toy model trained to take the minimum of two ints.

# Setup
(No need to read)

In [1]:
!git clone https://github.com/ckkissane/mech-interp-practice.git

Cloning into 'mech-interp-practice'...
remote: Enumerating objects: 93, done.[K
remote: Counting objects: 100% (93/93), done.[K
remote: Compressing objects: 100% (89/89), done.[K
remote: Total 93 (delta 31), reused 30 (delta 3), pack-reused 0[K
Receiving objects: 100% (93/93), 20.77 MiB | 4.55 MiB/s, done.
Resolving deltas: 100% (31/31), done.


In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-ion8cude
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-ion8cude
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 10d2f8a026d73eada861c7d51064f7e24d8f482c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.14.3-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━


## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Get:3 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ Packages [43.3 kB]
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:5 http://security.ubuntu.com/ubuntu jammy-security InRelease [110 kB]
Hit:6 https://ppa.launchpadcontent.net/c2d4u.team/c2d4u4.0+/ubuntu jammy InRelease
Get:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [119 kB]
Get:8 http://security.ubuntu.com/ubuntu jammy-security/restricted amd64 Packages [850 kB]
Get:9 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [109 kB]
Get:10 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease [18.1 kB]
Get:11 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 Packages [1,103 kB]
Get:12 http://ar

In [3]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [4]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [5]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [6]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [7]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

In [8]:
def visualize_attn_patterns(heads, local_tokens, local_cache, title: str = ""):
    labels = []
    patterns = []
    batch_index = 0

    for head in heads:
        if isinstance(head, tuple):
            layer, head_index = head
        else:
            layer, head_index = head // model.cfg.n_heads, head % model.cfg.n_heads
        patterns.append(local_cache["pattern", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    patterns = torch.stack(patterns, dim=-1)
    attn_viz = pysvelte.AttentionMulti(tokens=model.to_str_tokens(local_tokens[batch_index]), attention=patterns, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attn_viz.show()

# Load Model

In [9]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f82097c5f00>

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


The simplest model I found that can learn this task is a 1L, 1 head, attn-only transformer with no biases, layernorms, or positional embeddings.

The model has already been trained seperately, and is loaded into this notebook below:

In [11]:
MAX_NUM=200
cfg = HookedTransformerConfig(
    n_layers=1,
    n_heads=1,
    d_model=32,
    d_head=32,
    n_ctx=3, # a b =
    d_vocab=MAX_NUM+1, # 0, 1, ..., MAX_NUM-1, =
    d_vocab_out=MAX_NUM,
    normalization_type=None,
    attn_only=True,
    device=device,
    seed=0
)

model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [12]:
filename = "mech-interp-practice/models/min_int_model.pt"
state_dict = torch.load(filename)

state_dict = torch.load(filename)
model.load_state_dict(state_dict, strict=False);

# Task description

Given a sequence [a, b, =] the model was trained to output min(a,b) at the '=' position. Both $a$ and $b$ are integers which range between [0, 199] inclusive. The special '=' token (200) is always at the end of every sequence.


Below I provide the same dataset I used to train the model.


* 'dataset' is the full dataset - every possible [a, b, =] example for $a, b \in$ [0, 199] inclusive
* 'labels' are the corresponding labels: min(a,b) for each [a,b] pair in 'dataset'
* 'train_data' is a subset of 'dataset' used to train the model. 'train_labels' are the corresponding labels
* 'test_data' and 'test_labels' are the remaining 'dataset' examples and corresponding labels. These were held out during training, and just used to evaluate the model.

In [13]:
a_vec = einops.repeat(torch.arange(MAX_NUM), "i -> (i j)", j=MAX_NUM)
print(a_vec.shape)
print(a_vec[:5])

b_vec = einops.repeat(torch.arange(MAX_NUM), "i -> (j i)", j=MAX_NUM)
print(b_vec.shape)
print(b_vec[:5])

eq_vec = torch.ones_like(a_vec) * cfg.d_vocab-1
print(eq_vec.shape)
print(eq_vec[:5])

dataset = torch.stack([a_vec, b_vec, eq_vec], dim=-1).to(device)
print(dataset.shape)
print(dataset[:5])

torch.Size([40000])
tensor([0, 0, 0, 0, 0])
torch.Size([40000])
tensor([0, 1, 2, 3, 4])
torch.Size([40000])
tensor([200, 200, 200, 200, 200])
torch.Size([40000, 3])
tensor([[  0,   0, 200],
        [  0,   1, 200],
        [  0,   2, 200],
        [  0,   3, 200],
        [  0,   4, 200]], device='cuda:0')


In [14]:
labels = dataset.min(dim=-1).values
print(labels.shape)
print(labels[:5])

torch.Size([40000])
tensor([0, 0, 0, 0, 0], device='cuda:0')


In [15]:
DATA_SEED = 4
torch.manual_seed(DATA_SEED)

train_frac = 0.8

indices = torch.randperm(dataset.shape[0])
cutoff = int(train_frac * dataset.shape[0])

train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
print(train_data.shape)
print(train_data[:5])
print(train_labels.shape)
print(train_labels[:5])

test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(test_data.shape)
print(test_data[:5])
print(test_labels.shape)
print(test_labels[:5])

torch.Size([32000, 3])
tensor([[  7, 130, 200],
        [182,   0, 200],
        [110, 155, 200],
        [ 82,   8, 200],
        [126,  97, 200]], device='cuda:0')
torch.Size([32000])
tensor([  7,   0, 110,   8,  97], device='cuda:0')
torch.Size([8000, 3])
tensor([[ 19, 171, 200],
        [  6, 195, 200],
        [ 77,   0, 200],
        [ 97,  39, 200],
        [194,  42, 200]], device='cuda:0')
torch.Size([8000])
tensor([19,  6,  0, 39, 42], device='cuda:0')


The model was trained to minimize cross entropy loss. I provide the exact loss function I used during training here:

In [16]:
def loss_fn(logits, labels):
    if logits.ndim == 3:
        logits = logits[:, -1, :]

    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])
    return -correct_log_probs.mean()

# Solution

Where do we start? According to Neel Nanda: "Understanding a transformer breaks down into two high-level parts - interpreting the features represented by non-linear activations (output probabilities, attention patterns, neuron activations), and interpreting the circuits that calculate each feature (the weights representing the computation done to convert earlier features into later features)."


We already understand the output probabilities, and this model has no MLPs, so that leaves the attention patterns. Note that for a bigger model we might want to localize which bits of the model are important before staring at activations, but this model is small enough that we can safely skip this step.

## Stare at attn patterns

Attn patterns are often interpretable, and they are super easy to visualize using TransformerLens ActivationCache

In [17]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.shape)
print(cache)

torch.Size([40000, 3, 200])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post']


In [18]:
layer_0_patterns = cache['pattern', 0]
print(layer_0_patterns.shape)

torch.Size([40000, 1, 3, 3])


First I'll just look at the average attn pattern over the entire dataset. Note that I'll only look at the row of the pattern corresponding to destination token '=', since this is where the relevant prediction is made.

In [19]:
bar(
    einops.reduce(layer_0_patterns[:, 0, -1, :], "batch src -> src", "mean"),
    title="attn 0.0 from = to src residual stream, avged over entire dataset",
    xaxis="src", yaxis="attn",
    x=["a", "b", "="]
)

We can see that the "=" token only attends to "a" and "b", never itself. A natural guess might be that it always attends more to the minimum, so the OV circuit can just copy that as the answer. We can sanity check this on some random examples.

In [20]:
def tokens_to_plotly_labels(tokens):
    if tokens.ndim == 2:
        tokens = tokens[0]
    res = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
    res[-1] = '='
    return res

In [21]:
for _ in range(3):
    r = random.randint(0, dataset.shape[0]-1)
    labels = tokens_to_plotly_labels(dataset[r])
    bar(
        layer_0_patterns[r, 0, -1, :],
        title=f"attn 0.0 from = to src residual stream, for {labels}",
        xaxis="src", yaxis="attn",
        x=labels
    )

This seems to be what we expected. We can be more rigorous by plotting each act over every dataset example. We expect to see a clear divide: fully pay attn to "a" when "a" < "b", fully pay attn to "b" when "b" < "a":

In [22]:
imshow(
    torch.stack([layer_0_patterns[:, 0, -1, 0].reshape(MAX_NUM, MAX_NUM),
     layer_0_patterns[:, 0, -1, 1].reshape(MAX_NUM, MAX_NUM)]),
    yaxis='a', xaxis='b',
    title="attn paid from = to a and b, over entire dataset",
    facet_col=0,
    facet_labels = ["attn paid to a", "attn paid to b"]
)

This is what we expected, although it's fuzzier when a and b are approximately equal. I feel pretty convinced that this model has learned to pay more attn to the minimum token.

## Hypothesis: OV circuit just copies

Now that we have a good understanding of the nonlinear activations, we can move on to "interpreting the circuits that calculate each feature (the weights representing the computation done to convert earlier features into later features)."


In particular I want to interpret how the OV circuit maps the token it attends to into high output logit for that token.


To figure out the relevant weights, we can do some transformer math:

![picture](https://drive.google.com/uc?id=1awi0xgOPet5cnOe-hTEnW6dfVN4_QJt4)

$$
\begin{aligned}
logits &= x_1W_U\\
&= \left(h_{0.0}(e) + e\right)W_U\\
&= h_{0.0}(e)W_U + eW_U\\
& = A^{0.0}eW_{OV}^{0.0}W_U + eW_U\\
& = A^{0.0}tW_E W_{OV}^{0.0}W_U + tW_EW_U\\
\end{aligned}
$$

Where $W_{OV} = W_VW_O$, $t$ represents the input tokens as one hot encoded vectors, and $A$ represents the attention pattern. You can see $W_E W_{OV}^{0.0}W_U$ is the circuit that maps "token we attended to" to the logits. We'll call this the full OV circuit.

* Aside: I claim we can ignore the $tW_EW_U$ term. This term should always be the same for every prediction regardless of 'a' and 'b', since we constructed the dataset examples to always end with the '=' token.

In [23]:
full_OV_circuit = model.W_E @ model.OV[0, 0].AB @ model.W_U
print(full_OV_circuit.shape)

torch.Size([201, 200])


In [24]:
imshow(
    full_OV_circuit,
    title="attn 0.0 full OV circuit",
    yaxis="src token", xaxis="logit"
)

We can interpret this as a big look up table from "token we fully attended to" to corresponding output logits. The strong diagonal suggests that it gives high logits to the same token that we attend to. (Jargon note: I've seen many refer to this as a "copying" OV circuit. It's a fairly common motif in attn-only circuits.)

We can double check this by checking the fraction of the time the top1 logit is on the diagonal.

In [25]:
top_1_acc = (full_OV_circuit[:-1, :].argmax(dim=-1) == torch.arange(cfg.d_vocab_out, device=device)).float().mean()
print("Fraction of the time top score is on the diagonal:", top_1_acc.item())

Fraction of the time top score is on the diagonal: 0.9950000047683716


## Hypothesis: QK circuit gives higher attn scores to smaller numbers

Now that we know how the model goes from the "token we attended to" to the output logits, we want to know how it goes from input tokens to the attn pattern. Once again we can do some transformer math to determine the relevant circuit:


$$
\begin{aligned}
A^{0.0} &= softmax(\frac{eW_{QK}^{0.0}e^T}{\sqrt{d_{head}}})^*\\
&= softmax(\frac{tW_EW_{QK}^{0.0}W_E^Tt^T}{\sqrt{d_{head}}})^*\\
\end{aligned}
$$


Where $W_{QK} = W_QW_K^T$ and $t$ is the input tokens as one hot encoded vectors. Notice $W_EW_{QK}W_E^T$ is the bilinear form that computes the attn scores given destination and source tokens. We call this the "full QK circuit".

In [26]:
full_QK_circuit = model.W_E @ model.QK[0, 0].AB @ model.W_E.T
print(full_QK_circuit.shape)

torch.Size([201, 201])


In [27]:
imshow(
    full_QK_circuit,
    xaxis="src token", yaxis="dest token",
    title=f"attn 0.0 Full QK circuit"
)

Notice that the last row has the most structure. This makes sense: information must be moved to the '=' token, so that is the only important destination token. We can zoom in on this row as a bar graph:

In [28]:
bar(
    full_QK_circuit[-1, :-1],
    xaxis="src token", yaxis="attn score",
    title="0.0 full QK circuit for = destination token"
)

This seems pretty clear: we can read off from the weights that '=' destination token gives higher attn scores to small numbers, and lower scores to big numbers.

# Summary

A 1L, 1 head attn-only transformer with no layernorms, biases, or positional embeddings learns to compute the min of two ints with the following algorithm:

1. Attend to the smaller number with the QK circuit
2. Copy the "token you attended to" to high logits with the OV circuit

The lines of evidence we found for this is:
1. Attn pattern activations: we saw that for the vast majority of dataset examples, the '=' destination token pays ~100% attention to the token corresponding to min(a, b)
3. OV circuit weights: When we multiplied out the full OV circuit, we can see that it tends to give highest logits to the token it attends to.
3. QK circuit weights: When we multiplied out the full QK circuit, we can see that the '=' destination token gives higher attention scores to the smaller source tokens, and lower attention scores to the bigger tokens

## General techniques you can apply to other problems

1. Staring at attention pattern activations
2. Multiplying out full QK and OV circuits