# Problem

Interpret a 0 Layer transformer trained to predict the next token on a tiny dataset (see 'Task Description' section for full dataset).

Concretely, a solution looks like:
1. Describing a mechanism for how the model is able to minimize loss when predicting the next token. Ideally your explanation should include an interpretation of the model weights.
2. Providing evidence that this interpretation is actually true, ideally using some mechanistic interpretability technique(s). Hint: This model and dataset are small enough that you should be able to understand all of the weights.

Since this is the first problem in this sequence, feel free to jump straight to the solution to get a feel of what a solution to these problems looks like. However I still recommend at least poking around the model to build some intuition / guesses before doing this.

# Setup
(No need to read)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-sjdi7pty
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-sjdi7pty
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 10d2f8a026d73eada861c7d51064f7e24d8f482c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.14.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━

In [2]:
try:
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
except:
    import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-u6tqemqs
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-u6tqemqs
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit cc216772a66af819ff3a77038e53134f3e073af4
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting importlib-metadata<6.0.0,>=5.1.0 (from circuitsvis==0.0.0)
  Downloading importlib_metadata-5.2.0-py3-none-any.whl (21 kB)
Building wheels for collected packages: circuitsvis
  Building wheel for circuitsvis (pyproject.toml) ... [?25l[?25hdone
  Created wheel for circuitsvis: filename=circuitsvis-0.0.0-py3-none-any.whl size=1808565 sha256=00f6ea06383f60ca2f217e993df066f72a161cc8394eab833e

In [3]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [4]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import circuitsvis as cv
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [5]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [6]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [7]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

In [8]:
def disable_biases(model):
    for name, param in model.named_parameters():
        if 'b_' in name:
            param.requires_grad = False

def disable_pos_embed(model):
    assert model.cfg.positional_embedding_type == "standard"
    model.pos_embed.W_pos = nn.Parameter(torch.zeros_like(model.pos_embed.W_pos))
    model.pos_embed.W_pos.requires_grad = False

# Task Description

In this notebook I train 0L transformer to predict the next token on a tiny corpus of data. Your job is to reverse engineer what this model actually learns.

The dataset is extremely tiny, which is unrealistic, but makes the analysis way easier. Below 'tokens' is the full dataset that we will use to train this model.

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [10]:
text = """Before moving on to more complex models, it’s useful to briefly consider a zero-layer transformer. Such a model takes a token, embeds it, unembeds it to produce logits predicting the next token"""
vocab = set(text)
print(vocab)
char_to_token = {ch: i for i, ch in enumerate(vocab)}
print(char_to_token)
tokens = torch.tensor([char_to_token[ch] for ch in text]).unsqueeze(0).to(device)
print(tokens.shape)
print(tokens)

{'h', 'i', 'B', 'y', 'x', 'o', 'd', 'c', 'm', '.', 'a', 'r', 'u', 'f', 'n', '-', 'p', 's', 'z', 'l', 't', '’', 'e', 'S', 'k', 'g', ',', 'v', 'b', ' '}
{'h': 0, 'i': 1, 'B': 2, 'y': 3, 'x': 4, 'o': 5, 'd': 6, 'c': 7, 'm': 8, '.': 9, 'a': 10, 'r': 11, 'u': 12, 'f': 13, 'n': 14, '-': 15, 'p': 16, 's': 17, 'z': 18, 'l': 19, 't': 20, '’': 21, 'e': 22, 'S': 23, 'k': 24, 'g': 25, ',': 26, 'v': 27, 'b': 28, ' ': 29}
torch.Size([1, 193])
tensor([[ 2, 22, 13,  5, 11, 22, 29,  8,  5, 27,  1, 14, 25, 29,  5, 14, 29, 20,
          5, 29,  8,  5, 11, 22, 29,  7,  5,  8, 16, 19, 22,  4, 29,  8,  5,  6,
         22, 19, 17, 26, 29,  1, 20, 21, 17, 29, 12, 17, 22, 13, 12, 19, 29, 20,
          5, 29, 28, 11,  1, 22, 13, 19,  3, 29,  7,  5, 14, 17,  1,  6, 22, 11,
         29, 10, 29, 18, 22, 11,  5, 15, 19, 10,  3, 22, 11, 29, 20, 11, 10, 14,
         17, 13,  5, 11,  8, 22, 11,  9, 29, 23, 12,  7,  0, 29, 10, 29,  8,  5,
          6, 22, 19, 29, 20, 10, 24, 22, 17, 29, 10, 29, 20,  5, 24, 22, 14, 26,


# Load Model

A 0L transformer just takes a token, embeds it, unembeds it to produce logits predicting the next token:

In [11]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7c39e40b2ec0>

In [12]:
@dataclass
class Config:
    d_vocab: int = len(vocab)
    d_model: int = 16
    seed = 0

cfg = Config()
print(cfg)

Config(d_vocab=30, d_model=16)


In [13]:
class ZeroLayerTransformer(HookedRootModule):
    def __init__(self, cfg):
        super().__init__()
        torch.manual_seed(cfg.seed)
        self.W_E = nn.Parameter(torch.randn((cfg.d_vocab, cfg.d_model)) / np.sqrt(cfg.d_model))
        self.W_U = nn.Parameter(torch.randn((cfg.d_model, cfg.d_vocab)) / np.sqrt(cfg.d_model))

        self.hook_embed = HookPoint()
        super().setup()

    def forward(self, x):
        embed = self.hook_embed(self.W_E[x])
        return embed @ self.W_U

model = ZeroLayerTransformer(cfg).to(device)
print(model)

ZeroLayerTransformer(
  (hook_embed): HookPoint()
)


In [14]:
for name, param in model.named_parameters():
    print(name, param.shape, param.requires_grad)

W_E torch.Size([30, 16]) True
W_U torch.Size([16, 30]) True


# Training

Training only takes ~10 seconds to converge, so we just train in this notebook:

## Loss Fn

The model is trained to minimize cross entropy loss:

In [15]:
def loss_fn(logits, tokens):
    logits = logits[:, :-1]
    labels = tokens[:, 1:]

    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    return -correct_log_probs.mean()

with torch.inference_mode():
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    print(loss)

tensor(3.4129, device='cuda:0')


In [16]:
print("uniform loss:", np.log(cfg.d_vocab))

uniform loss: 3.4011973816621555


## Setup optimizer

In [17]:
lr = 1e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

## Training Loop

In [18]:
num_epochs = 10_000

train_losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
    logits = model(tokens)
    loss = loss_fn(logits, tokens)
    loss.backward()
    train_losses.append(loss.item())

    optimizer.step()
    optimizer.zero_grad()

    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, train loss: {loss.item()}")

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch: 0, train loss: 3.412891387939453
Epoch: 100, train loss: 3.0832862854003906
Epoch: 200, train loss: 2.645434617996216
Epoch: 300, train loss: 2.256800651550293
Epoch: 400, train loss: 2.0204482078552246
Epoch: 500, train loss: 1.8752901554107666
Epoch: 600, train loss: 1.783310890197754
Epoch: 700, train loss: 1.724074363708496
Epoch: 800, train loss: 1.6863884925842285
Epoch: 900, train loss: 1.662880301475525
Epoch: 1000, train loss: 1.6479673385620117
Epoch: 1100, train loss: 1.6381254196166992
Epoch: 1200, train loss: 1.631356954574585
Epoch: 1300, train loss: 1.6265252828598022
Epoch: 1400, train loss: 1.6229636669158936
Epoch: 1500, train loss: 1.6202653646469116
Epoch: 1600, train loss: 1.6181726455688477
Epoch: 1700, train loss: 1.616517424583435
Epoch: 1800, train loss: 1.6151856184005737
Epoch: 1900, train loss: 1.614098310470581
Epoch: 2000, train loss: 1.6131986379623413
Epoch: 2100, train loss: 1.6124461889266968
Epoch: 2200, train loss: 1.6118104457855225
Epoch: 23

In [19]:
line(
    train_losses,
    title="Loss Curve",
    xaxis="Epoch", yaxis="Loss"
)

# Solution

The key here is to reason about the constraints of this tiny model. When this model makes a prediction for the next token, the only information it has access to is the current token. Therefore the best the model can do is to learn bigram statistics for the training data. See [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html#zero-layer-transformers) for a better explanation.


How do we show this? Note that the factored linear map $W_EW_U$ can be thought of as a lookup table from input token to output logits. Thus if we multiply out these matrices we should expect to see an approximation of the bigram statistics (before softmax). Let's compute the bigram statistics and compare this to the model weights.

In [20]:
token_to_char = {i: ch for i, ch in enumerate(vocab)}
token_labels = list(token_to_char.values())

In [21]:
bigram_freqs = torch.zeros((cfg.d_vocab, cfg.d_vocab), device=device)
for ch1, ch2 in zip(tokens[0], tokens[0, 1:]):
    bigram_freqs[ch1, ch2] += 1

imshow(
    bigram_freqs,
    title="Bigram frequency table (for training corpus)",
    xaxis="next token", yaxis="input token",
    x=token_labels, y=token_labels
)

Now let's compare to the model weights, after softmax:

In [22]:
W_E = model.W_E
print(W_E.shape)
W_U = model.W_U
print(W_U.shape)

torch.Size([30, 16])
torch.Size([16, 30])


In [23]:
imshow(
    [bigram_freqs / bigram_freqs.sum(dim=-1, keepdim=True), (W_E @ W_U).softmax(dim=-1)],
    title="Bigram statistics vs W_E @ W_U",
    xaxis="next token", yaxis="input token",
    facet_col=0,
    facet_labels=["Bigram statistics", "W_E @ W_U (post softmax)"],
    x=token_labels, y=token_labels
)

As you can see they look almost identical!

# Summary

Although this was an extremely simple model, the general lesson is that we were able to leverage our understanding of the constraints of the model architecture to significantly prune our hypothesis space.


We can also generalize this lesson to larger transformers. All transformers should be able to approximate bigram statistics with this same $W_EW_U$ 'direct path' term (although we haven't shown that they actually do in practice).