# Problem

How does the model decide that the cell for the current move is not blank?


# Setup (Don't Read This)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer_lens==1.2.1
    %pip install git+https://github.com/neelnanda-io/neel-plotly
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformer_lens==1.2.1
  Downloading transformer_lens-1.2.1-py3-none-any.whl (80 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.5/80.5 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets<3.0.0,>=2.7.1 (from transformer_lens==1.2.1)
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops<0.7.0,>=0.6.0 (from transformer_lens==1.2.1)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fancy-einsum<0.0.4,>=0.0.3 (from transformer_lens==1.2.1)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Collecting jaxtyping<0.3.0,>=0.2.

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/neel-plotly
  Cloning https://github.com/neelnanda-io/neel-plotly to /tmp/pip-req-build-p2anum0o
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly /tmp/pip-req-build-p2anum0o
  Resolved https://github.com/neelnanda-io/neel-plotly to commit 6dc096fdc575da978d3e56489f2347d95cd397e7
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: neel-plotly
  Building wheel for neel-plotly (setup.py) ... [?25l[?25hdone
  Created wheel for neel-plotly: filename=neel_plotly-0.0.0-py3-none-any.whl size=10144 sha256=a0c2acafb3da2618026be795b49b4e25415e30030c1f373ea9c0637224662ecb
  Stored in directory: /tmp/pip-ephem-wheel-cache-l7ltnazt/wheels/60/b4/63/92bac484ab33ad4facf74435b557ca39eb9b9294f27d74e848
Successfully built neel-plotly
Installing collected packages: neel-p

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f163be0a440>

Plotting helper functions:

In [6]:
from neel_plotly import line, scatter, imshow, histogram

# Othello GPT Setup (Copied from Neel Nanda)

## Loading the model

This loads a conversion of the author's synthetic model checkpoint to TransformerLens format. See [this notebook](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello_GPT.ipynb) for how.

In [7]:
import transformer_lens.utils as utils
cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)

In [8]:

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

Downloading synthetic_model.pth:   0%|          | 0.00/101M [00:00<?, ?B/s]

<All keys matched successfully>

Code to load and convert one of the author's checkpoints to TransformerLens:

Testing code for the synthetic checkpoint giving the correct outputs

In [9]:
# An example input
sample_input = torch.tensor([[20, 19, 18, 10, 2, 1, 27, 3, 41, 42, 34, 12, 4, 40, 11, 29, 43, 13, 48, 56, 33, 39, 22, 44, 24, 5, 46, 6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37, 9, 25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7]])
# The argmax of the output (ie the most likely next move from each position)
sample_output = torch.tensor([[21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5,
         33,  5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,
         50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,
         14, 15,  8,  7,  8]])
model(sample_input).argmax(dim=-1)

tensor([[21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5,
         33,  5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,
         50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,
         14, 15,  8,  7,  8]], device='cuda:0')

## Loading Othello Content
Boring setup code to load in 100K sample Othello games, the linear probe, and some utility functions

In [10]:

if IN_COLAB:
    !git clone https://github.com/likenneth/othello_world
    OTHELLO_ROOT = Path("/content/othello_world/")
    import sys
    sys.path.append(str(OTHELLO_ROOT/"mechanistic_interpretability"))
    from mech_interp_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState
else:
    OTHELLO_ROOT = Path("/workspace/othello_world/")
    from tl_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState


Cloning into 'othello_world'...
remote: Enumerating objects: 80, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (22/22), done.[K
remote: Total 80 (delta 20), reused 19 (delta 15), pack-reused 43[K
Unpacking objects: 100% (80/80), 10.14 MiB | 3.71 MiB/s, done.


We load in a big tensor of 100,000 games, each with 60 moves. This is in the format the model wants, with 1-59 representing the 60 moves, and 0 representing pass.

We also load in the same set of games, in the same order, but in "string" format - still a tensor of ints but referring to moves with numbers from 0 to 63 rather than in the model's compressed format of 1 to 59

In [11]:
board_seqs_int = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_int_small.npy"), dtype=torch.long)
board_seqs_string = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_string_small.npy"), dtype=torch.long)

num_games, length_of_game = board_seqs_int.shape
print("Number of games:", num_games,)
print("Length of game:", length_of_game)

Number of games: 100000
Length of game: 60


In [12]:
stoi_indices = [
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
]
alpha = "ABCDEFGH"


def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"


board_labels = list(map(to_board_label, stoi_indices))

## Running the Model

The model's context length is 59, not 60, because it's trained to receive the first 59 moves and predict the final 59 moves (ie `[0:-1]` and `[1:]`. Let's run the model on the first 30 moves of game 0!

In [13]:
moves_int = board_seqs_int[0, :30]

# This is implicitly converted to a batch of size 1
logits = model(moves_int)
print("logits:", logits.shape)

logits: torch.Size([1, 30, 61])


We take the final vector of logits. We convert it to log probs and we then remove the first element (corresponding to passing, and we've filtered out all games with passing) and get the 60 logits. This is 64-4 because the model's vocab is compressed, since the center 4 squares can't be played.

We then convert it to an 8 x 8 grid and plot it, with some tensor magic

In [14]:
logit_vec = logits[0, -1]
log_probs = logit_vec.log_softmax(-1)
# Remove passing
log_probs = log_probs[1:]
assert len(log_probs)==60

temp_board_state = torch.zeros(64, device=logit_vec.device)
# Set all cells to -15 by default, for a very negative log prob - this means the middle cells don't show up as mattering
temp_board_state -= 13.
temp_board_state[stoi_indices] = log_probs

We can now plot this as a board state! We see a crisp distinction from a set of moves that the model clearly thinks are valid (at near uniform probabilities), and a bunch that aren't. Note that by training the model to predict a *uniformly* chosen next move, we incentivise it to be careful about making all valid logits be uniform!

In [15]:
def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    if diverging_scale:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="RdBu", color_continuous_midpoint=0., aspect="equal", **kwargs)
    else:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="Blues", color_continuous_midpoint=None, aspect="equal", **kwargs)
plot_square_as_board(temp_board_state.reshape(8, 8), zmax=0, diverging_scale=False, title="Example Log Probs")

## Making some utilities

At this point, I'll stop and get some aggregate data that will be useful later - a tensor of valid moves, of board states, and a cache of all model activations across 50 games (in practice, you want as much as will comfortably fit into GPU memory). It's really convenient to have the ability to quickly run an experiment across a bunch of games! And one of the great things about small models on algorithmic tasks is that you just can do stuff like this. 

For want of a lack of creativity, let's call these the **focus games**

In [16]:
num_games = 50
focus_games_int = board_seqs_int[:num_games]
focus_games_string = board_seqs_string[:num_games]

A big stack of each move's board state and a big stack of the valid moves in each game (one hot encoded to be a nice tensor)

In [17]:
def one_hot(list_of_ints, num_classes=64):
    out = torch.zeros((num_classes,), dtype=torch.float32)
    out[list_of_ints] = 1.
    return out
focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = torch.zeros((num_games, 60, 64), dtype=torch.float32)
for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())
print("focus states:", focus_states.shape)
print("focus_valid_moves", focus_valid_moves.shape)


focus states: (50, 60, 8, 8)
focus_valid_moves torch.Size([50, 60, 64])


A cache of every model activation and the logits

In [18]:
focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].cuda())

## Using the probe

The training of this probe was kind of a mess, and I'd do a bunch of things differently if doing it again.

<details><summary>Info dump of technical details:</summary>

mode==0 was trained on black to play, ie odd moves, and the classes are \[blank, white, black\] ie \[blank, their colour, my colour\] (I *think*, they could easily be the other way round. This should be easy to sanity check)

mode==1 was trained on white to play, ie even moves, and the classes are \[blank, black, white\] ie \[blank, their colour, my colour\] (I think*)*

mode==2 was trained on all moves, and just doesn't work very well.


The probes were trained on moves 5 to 54 (ie not the early or late moves, because these are weird). I literally did AdamW against cross-entropy loss for each board cell, nothing fancy. You really didn't need to train on 4M games lol, it plateaued well before the end. Which is to be expected, it's just logistic regression!

</details>

But it works!


Let's load in the probe. The shape is [modes, d_model, row, col, options]. The 3 modes are "black to play/odd moves", "white to play/even moves", and "all moves". The 3 options are empty, white and black in that order.

We'll just focus on the black to play probe - it basically just works for the even moves too, once you realise that it's detecting my colour vs their colour!

This means that the options are "empty", "their's" and "mine" in that order

In [19]:
full_linear_probe = torch.load(OTHELLO_ROOT/"mechanistic_interpretability/main_linear_probe.pth")

On move 29 in game 0, we can apply the probe to the model's residual stream after layer 6. Move 29 is black to play.

In [20]:
rows = 8
cols = 8 
options = 3
black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2
linear_probe = torch.zeros(cfg.d_model, rows, cols, options, device="cuda")
linear_probe[..., blank_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 0] + full_linear_probe[white_to_play_index, ..., 0])
linear_probe[..., their_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 1] + full_linear_probe[white_to_play_index, ..., 2])
linear_probe[..., my_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 2] + full_linear_probe[white_to_play_index, ..., 1])

In [21]:
blank_probe = linear_probe[..., 0] - linear_probe[..., 1] * 0.5 - linear_probe[..., 2] * 0.5
my_probe = linear_probe[..., 2] - linear_probe[..., 1]

# Zoom in on one example: current_move == E0

## Activation patching

We can start by zooming in on a specific move for one game, and apply activation patching. This problem seems especially well suited to activation patching since we can easily corrupt the information we care about (the current move), while keeping the rest of the board the same. 

In [22]:
game_index = 4
move = 20
plot_single_board(focus_games_string[game_index, :move+1])
plot_single_board(focus_games_string[game_index, :move].tolist()+[16], title="Corrupted Game (blank plays C0)")

In [23]:
clean_input = focus_games_int[game_index, :move+1].clone()
print("Last 5 clean moves:", int_to_label(clean_input)[-5:])

corrupted_input = focus_games_int[game_index, :move+1].clone()
corrupted_input[-1] = to_int('C0')
print("Last 5 corrupted moves:", int_to_label(corrupted_input)[-5:])

Last 5 clean moves: ['F6', 'B2', 'F4', 'B3', 'E0']
Last 5 corrupted moves: ['F6', 'B2', 'F4', 'B3', 'C0']


In [24]:
clean_logits, clean_cache = model.run_with_cache(clean_input)
print(clean_logits.shape)
print(clean_cache)

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_input)
print(corrupted_logits.shape)
print(corrupted_cache)

torch.Size([1, 21, 61])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_

In [25]:
print(blank_probe.shape)
e0_blank_probe_dir = blank_probe[:, 4, 0]
print(e0_blank_probe_dir.shape)

torch.Size([512, 8, 8])
torch.Size([512])


Since we want to know how it knows the current cell is not blank, we can use the probe score (dot product of resid_post after L4 and e0 blank probe dir) as a metric. I'm using resid_post at layer 4 since the model seems to have learned the blank world model very well by then (see Neel Nanda's original demo).

In [26]:
resid_post_layer = 4
clean_resid_post = clean_cache['resid_post', resid_post_layer][0, -1, :]
clean_e0_blank_score = (clean_resid_post * e0_blank_probe_dir).sum(dim=0)
print("clean E0 blank score:", clean_e0_blank_score.item())

corrupted_resid_post = corrupted_cache['resid_post', resid_post_layer][0, -1, :]
corrupted_e0_blank_score = (corrupted_resid_post * e0_blank_probe_dir).sum(dim=0)
print("corrupted E0 blank score:", corrupted_e0_blank_score.item())

def patching_metric(patched_resid_post):
    if patched_resid_post.ndim == 3:
        patched_resid_post = patched_resid_post[0, -1, :]
    patched_e0_blank_score = (patched_resid_post * e0_blank_probe_dir).sum(dim=0)
    return (patched_e0_blank_score - corrupted_e0_blank_score) / (clean_e0_blank_score - corrupted_e0_blank_score)

print("Clean metric:", patching_metric(clean_resid_post).item())
print("Corrupted metric:", patching_metric(corrupted_resid_post).item())

clean E0 blank score: -7.698649883270264
corrupted E0 blank score: 13.965202331542969
Clean metric: 1.0
Corrupted metric: -0.0


In [27]:
model.reset_hooks()

attn_out_patching_results = []
def patch_attn_out(attn_out, hook):
    attn_out[0, -1, :] = clean_cache[hook.name][0, -1, :]
    return attn_out

resid_post_name_filter = lambda name: 'resid_post' in name
for layer in range(5):
    model.add_hook(utils.get_act_name('attn_out', layer), patch_attn_out)
    _, patched_cache = model.run_with_cache(
        corrupted_input,
        names_filter=resid_post_name_filter,
        return_type=None,
    )
    patched_resid_post = patched_cache['resid_post', resid_post_layer]
    attn_out_patching_results.append(patching_metric(patched_resid_post))


mlp_out_patching_results = []
def patch_mlp_out(mlp_out, hook):
    mlp_out[0, -1, :] = clean_cache[hook.name][0, -1, :]
    return mlp_out

for layer in range(5):
    model.add_hook(utils.get_act_name('mlp_out', layer), patch_mlp_out)
    _, patched_cache = model.run_with_cache(
        corrupted_input,
        names_filter=resid_post_name_filter,
        return_type=None
    )
    patched_resid_post = patched_cache['resid_post', resid_post_layer]
    mlp_out_patching_results.append(patching_metric(patched_resid_post))

line(
    [attn_out_patching_results, mlp_out_patching_results],
     title="attn_out and mlp_out patching results",
     xaxis="Layer", yaxis="Patching metric",
     line_labels=["attn_out", "mlp_out"]
)

This suggests that MLP0 plays the biggest role in determining that E0 is not blank. A natural next step is to zoom in on MLP0 
by applying activation patching on the individual neurons.

In [28]:
model.reset_hooks()

layer_0_neuron_patching_results = []
def patch_post(post, hook, neuron):
    post[0, -1, neuron] = clean_cache[hook.name][0, -1, neuron]
    return post

layer = 0
for neuron in range(cfg.d_mlp):
    hook_fn = partial(patch_post, neuron=neuron)
    model.add_hook(utils.get_act_name('post', layer), hook_fn)
    _, patched_cache = model.run_with_cache(
        corrupted_input,
        names_filter=resid_post_name_filter,
        return_type=None
    )
    patched_resid_post = patched_cache['resid_post', resid_post_layer]
    layer_0_neuron_patching_results.append(patching_metric(patched_resid_post))

line(
    layer_0_neuron_patching_results,
     title="L0 neuron patching results",
     xaxis="Neuron", yaxis="Patching metric"
)

L0N774 stands out. Let's stare at this neurons' activations over the focus games to get a sense of when it activates.

## Stare at L0N774 acts

In [29]:
layer, neuron = 0, 774
neuron_acts_post = focus_cache['post', 0][..., neuron]

imshow(
    neuron_acts_post,
    title=f"L{layer}N{neuron} acts over focus games",
    xaxis="Move", yaxis="Game"
)

It appears to activate very strongly once per game. Perhaps it only activates when E0 is the current move?

In [30]:
neuron_activated_mask = neuron_acts_post > 0.0
current_move_when_active = focus_games_int[:, :-1][neuron_activated_mask.cpu()]
print("Current move when L0774 activates:", int_to_label(current_move_when_active))
print("Fraction of time current_move==E0 when L0N774 activates:", (current_move_when_active == to_int('E0')).float().mean().item())

Current move when L0774 activates: ['E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0']
Fraction of time current_move==E0 when L0N774 activates: 1.0


We can further sanity check this hypothesis by looking at max activating dataset examples.

## Max activating dataset examples

In [31]:
def state_stack_to_one_hot(state_stack):
    one_hot = torch.zeros(
        state_stack.shape[0], # num games
        state_stack.shape[1], # num moves
        8, # rows
        8, # cols
        3, # the two options
        device=state_stack.device,
        dtype=torch.int,
    )
    one_hot[..., 0] = state_stack == 0 # empty
    one_hot[..., 1] = state_stack == -1 # white
    one_hot[..., 2] = state_stack == 1 # black
    
    return one_hot

# We first convert the board states to be in terms of my (+1) and their (-1)
alternating = np.array([-1 if i%2 == 0 else 1 for i in range(focus_games_int.shape[1])])
flipped_focus_states = focus_states * alternating[None, :, None, None]

# We now convert to one hot
focus_states_flipped_one_hot = state_stack_to_one_hot(torch.tensor(flipped_focus_states))

# Take the argmax
focus_states_flipped_value = focus_states_flipped_one_hot.argmax(dim=-1)

# focus_states_flipped_value does this but has their's==2 and mine==1, not 1 and -1. So let's convert, this is boring and you don't need to understand it.
focus_states_flipped_pm1 = torch.zeros_like(focus_states_flipped_value, device="cuda")
focus_states_flipped_pm1[focus_states_flipped_value==2] = 1.
focus_states_flipped_pm1[focus_states_flipped_value==1] = -1.

In [32]:
top_moves = neuron_acts_post.flatten() > neuron_acts_post.flatten().quantile(0.99)
plot_square_as_board(
    focus_states_flipped_pm1[:, :-1].reshape(-1, rows, cols)[top_moves].float().mean(dim=0),
    title="Max activating dataset examples for L0N774",
)
plot_square_as_board(
    focus_states_flipped_pm1[:, :-1].abs().reshape(-1, rows, cols)[top_moves].float().mean(dim=0),
    title="Max activating dataset examples for L0N774 (abs value)",
)

This confirms that E0 is not blank (specifically it’s theirs) in 100% of max activating dataset examples.

A natural hypothesis is that this neuron represents the feature "current_move==E0". We should be able to check this with a spectrum plot:


## Spectrum Plot

In [33]:
import pandas as pd

In [34]:
layer, neuron = 0, 774
neuron_acts_pre = focus_cache['pre', layer][..., neuron]

label = focus_games_int[:, :-1] == to_int('E0')

df = pd.DataFrame({"pre_act": neuron_acts_pre.flatten().tolist(), "label": label.flatten().tolist()})

px.histogram(
    df,
    nbins=100, histnorm="percent",
    x="pre_act", color="label",
    title=f"L{layer}N{neuron} spectrum plot testing current_move==E0"
)

This is extremely clear cut: L0N774 activates very strongly when current_move==E0, and never activates otherwise.


Another thing to check is what earlier bits of the model activate this neuron? Since it only requires information from the current position, it shouldn't need the attn to move information, and thus I suspect it is being activated from the token embedding through the residual connection. We can check this by applying a technique analogous to direct logit attribution: decomposing MLP0's input into components, and by taking dot products of these with the input weights of this neuron.

## How does this neuron get activated?

In [35]:
decomposed_resid_stack, labels = clean_cache.decompose_resid(layer=1, pos_slice=-1, return_labels=True)
decomposed_resid_stack = clean_cache.apply_ln_to_stack(decomposed_resid_stack, layer=1, mlp_input=True, pos_slice=-1)
decomposed_resid_stack = decomposed_resid_stack[:3, -1, :]
print(decomposed_resid_stack.shape)
labels = labels[:-1]
print(labels)

torch.Size([3, 512])
['embed', 'pos_embed', '0_attn_out']


In [36]:
decomposed_dot_w_in = einops.einsum(
    decomposed_resid_stack,
    model.W_in[layer, :, neuron],
    "n_comp d_model, d_model -> n_comp"
)

line(
    decomposed_dot_w_in,
     x=labels,
     yaxis="dot product", xaxis="component",
     title=f"per compoment dot product with L{layer}N{neuron} input weights"
)

Here we see that the embedding is mostly responsible for activating this neuron, as expected. Although the attn output has non zero attribution.


We can further check that this neuron is looking for E0 token embedding by taking the cosine similarity between all token embeddings and this neuron's input weights.

## E0 token embedding aligns with input weights

In [37]:
model.W_E.shape

torch.Size([61, 512])

In [38]:
layer, neuron = 0, 774

w_in = model.W_in[layer, :, neuron]
w_in /= w_in.norm()

W_E_norm = model.W_E[1:, :].detach()
W_E_norm /= W_E_norm.norm(dim=-1, keepdim=True)


state = torch.zeros((rows*cols,), device="cuda")
state[stoi_indices] = (W_E_norm @ w_in)

plot_square_as_board(
    state.reshape(rows,cols,),
    title=f"Cosine similarity of token embeddings and L{layer}N{neuron} input weights",

)

Here we see L0N774 input weights are abnormally aligned (0.36 cos sim) with the E0 token embedding, while being nearly orthogonal with everything else.

## Reading off neuron weights

We can also read off this neurons’ input / output weights in terms of probe directions. We mainly expect the output weights to write in the "e0 is not blank" direction, based on our activation patching results from earlier.

In [39]:
# Scale the probes down to be unit norm per cell
blank_probe_normalised = blank_probe / blank_probe.norm(dim=0, keepdim=True)
my_probe_normalised = my_probe / my_probe.norm(dim=0, keepdim=True)
# Set the center blank probes to 0, since they're never blank so the probe is meaningless
blank_probe_normalised[:, [3, 3, 4, 4], [3, 4, 3, 4]] = 0.

In [40]:
layer, neuron = 0, 774

w_in = model.W_in[layer, :, neuron]
w_in /= w_in.norm()

w_out = model.W_out[layer, neuron, :]
w_out /= w_out.norm()

plot_square_as_board(
    [(w_in[:, None, None] * blank_probe_normalised).sum(dim=0),
    (w_in[:, None, None] * my_probe_normalised).sum(dim=0)],
    title=f"Cosine sim of L{layer}N{neuron} input weights and probes",
    facet_col=0,
    facet_labels=["Blank in", "My in"]
)

plot_square_as_board(
    [(w_out[:, None, None] * blank_probe_normalised).sum(dim=0),
    (w_out[:, None, None] * my_probe_normalised).sum(dim=0)],
    title=f"Cosine sim of L{layer}N{neuron} output weights and probes",
    facet_col=0,
    facet_labels=["Blank out", "My out"]
)

Visually, "E0 is not blank" in the blank output weights does stand out as expected, although the cosine sim is a bit smaller than I would have guessed.


(We also see that the model also writes pretty strongly in the "E1 is theirs". I'm confused about this because this neuron occasionally activates when E1 is blank, but I'm just going to focus on how the model knows E0 is not blank in this notebook.)

We should also check what fraction of each neuron's input and output weights are captured by the probe:

In [41]:
U, S, Vh = torch.svd(torch.cat([my_probe.reshape(cfg.d_model, 64), blank_probe.reshape(cfg.d_model, 64)], dim=1))
# Remove the final four dimensions of U, as the 4 center cells are never blank and so the blank probe is meaningless there
probe_space_basis = U[:, :-4]

print("Fraction of input weights in probe basis:", (w_in @ probe_space_basis).norm().item()**2)
print("Fraction of output weights in probe basis:", (w_out @ probe_space_basis).norm().item()**2)

Fraction of input weights in probe basis: 0.16757284552099438
Fraction of output weights in probe basis: 0.1506617463001625


We see that both the input and output weights are both somewhat well explained by the probes, but it's surprisingly low. This suggests that although this neuron plays a big role in determining that E0 is blank, it’s likely not the whole story. To emphasize this point, we can see that the "E0 blank score" continues to change well after MLP0 (despite the score already being quite low after L0).

In [42]:
e0_blank_scores = []
for layer in range(cfg.n_layers):
    e0_blank_scores.append((clean_cache['resid_post', layer][0, -1, :] * e0_blank_probe_dir).sum(dim=0))

line(
    e0_blank_scores,
     title="E0 blank score after each layer",
     xaxis="Layer", yaxis="e0 blank score"
)

# Find neuron family

I feel pretty convinced that this neuron plays a non trivial part in identifying that when E0 is the current move, E0 is not blank. Intuitively it would make sense for the model to have analogous neurons for the other squares. This model has 2048 neurons in MLP0, and there are only 60 legal moves.


A very nice property of this problem is that it's extremely easy to automatically search for similar neurons. To do this we can look for MLP0 neurons that always activate on the same current move. I'll also add the constraint that they should be active in at least 70% of games to weed out more specific features (ex: "current_move==X AND _).

In [43]:
layer = 0
current_move_not_blank_neurons = []
for neuron in range(cfg.d_mlp):
    neuron_acts = focus_cache['post', layer][..., neuron]
    neuron_active_mask = neuron_acts > 0.0
    current_moves = focus_games_int[:, :-1][neuron_active_mask.cpu()]
    if len(current_moves) >= 35 and torch.all((current_moves - current_moves[0]) == torch.zeros_like(current_moves)).item():
        cur_move = int_to_label(current_moves[0])
        current_move_not_blank_neurons.append([neuron, cur_move])

print(len(current_move_not_blank_neurons))
print(current_move_not_blank_neurons[:5])

57
[[38, 'C1'], [40, 'H4'], [110, 'B2'], [196, 'F2'], [225, 'A5']]


One interesting fact is that this search found 57 (out of 60) neurons, and they all seem to represent unique squares:

In [44]:
print("unique squares:", len(set([sq for neuron, sq in current_move_not_blank_neurons])))

unique squares: 57


Now we should just be able to automatically apply the same experiments that we used on E0. Starting with max dataset examples:

## max activating dataset examples

In [45]:
def plot_max_activating_dataset_examples(layer, neuron, sq_label):
    neuron_acts_post = focus_cache['post', layer][..., neuron]
    top_moves = neuron_acts_post.flatten() > neuron_acts_post.flatten().quantile(0.99)
    plot_square_as_board(
        [focus_states_flipped_pm1[:, :-1].reshape(-1, 64)[top_moves].float().mean(0).reshape(8, 8),
         focus_states_flipped_pm1[:, :-1].abs().reshape(-1, 64)[top_moves].float().mean(0).reshape(8, 8)],
        title=f"L{layer}N{neuron} Max activating dataset examples (expecting {sq_label})",
        facet_col=0,
        facet_labels=["normal", "abs"]
    )

layer = 0
for neuron, sq_label in current_move_not_blank_neurons:
    plot_max_activating_dataset_examples(layer, neuron, sq_label)

Every single plot above shows that the square of choice is non empty in 100% of cases. Although this isn't surprising given how we identified the neurons in the first place. Now we can use spectrum plots to see if they all monosemantically represent "current_move == X" as expected.

## Spectrum plots

In [46]:
def plot_spectrum_plot(layer, neuron, sq_label):
    neuron_acts_pre = focus_cache['pre', layer][..., neuron]

    label = focus_games_int[:, :-1] == to_int(sq_label)

    df = pd.DataFrame({"pre_act": neuron_acts_pre.flatten().tolist(), "label": label.flatten().tolist()})

    px.histogram(
        df,
        nbins=100, histnorm="percent",
        x="pre_act", color="label",
        title=f"L{layer}N{neuron} spectrum plot testing current_move=={sq_label}"
    ).show()

layer = 0
for neuron, sq_label in current_move_not_blank_neurons:
    plot_spectrum_plot(layer, neuron, sq_label)

These are also pretty clear cut: Almost all of them are monosemantic neurons representing the feature we expected. We do however find some strange plots: L0326, L0706, L0807. All of these have a nontrivial percentage of games with "current_move==X" true but negative activations, suggesting these 3 might not belong to this neuron family.


We can now check how the token embeddings align with each neuron's input weights as we did before.

## Embedding alignment

In [47]:
def plot_embedding_cosine_sims(layer, neuron, sq_label):
    w_in = model.W_in[layer, :, neuron]
    w_in /= w_in.norm()

    W_E_norm = model.W_E[1:, :].detach()
    W_E_norm /= W_E_norm.norm(dim=-1, keepdim=True)


    state = torch.zeros((rows*cols,), device="cuda")
    state[stoi_indices] = (W_E_norm @ w_in)

    plot_square_as_board(
        state.reshape(rows,cols,),
        title=f"Cosine similarity of token embeddings and L{layer}N{neuron} input weights (expecting {sq_label})",
    )

layer = 0
for neuron, sq_label in current_move_not_blank_neurons:
    plot_embedding_cosine_sims(layer, neuron, sq_label)

Not only does every neuron's input weights clearly align most with the expected square, but all of the cosine sims are pretty similar (0.3-0.5).


Finally we can read off neuron weight using the probes. I'll just look at output weights to check if they all write a non trivial amount in the "X is blank" probe dir:

## Reading off neuron weights

In [48]:
def dot_weights_with_probes(layer, neuron, sq_label):
    w_out = model.W_out[layer, neuron, :]
    w_out /= w_out.norm()

    plot_square_as_board(
        [(w_out[:, None, None] * blank_probe_normalised).sum(dim=0),
        (w_out[:, None, None] * my_probe_normalised).sum(dim=0)],
        title=f"Cosine sim of L{layer}N{neuron} output weights and probes (expecting {sq_label})",
        facet_col=0,
        facet_labels=["Blank out", "My out"]
    )

layer = 0
for neuron, sq_label in current_move_not_blank_neurons:
    dot_weights_with_probes(layer, neuron, sq_label)

Notice that they all seem to write in the "X is blank" probe direction, similar to L0N774. Finally we can also look at what fraction of each neuron's input and output weights are captured by the probe:

## frac in probe basis

In [49]:
U, S, Vh = torch.svd(torch.cat([my_probe.reshape(cfg.d_model, 64), blank_probe.reshape(cfg.d_model, 64)], dim=1))
# Remove the final four dimensions of U, as the 4 center cells are never blank and so the blank probe is meaningless there
probe_space_basis = U[:, :-4]
 
layer = 0
for neuron, sq_label in current_move_not_blank_neurons:
    w_in = model.W_in[layer, :, neuron]
    w_in /= w_in.norm()

    w_out = model.W_out[layer, neuron, :]
    w_out /= w_out.norm()
    print(f"L{layer}N{neuron} Fraction of weights in probe basis: input: {(w_in @ probe_space_basis).norm().item()**2}, output: {(w_out @ probe_space_basis).norm().item()**2}")

L0N38 Fraction of weights in probe basis: input: 0.15633016792720866, output: 0.15004583526653992
L0N40 Fraction of weights in probe basis: input: 0.21677476977717802, output: 0.19391775179020332
L0N110 Fraction of weights in probe basis: input: 0.175046855301475, output: 0.16088858133585493
L0N196 Fraction of weights in probe basis: input: 0.14801793828391752, output: 0.13849427231957012
L0N225 Fraction of weights in probe basis: input: 0.15649982453727862, output: 0.12975682746727912
L0N259 Fraction of weights in probe basis: input: 0.16448562315198334, output: 0.1650257932826804
L0N260 Fraction of weights in probe basis: input: 0.20810927982497507, output: 0.175526836556835
L0N305 Fraction of weights in probe basis: input: 0.16390309308537, output: 0.17455420804077448
L0N317 Fraction of weights in probe basis: input: 0.21535058510915306, output: 0.1667296841697672
L0N318 Fraction of weights in probe basis: input: 0.1383051265277766, output: 0.15293584461284127
L0N326 Fraction of wei

We see very similar numbers for all neurons (0.15-0.25), weakly suggesting that they may play analogous roles for different squares.

# Summary

We find that neuron L0N774 plays a big role in determining that current_move=E0 is not blank by implementing a lookup table: "If the current move is E0, write E0 is not blank". We then extend the same techniques to automatically find an analogous family of monosemantic neurons in MLP0: "if the current move is X, write X is not blank".


The lines of evidence we used for L0N774:
- Activation patching: when we corrupt a game with current_move==E0 to current_move==C0 (all previous moves stay the same), we find that patching in the output of L0N774 recovers the most "e0 is blank score”
- Staring at neurons: For the 50 focus game, we only see that this neuron only activates when current_move==E0
- Max activating dataset examples: In 100% of max activating dataset examples, E0 is not empty and theirs
- Neuron Attribution: When we attribute what early components activate this neuron, we the token embedding at the current position stands out
- Spectrum plot: in a spectrum plot for L0N774 testing "E0=current_move", we see very clearly that it activates strongly when current_move=E0, and doesn't activate otherwise
- Input weights alignment with embeddings: When we take the cosine similarity of all of the token embeddings with the input weights of L0N774, it clearly aligns with E0 the most
- The output weights have relatively large negative cosine similarity with the "e0 is blank" probe direction


The lines of evidence that the model has a family of these neurons for almost every square:
- When we checked for neurons that always activate on the same move, we found 57 unique neurons (almost all 60 legal squares)
- All of these neurons have max activating dataset examples such that the square of choice is non-empty 100% of the time (specifically theirs)
- Almost all have clear divides in spectrum plots for the square they are testing (activates strongly when current_move==X, doesn't activate otherwise)
- Each neuron's input weights has similar high cosine similarity (0.3-0.5) with X’s token embedding
- Each neurons output weights have a similar relatively high cosine sim (0.15-0.25) with the "X is blank" probe direction
- Each of these neurons has similar non trivial fraction (0.15-0.25) of input / output weights in probe basis


## General techniques you can apply to other problems

* Activation patching
* Staring at neuron activations
* Spectrum plots
* Direct Attribution in probe directions
* Direct Attribution in neuron input weight directions
* Max activating dataset examples for a neuron
* Reading off neuron weights using probes