# Problem

Reverse engineer how a transformer trained without positional embeddings predicts the previous token.

# Setup
(No need to read)

In [17]:
!git clone https://github.com/ckkissane/mech-interp-practice.git

fatal: destination path 'mech-interp-practice' already exists and is not an empty directory.


In [18]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-wrpz7vi5
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-wrpz7vi5
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 10d2f8a026d73eada861c7d51064f7e24d8f482c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [19]:
try:
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
except:
    import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-n4dz_gmr
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-n4dz_gmr
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit cc216772a66af819ff3a77038e53134f3e073af4
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [20]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [21]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import circuitsvis as cv
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [22]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [23]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [24]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

# Load Model

In [25]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7b52b6b04a30>

In [26]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


The model is a 2L, 1H (per layer), attn-only transformer with no biases, layernorms, or positional embeddings.

![picture](https://drive.google.com/uc?id=115AlOlXFE24PMYugkPeAadrkrh-hkOcS)

It is loaded into this notebook below:

In [27]:
MAX_NUM=100
cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=1,
    d_model=128,
    d_head=128,
    attn_only=True,
    normalization_type=None,
    n_ctx=10,
    d_vocab=MAX_NUM+1, # 0,...,MAX_NUM-1, BOS
    device=device,
    seed=0
)

model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [28]:
filename = "mech-interp-practice/models/rederive_position_model.pt"
state_dict = torch.load(filename)
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

# Task description

The model is given a sequence of tokens and trained to predict the previous token at each position. The inputs are of the form

```
[BOS, a_0, a_1, ..., a_8]
```

* BOS (100) is a special token at the beginning of each sequence
* Each a_i is a random token between [0, 99] inclusive

Each example is fixed at length 10 (including BOS). At each position $i \in [1, ..., 9]$ (0 indexed), the model is trained to predict the token at position i-1. It isn't expected to make a prediction at position 0, since there is no previous token to predict.

A data loader and example tokens are provided below for you to begin your investigation:

In [29]:
BOS_TOKEN = cfg.d_vocab-1
def make_data_generator(cfg, batch_size, seed=0):
    torch.manual_seed(seed)
    while True:
        tokens = torch.randint(0, cfg.d_vocab, (batch_size, cfg.n_ctx))
        tokens[:, 0] = BOS_TOKEN
        yield tokens

batch_size = 256
data_loader = make_data_generator(cfg, batch_size, seed=42)
test_sample = next(data_loader).to(device)
print(test_sample.shape)
print(test_sample[:5])

torch.Size([256, 10])
tensor([[100,  32,  94,  55,   2,  21,  10,  12,  47,  30],
        [100,  38,  38,  58,  86,   4,  43,  66,  86,  46],
        [100,   6,  11,  26,  52,  96,  42,  53,  85,  90],
        [100,  32,  48,  28,  66,  74,  24,  60,  80,  68],
        [100,  63,  81,  89,  23,  26,  15,  26,  78,  91]], device='cuda:0')


The model is trained to minimize cross entropy loss. The exact loss function used during training is provided below:

In [30]:
def loss_fn(logits, tokens):
    logits = logits[:, 1:, :]
    logits = logits.to(torch.float64)
    labels = tokens[:, :-1]
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    return -correct_log_probs.mean()