# Problem

Interpret a model trained to predict the output to a simple code functions: predicting the bold text in problems like
$$
a = [1, 2, 3] \\
a[2] = 4 \\
a -> [\textbf{1, 2, 4}] \\
$$

# Setup
(No need to read)

In [1]:
!git clone https://github.com/ckkissane/mech-interp-practice.git

Cloning into 'mech-interp-practice'...
remote: Enumerating objects: 327, done.[K
remote: Counting objects: 100% (146/146), done.[K
remote: Compressing objects: 100% (132/132), done.[K
remote: Total 327 (delta 85), reused 32 (delta 14), pack-reused 181[K
Receiving objects: 100% (327/327), 38.61 MiB | 9.83 MiB/s, done.
Resolving deltas: 100% (164/164), done.


In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-5a0mas49
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-5a0mas49
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 54d548de4995a1ecc5b01b9c03aceaf0966c0eb3
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.14.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━

In [3]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [4]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [5]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [6]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [7]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

In [8]:
def visualize_attn_patterns(heads, local_tokens, local_cache, title: str = ""):
    labels = []
    patterns = []
    batch_index = 0

    for head in heads:
        if isinstance(head, tuple):
            layer, head_index = head
        else:
            layer, head_index = head // model.cfg.n_heads, head % model.cfg.n_heads
        patterns.append(local_cache["pattern", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    patterns = torch.stack(patterns, dim=-1)
    attn_viz = pysvelte.AttentionMulti(tokens=model.to_str_tokens(local_tokens[batch_index]), attention=patterns, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attn_viz.show()

# Load Model

In [9]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f3e4a8002b0>

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


The smallest model I found that can learn this task is a 2L, 1H, attn-only transformer with no biases or layernorms. I also use [shortformer](https://github.com/neelnanda-io/TransformerLens/blob/main/further_comments.md#shortformer-attention-positional_embeddings_type--shortformer) positional embeddings to make things cleaner.

![picture](https://drive.google.com/uc?id=115AlOlXFE24PMYugkPeAadrkrh-hkOcS)

The model has already been trained and is loaded into this notebook below:

In [11]:
LIST_LEN = 3
MAX_NUM = 50
cfg = HookedTransformerConfig(
    n_layers=2,
    n_heads=1,
    d_model=128,
    d_head=128,
    n_ctx=2*LIST_LEN+3+2, #BOS a1 a2 a3 \n idx elt \n b1 b2 b3
    d_vocab=MAX_NUM+2, #0,...,MAX_NUM-1, BOS, \n
    d_vocab_out=MAX_NUM,
    attn_only=True,
    normalization_type=None,
    positional_embedding_type="shortformer",
    device=device,
    seed=0
)

model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_attn_input): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [12]:
filename = "mech-interp-practice/models/array_indexing_model.pt"
state_dict = torch.load(filename)
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

# Task description

Given an array 'a', and integers 'idx' and 'elt', the model is trained to update a[idx] = elt, while keeping the other elements of 'a' the same. The input is always of the form:


```
BOS a_0 a_1 a_2 \n idx elt \n b_0 b_1 b_2
```


* BOS (50) is a special token at the start of every sequence
* '\n' (51) is supposed to represent a newline character. These are also always at the same positions for every sequence. (You can just think of it as a special "MID" token too)
* [a_0, a_1, a_2] is the original list before updating. These tokens should range between [0, 49] inclusive.
* “idx” is the index of the list we are to update. This should always be in {0, 1, 2}.
* “elt” is the new element we are writing to “idx” in the list. This ranges between [0, 49] inclusive.
* [b_0, b_1, b_2] is the correct new list after the update: b_{idx} = elt. Otherwise b_i = a_i for i != idx.


For a concrete example, consider:



```
BOS 1 2 3 \n 2 4 \n 1 2 4
```



In this example [1,2,3] is the original list, idx=2 and elt=4. Thus the model should just update a[2]=4, predicting [1, 2, 4].



The list length is fixed to 3 for every example that the model sees in training. The model is trained to predict [b_0, b_1, b_2] at positions [\n, b_0, b_1] respectively.



Below I prove a data loader and sample tokens that you can use to start investigating the model.

In [13]:
def make_data_generator(cfg, batch_size, seed=0):
    torch.manual_seed(seed)
    BOS_TOKEN = cfg.d_vocab-2
    NEW_LINE_TOKEN = cfg.d_vocab-1
    while True:
        bos_vec = (torch.ones(batch_size) * BOS_TOKEN)[:, None]
        nl_vec = (torch.ones(batch_size) * NEW_LINE_TOKEN)[:, None]

        list_toks = torch.randint(0, MAX_NUM, (batch_size, LIST_LEN))

        idx_tok = random.randint(0, LIST_LEN-1)
        idx_vec = (torch.ones(batch_size) * idx_tok)[:, None]

        elt_tok = random.randint(0, MAX_NUM-1)
        elt_vec = (torch.ones(batch_size) * elt_tok)[:, None]

        ans_toks = list_toks.clone()
        ans_toks[:, idx_tok] = elt_tok

        x = torch.cat([bos_vec, list_toks, nl_vec, idx_vec, elt_vec, nl_vec, ans_toks], dim=-1).to(torch.long)
        yield x

batch_size = 256
data_loader = make_data_generator(cfg, batch_size, seed=42)

test_sample = next(data_loader).to(device)
print(test_sample.shape)
print(test_sample[:5])

torch.Size([256, 11])
tensor([[50, 42, 17, 26, 51,  1, 48, 51, 42, 48, 26],
        [50, 14, 26, 35, 51,  1, 48, 51, 14, 48, 35],
        [50, 20, 24,  0, 51,  1, 48, 51, 20, 48,  0],
        [50, 13, 28, 14, 51,  1, 48, 51, 13, 48, 14],
        [50, 10,  4, 31, 51,  1, 48, 51, 10, 48, 31]], device='cuda:0')


The model was trained to minimize cross entropy loss. I provide the exact loss function used during training below:

In [14]:
def loss_fn(logits, tokens):
    logits = logits[:, -LIST_LEN-1:-1, :]
    logits = logits.to(torch.float64)
    labels = tokens[:, -LIST_LEN:]
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    return -correct_log_probs.mean()

It can be useful to convert tokens to a list of strings for plotly / CircuitsVis labels. Here is a function that you can optionally use for this:

In [15]:
def tokens_to_labels(tokens):
    if tokens.ndim==2:
        tokens = tokens[0]
    res = [f"{tok}_{i}" for i, tok in enumerate(tokens)]
    res[0] = "BOS"
    res[LIST_LEN+1] = "nl_0"
    res[-LIST_LEN-1] = "nl_1"
    return res

print(tokens_to_labels(test_sample[1]))

['BOS', '14_1', '26_2', '35_3', 'nl_0', '1_5', '48_6', 'nl_1', '14_8', '48_9', '35_10']
