# Problem

Interpret a 2L MLP (one hidden layer) trained to do modular addition.

# Setup
(No need to read)

In [45]:
!git clone https://github.com/ckkissane/mech-interp-practice.git

fatal: destination path 'mech-interp-practice' already exists and is not an empty directory.


In [46]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-wju6gc72
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-wju6gc72
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 10d2f8a026d73eada861c7d51064f7e24d8f482c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone

## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:2 https://deb.nodesource.com/node_16.x jammy InRelease
Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:5 htt

In [47]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [48]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [49]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [50]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [51]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

In [52]:
def visualize_attn_patterns(heads, local_tokens, local_cache, title: str = ""):
    labels = []
    patterns = []
    batch_index = 0

    for head in heads:
        if isinstance(head, tuple):
            layer, head_index = head
        else:
            layer, head_index = head // model.cfg.n_heads, head % model.cfg.n_heads
        patterns.append(local_cache["pattern", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    patterns = torch.stack(patterns, dim=-1)
    attn_viz = pysvelte.AttentionMulti(tokens=model.to_str_tokens(local_tokens[batch_index]), attention=patterns, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attn_viz.show()

# Load Model

This model is a 2L MLP (one hidden layer). An embedding is applied to the inputs (a, b) and these embedding vectors are added before being passed through a standard MLP layer with ReLU activation function:

![picture](https://drive.google.com/uc?id=1mA_0VmZfQ8QnviY8mzuDrss8CXz_iG1g)

It has already been trained and is loaded into this notebook below:

In [53]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7b193d5ce9e0>

In [54]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [55]:
p = 113
@dataclass
class Config:
    d_model: int = 128
    d_mlp: int = 512
    d_vocab: int = p
    d_vocab_out: int = p
    device: int = device
    seed: int = 0

cfg = Config()
print(cfg)

Config(d_model=128, d_mlp=512, d_vocab=113, d_vocab_out=113, device='cuda', seed=0)


In [56]:
class OneLayerMLP(HookedRootModule):
    def __init__(self, cfg):
        super().__init__()
        torch.manual_seed(cfg.seed)

        self.W_E = nn.Parameter(torch.randn(cfg.d_vocab, cfg.d_model) / np.sqrt(cfg.d_model))
        self.W_in = nn.Parameter(torch.randn(cfg.d_model, cfg.d_mlp) / np.sqrt(cfg.d_model))
        self.W_U = nn.Parameter(torch.randn(cfg.d_mlp, cfg.d_vocab_out) / np.sqrt(cfg.d_mlp))

        self.hook_embed = HookPoint()
        self.hook_pre = HookPoint()
        self.hook_post = HookPoint()

        super().setup()

    def forward(self, data):
        a = data[:, 0]
        b = data[:, 1]
        embed = self.hook_embed(self.W_E[a] + self.W_E[b])
        pre = self.hook_pre(embed @ self.W_in)
        post = self.hook_post(F.relu(pre))
        logits = post @ self.W_U
        return logits

model = OneLayerMLP(cfg).to(cfg.device)
print(model)

OneLayerMLP(
  (hook_embed): HookPoint()
  (hook_pre): HookPoint()
  (hook_post): HookPoint()
)


In [57]:
filename = "mech-interp-practice/models/modular_addition_2l_mlp_model.pt"
state_dict = torch.load(filename)
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

# Task description

The model was trained to do modular addition. Given integers (a,b), it computes (a + b) mod p where p=113, and 'a' and 'b' are both integers in the range [0, p-1] inclusive.
The model is given inputs of the form


```
[a, b]
```


Under the hood, both a and b are mapped to embedding vectors and then added together before being inputted into the MLP layer (see the diagram in the 'Load Model' section above). Note this is different from a normal transformer, since the MLP is only applied to one residual stream vector, rather than multiple residual stream vectors for each sequence position.


Below I provide the following datasets that you can use to investigate the model:


* 'dataset' is the full dataset - every possible (a, b) pair. 'labels' are the corresponding answers, a + b mod p, for each pair in 'dataset'
* 'train_data' is the subset of 'dataset' used for training. This was about 30% of the full dataset. 'train_labels' are the corresponding labels.
* 'test_data' consists of the remaining (a,b) pairs in ‘dataset’ not seen during training. 'test_labels' are the corresponding labels.

In [58]:
a_vec = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
print(a_vec.shape)
print(a_vec[:5])

b_vec = einops.repeat(torch.arange(p), "i -> (j i)", j=p)
print(b_vec.shape)
print(b_vec[:5])

dataset = torch.stack([a_vec, b_vec], dim=-1).to(device)
print(dataset.shape)
print(dataset)

torch.Size([12769])
tensor([0, 0, 0, 0, 0])
torch.Size([12769])
tensor([0, 1, 2, 3, 4])
torch.Size([12769, 2])
tensor([[  0,   0],
        [  0,   1],
        [  0,   2],
        ...,
        [112, 110],
        [112, 111],
        [112, 112]], device='cuda:0')


In [59]:
labels = (dataset[:, 0] + dataset[:, 1]) % p
print(labels.shape)
print(labels[:5])

torch.Size([12769])
tensor([0, 1, 2, 3, 4], device='cuda:0')


In [60]:
DATA_SEED = 0
torch.manual_seed(DATA_SEED)

train_frac = 0.3
cutoff = int(train_frac * dataset.shape[0])

indices = torch.randperm(dataset.shape[0])

train_indices = indices[:cutoff]
train_data = dataset[train_indices]
train_labels = labels[train_indices]
print(train_data.shape)
print(train_data[:5])
print(train_labels.shape)
print(train_labels[:5])

test_indices = indices[cutoff:]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(test_data.shape)
print(test_data[:5])
print(test_labels.shape)
print(test_labels[:5])

torch.Size([3830, 2])
tensor([[ 37,   1],
        [ 23,   9],
        [  2,  43],
        [  7,  34],
        [ 42, 101]], device='cuda:0')
torch.Size([3830])
tensor([38, 32, 45, 41, 30], device='cuda:0')
torch.Size([8939, 2])
tensor([[ 55,  55],
        [ 69, 109],
        [ 13,  17],
        [103,  54],
        [ 78, 112]], device='cuda:0')
torch.Size([8939])
tensor([110,  65,  30,  44,  77], device='cuda:0')


The model was trained to minimize cross entropy loss. The exact loss function used during training is provided below:

In [61]:
def loss_fn(logits, labels):
    if logits.ndim == 3:
        logits = logits[:, -1, :]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])
    return -correct_log_probs.mean()

# Solution

## Sanity check

The model was only trained on 30% of the data. Let's first sanity check that it has generalized well enough to perform well in the test distribution:

In [62]:
with torch.inference_mode():
    test_logits = model(test_data)
    preds = test_logits.argmax(dim=-1)
    test_acc = (preds == test_labels).float().mean()
    print("Test accuracy:", test_acc.item())

Test accuracy: 1.0


## Stare at neuron activations

A common technique in mechanistic interpretability is just staring at interpretable activations (logits, attn patterns, neurons). Since this model is an MLP, a natural first step here is to stare at the post-relu neuron activations over the entire dataset.

In [63]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.shape)
print(cache.keys())

torch.Size([12769, 113])
dict_keys(['hook_embed', 'hook_pre', 'hook_post'])


In [64]:
neuron_acts_post = cache['hook_post']
print(neuron_acts_post.shape)

torch.Size([12769, 512])


In [65]:
for _ in range(3):
    r = random.randint(0, cfg.d_mlp-1)
    imshow(
        neuron_acts_post[..., r].reshape(p, p),
        title=f"N{r} activations over entire dataset",
        xaxis="b", yaxis="a"
    )

There's clearly some structure here: notice that the activations appear periodic.

When trying to guess what the network is doing, it's natural to connect periodicity to sines / cosines / general trig stuff. One clever idea from Neel Nanda's Grokking work is to check if the model is working in the fourier basis. To do this, we can just apply a change of basis to these activations and see if we can understand them as a linear combination of trig terms.

## Fourier basis

In [66]:
fourier_basis = torch.ones((p, p), device=device)
fourier_basis_names = ["Const"]
for k in range(1, p//2+1):
    w_k = 2 * torch.pi * k / p
    fourier_basis[2*k-1] = torch.sin(w_k * torch.arange(p))
    fourier_basis[2*k] = torch.cos(w_k * torch.arange(p))

    fourier_basis_names.append(f'Sin {k}')
    fourier_basis_names.append(f'Cos {k}')

fourier_basis /= fourier_basis.norm(dim=-1, keepdim=True)
print(fourier_basis.shape)
print(fourier_basis_names[:5])

torch.Size([113, 113])
['Const', 'Sin 1', 'Cos 1', 'Sin 2', 'Cos 2']


In [67]:
lines(
    fourier_basis[:5],
    title="First 5 fourier basis terms",
    labels=fourier_basis_names[:5]
)

Since the neuron activation heatmaps above are a function of both a and b, we write a function to do a 2D fourier transform.

In [68]:
def fft2d(acts):
    """
    Args:
        acts: tensor which we apply the 2D FFT. Has shape [p, p, ...]
    Returns:
        out: acts after 2D FFT. Has shape [p, p, ...]
    """
    return einops.einsum(
        acts, fourier_basis, fourier_basis,
        "px py ..., i px, j py -> i j ..."
    )

Note that this is just a change of basis: we can apply this to any 2D heatmap of shape [p, p] (in other words, we can write any pxp heatmap as the linear combination of pxp fourier heatmaps). But we expect this framing to be especially fruitful for this model since we suspect it's deliberately using some trig based algorithm.

## Stare at neurons in fourier basis

Now we can just apply a 2D fourier transform to the same neuron activations, and see if if they are interpretable in the fourier basis:

In [69]:
for _ in range(3):
    r = random.randint(0, cfg.d_mlp-1)
    imshow(
        fft2d(neuron_acts_post[..., r].reshape(p, p)),
        title=f"N{r} activations over entire dataset, in 2D Fourier Basis",
        xaxis="b component", yaxis="a component",
        x=fourier_basis_names, y=fourier_basis_names
    )

Notice that the neuron activations are very sparse in the fourier basis. If we look closely, we can view most of the individual neurons as a linear combination of 9 terms:
$Const, cos(w_ka), sin(w_ka), cos(w_ka), sin(w_kb), cos(w_ka)cos(w_kb), cos(w_ka)sin(w_kb), sin(w_ka)cos(w_kb), sin(w_ka)sin(w_kb)$, for some frequency $w_k = \frac{2\pi k}{p}$.

We can also takes the norm across all neurons in the fourier basis:

In [70]:
# Center the neurons to remove the constant term
neuron_acts_centered = neuron_acts_post - einops.reduce(neuron_acts_post, 'batch neuron -> 1 neuron', 'mean')
# Note that fourier_neuron_acts[(0, 0), i]==0 for all i, because we centered the activations
fourier_neuron_acts = fft2d(neuron_acts_centered.reshape(p, p, -1))

imshow(fourier_neuron_acts.pow(2).sum(-1),
       title="Norm of 2D Fourier Components of all neuron activations",
       xaxis="b component", yaxis="a component",
       x=fourier_basis_names, y=fourier_basis_names
)

Here we also see that only a subset of "key frequencies" are used. With all these sin / cos terms, we might guess that the model is using some trig identities to represent modular addition as rotations around the unit circle, analogous to what Neel Nanda found in his grokking work (Diagram stolen from Lawrence Chan, read from the bottom up):

![picture](https://drive.google.com/uc?id=1YNe7z0rQJ0DVZ3YQ0Ne4e6fEObWzjMlW)

Notice that $sin(w(a+b))$ and $cos(w(a+b))$ are just linear combinations of the quadratic terms ($cos(w_ka)cos(w_kb), cos(w_ka)sin(w_kb), sin(w_ka)cos(w_kb), sin(w_ka)sin(w_kb)$) that we noticed in the neuron activations. However, if this hypothesis is correct, we shouldn't need the linear terms ($cos(w_ka), sin(w_ka), cos(w_ka), sin(w_kb)$) to compute the answer. One way to check that the model really is using this trig identity is to look at the logits in the fourier basis. We expect to see quadratic terms but not the linear terms.

## Stare at logits in fourier basis

In [71]:
original_logits.shape

torch.Size([12769, 113])

In [72]:
for _ in range(3):
    r = random.randint(0, cfg.d_vocab_out-1)
    imshow(
        fft2d(original_logits[..., r].reshape(p, p)),
        title=f"logits[{r}] over entire dataset, in 2D fourier basis",
        xaxis="b component", yaxis="a component",
        x=fourier_basis_names, y=fourier_basis_names
    )

Notice the 4x4 squares representing the quadratic terms now seem to dominate the linear terms on the edges, as expected.

We also see that the model is only using the same subset of the "key frequencies" that we saw in the neurons. Once again we can plot the norm across all logits to see this more clearly:

In [73]:
logits_centered = original_logits - original_logits.mean(dim=0, keepdim=True)
fourier_logits = fft2d(logits_centered.reshape(p,p,-1))

imshow(
    fourier_logits.pow(2).sum(-1),
    title="Norm of logits in 2D Fourier Basis",
    xaxis="b component", yaxis="a component",
    x=fourier_basis_names, y=fourier_basis_names
)

## Understanding $W_U$

To get further evidence that the logits are approximating $cos(w(a+b-\textbf{c}))$, we can stare at the weights of $W_U$ which map the post-relu neuron activations to the logits. Since the fourier basis F is orthonormal, we can write: $W_U = USF^T$, where U is a [d_mlp, p] matrix and S is a diagonal [p,p] matrix. We can rewrite this as a sum of outer products:


$$
\begin{aligned}
W_U &= USF^T\\
&= \sum_{i=1}^{p} \sigma_{i}\boldsymbol{\mu_{i}}\textbf{f}_i^T\\
\end{aligned}
$$


I claim we can further approximate this my only using the directions corresponding to the key frequencies:


$$
\begin{aligned}
W_U &= \sum_{k\in keyfreqs}^{} \sigma_{2k-1}\boldsymbol{\mu}_{2k-1}sin(w_k\textbf{c}) + \sigma_{2k}\boldsymbol{\mu}_{2k}cos(w_k\textbf{c})\\
\end{aligned}
$$


We can show this by staring at the $W_U$ weights in the fourier basis, and seeing that only these directions matter.

In [74]:
W_U = model.W_U
print(W_U.shape)

torch.Size([512, 113])


In [75]:
imshow(
    W_U @ fourier_basis.T,
    aspect='auto',
    xaxis="Fourier Basis Component", yaxis="d_mlp",
    title="W_U in fourier basis",
    x=fourier_basis_names
)

We can see this even more clearly if we look at the norms of each component.

In [76]:
line(
    (W_U @ fourier_basis.T).pow(2).sum(dim=0),
    xaxis="Fourier basis component", yaxis="norm",
     title="Norms of Fourier Components of W_U ",
    x=fourier_basis_names
)

For a more quantitative measure, we can show that the sum corresponding to these directions recover almost 100% of the frobenius norm of $W_U$:

In [77]:
key_freqs = [1, 3, 17, 19, 21, 23, 28, 29, 33, 35, 37, 42, 44]

In [78]:
US = W_U @ fourier_basis.T
print(US.shape)

torch.Size([512, 113])


In [79]:
W_U_approx = torch.zeros_like(W_U)
for k in key_freqs:
    W_U_approx += torch.outer(US[:, 2*k-1], fourier_basis[2*k-1]) + torch.outer(US[:, 2*k], fourier_basis[2*k])

print("W_U_approx norm:", torch.norm(W_U_approx).item())
print("W_U norm:", torch.norm(W_U).item())
print("ratio:", torch.norm(W_U_approx).item() / torch.norm(W_U).item())

W_U_approx norm: 36.203739166259766
W_U norm: 36.85184860229492
ratio: 0.9824131092301623


Now I feel pretty convinced that $W_U$ is approximating this sum and the $sin(w_kc)$ and $cos(w_kc)$ terms are encoded in the weights. But how do we know it uses the $sin(w_k(a+b))$ and $cos(w_k(a+b))$ terms? It must be recovering these from the neuron activations.


We want to show:


$$
\begin{aligned}
logits &= n_{post} @ W_U = \sum_{k\in keyfreqs}^{} cos(w_k(a+b-\textbf{c}))\\
\end{aligned}
$$


Where $n_{post}$ is the post relu neuron activations. Substituting $W_U$ as the sum from earlier and applying a trig identity, we can write:


$$
\begin{aligned}
\sum_{k\in keyfreqs}^{} n_{post} @ \sigma_{2k-1}\boldsymbol{\mu}_{2k-1}sin(w_k\textbf{c}) + n_{post} @ \sigma_{2k}\boldsymbol{\mu}_{2k}cos(w_k\textbf{c}) &= \sum_{k\in keyfreqs}^{} sin(w_k(a+b))sin(w_k\textbf{c}) + cos(w_k(a+b))cos(w_k\textbf{c})\\
\end{aligned}
$$


Matching terms, we are left to show:


$$
\begin{aligned}
n_{post} @ \sigma_{2k-1}\boldsymbol{\mu}_{2k-1} &= sin(w_k(a+b)) \\
n_{post} @ \sigma_{2k}\boldsymbol{\mu}_{2k} &= cos(w_k(a+b)) \\
\end{aligned}
$$


In [80]:
for k in key_freqs[:3]:
    sin_components = fft2d((neuron_acts_post @ US[:, 2*k-1]).reshape(p, p))
    cos_components = fft2d((neuron_acts_post @ US[:, 2*k]).reshape(p, p))
    imshow(
        [sin_components, cos_components],
        title=f"Frequency {k} fourier components from projecting onto neurons acts",
        facet_col=0,
        facet_labels=["Sin(w_k(a+b)) term", "Cos(w_k(a+b)) term"],
        x=fourier_basis_names, y=fourier_basis_names,
        xaxis="b component", yaxis="a component"
    )

Notice that we roughly see:

$$
\begin{aligned}
n_{post} @ \sigma_{2k-1}\boldsymbol{\mu}_{2k-1} &= sin(w_ka)cos(w_kb) + cos(w_ka)sin(w_kb) = sin(w_k(a+b)) \\
n_{post} @ \sigma_{2k}\boldsymbol{\mu}_{2k} &= cos(w_ka)cos(w_kb) - sin(w_ka)sin(w_kb) = cos(w_k(a+b)) \\
\end{aligned}
$$


as expected. For a more quantitative measure, we can compute the fraction of variance explained from these trig terms.



In [81]:
for k in key_freqs:
    sin_components = fft2d((neuron_acts_post @ US[:, 2*k-1]).reshape(p, p))
    cos_components = fft2d((neuron_acts_post @ US[:, 2*k]).reshape(p, p))

    square_of_all_sin_terms = sin_components.pow(2).sum()
    square_of_all_cos_terms = cos_components.pow(2).sum()

    square_of_sin_terms = (torch.tensor([sin_components[2*k-1, 2*k], sin_components[2*k, 2*k-1]])).pow(2).sum()
    square_of_cos_terms = (torch.tensor([cos_components[2*k, 2*k], cos_components[2*k-1, 2*k-1]])).pow(2).sum()
    sin_fve = square_of_sin_terms / square_of_all_sin_terms
    cos_fve = square_of_cos_terms / square_of_all_cos_terms
    print(f"Freq {k} Sin terms FVE: {sin_fve.item()}, Cos terms FVE: {cos_fve.item()}")

Freq 1 Sin terms FVE: 0.8760449290275574, Cos terms FVE: 0.8736253976821899
Freq 3 Sin terms FVE: 0.9625068306922913, Cos terms FVE: 0.9590926170349121
Freq 17 Sin terms FVE: 0.95650315284729, Cos terms FVE: 0.9463934898376465
Freq 19 Sin terms FVE: 0.8403603434562683, Cos terms FVE: 0.8526453375816345
Freq 21 Sin terms FVE: 0.9805721640586853, Cos terms FVE: 0.9833682179450989
Freq 23 Sin terms FVE: 0.8515714406967163, Cos terms FVE: 0.8635261058807373
Freq 28 Sin terms FVE: 0.9233057498931885, Cos terms FVE: 0.9262790679931641
Freq 29 Sin terms FVE: 0.9492931962013245, Cos terms FVE: 0.9513300061225891
Freq 33 Sin terms FVE: 0.9452261328697205, Cos terms FVE: 0.9418636560440063
Freq 35 Sin terms FVE: 0.8353185653686523, Cos terms FVE: 0.8411898612976074
Freq 37 Sin terms FVE: 0.9169767498970032, Cos terms FVE: 0.9158449769020081
Freq 42 Sin terms FVE: 0.9668862223625183, Cos terms FVE: 0.9696757197380066
Freq 44 Sin terms FVE: 0.9108324646949768, Cos terms FVE: 0.909629762172699


Togther, we have shown the logits are approximating the weighted sum:

$$
\begin{aligned}
logits &= \sum_{k\in keyfreqs}^{} cos(w_k(a+b-\textbf{c}))\\
\end{aligned}
$$

## Stare at pre-relu neuron activations

I now feel pretty convinced that the network has learned this trig algorithm to compute the answer as the argmax of $logits = \sum_{k\in keyfreqs}^{} cos(w_k(a+b-\textbf{c}))$, but how did it compute the quadratic trig terms in the first place, and why do we also see linear trig terms in the neuron activations if they aren't used to compute the answer?

First I'll check the neuron activations pre ReLU to see if the ReLU is responsible for any of these terms.

In [82]:
neuron_acts_pre = cache['hook_pre']
print(neuron_acts_pre.shape)

torch.Size([12769, 512])


In [83]:
for _ in range(3):
    r = random.randint(0, cfg.d_mlp-1)
    imshow(
        [fft2d(neuron_acts_pre[..., r].reshape(p, p)),fft2d(neuron_acts_post[..., r].reshape(p, p))],
        title=f"N{r} pre and post relu neuron activations over entire dataset",
        xaxis="b component", yaxis="a component",
        x=fourier_basis_names, y=fourier_basis_names,
        facet_col=0,
        facet_labels=["pre", "post"]
    )

These suggest the pre-ReLU neuron activations already contain information about the linear cos / sin terms, but the ReLU is solely responsible for computing the quadratic terms. We can check more clearly by taking at the norm across all neurons:

In [84]:
neuron_acts_pre_centered = neuron_acts_pre - neuron_acts_pre.mean(dim=0, keepdim=True)
fourier_neuron_acts_pre = fft2d(neuron_acts_pre_centered.reshape(p, p, -1))

imshow(
    [fourier_neuron_acts_pre.pow(2).sum(dim=-1), fourier_neuron_acts.pow(2).sum(dim=-1)],
    facet_col=0,
    title="Norms of Pre and Post neuron acts in 2D fourier basis",
    facet_labels=["Pre", "Post"],
    xaxis="b component", yaxis="a component",
    x=fourier_basis_names, y=fourier_basis_names
)

It's faint, but notice we can only see the quadratic terms in the post-relu activations. The pre-relu acts are solely linear combinations of the linear trig terms. It's wild that the ReLU can do this, but also makes sense: multiplying two trig components is nonlinear, and this ReLU is the only non-linearity in our network.


Now that we know that the ReLU can compute the quadratic terms, we can guess that $W_E$ and $W_{in}$ are responsible for the linear terms.

## Understanding $W_EW_{in}$

Doing some math, we can write the pre-relu neuron activations as a linear map from the tokens:

$$
\begin{aligned}
n_{pre} &= e @ W_{in} \\
&= (t_aW_E + t_bW_E) @ W_{in} \\
&= t_aW_EW_{in} + t_bW_EW_{in} \\
\end{aligned}
$$

Where $t_a$ and $t_b$ are the inputs $a$ and $b$ as one hot encoded vectors. Define $W_{neur} = W_EW_{in}$ as the linear map that takes tokens to pre-relu neuron activations.

Since $t_a$ and $t_b$ are one hot encoded vectors, we should be able to think of $W_{neur}$ as a lookup table which looks up $sin(w_ka), cos(w_ka), sin(w_kb), cos(w_kb)$ given $a$ and $b$. To check this, we can look at $W_{neur}$ in the fourier basis.

In [85]:
W_E = model.W_E
print(W_E.shape)
W_in = model.W_in
print(W_in.shape)

torch.Size([113, 128])
torch.Size([128, 512])


In [86]:
W_neur = W_E @ W_in
print(W_neur.shape)

torch.Size([113, 512])


In [87]:
imshow(
    fourier_basis @ W_neur,
    title="W_neur in fourier basis",
    yaxis="fourier component", xaxis="d_mlp",
    y=fourier_basis_names
)

line(
    (fourier_basis @ W_neur).pow(2).sum(dim=-1),
    title="W_neur norms in fourier basis",
    yaxis="norm", xaxis="fourier component",
    x=fourier_basis_names
)

Notice that it is also sparse in the fourier basis, and has the exact same key frequencies we observed in the rest of the network.

In [88]:
key_freqs

[1, 3, 17, 19, 21, 23, 28, 29, 33, 35, 37, 42, 44]

# Summary

We found that a 1L MLP can grok modular addition by learning the same trig based algorithm that Neel Nanda found in his Grokking work:


1. Given tokens $a$ and $b$, $W_{E}W_{in}$ looks up $sin(w_ka), cos(w_ka), sin(w_kb), cos(w_kb)$ for key frequencies $w_k = \frac{2\pi k}{p}$.
2. ReLU multiplies these to compute quadratic terms $sin(w_ka)sin(w_kb), sin(w_ka)cos(w_kb), cos(w_ka)sin(w_kb), cos(w_ka)cos(w_kb)$
3. The neuron-logit map $W_U$ reads $sin(w_k(a+b))$, $cos(w_k(a+b))$ from the neuron activations to compute $logits = sin(w_k(a+b))sin(w_k\textbf{c}) + cos(w_k(a+b))cos(w_k\textbf{c}) = cos(w_k(a+b-\textbf{c}))$. Thus the highest logit index corresponds to a + b % p.


The lines of evidence we used are:
- Neuron activations are sparse in the 2D fourier basis, only corresponding to 9 trig terms of key frequencies.
- Logits are also sparse in 2D fourier basis, only corresponding to quadratic trig terms of key frequencies
- Logit map W_U is sparse in fourier basis, with only non trivial directions corresponding to sin / cos terms of key frequencies. (accounting for >98% of total frobenius norm)
- Projecting post-relu neuron acts onto the W_U "in directions" for key frequencies gives linear combinations of quadratic terms corresponding to $sin(w_k(a+b)), cos(w_k(a+b))$ (>84% FVE for each key freq)
- Norms of pre-relu neuron activations in 2D fourier basis clearly only have linear trig terms, while post-relu acts have linear and quadratic trig terms.
- $W_{E}W_{in}$ is also sparse in fourier basis, and uses the exact same key frequencies as the rest of the model.

## General Tehniques you can apply to other problems

* Staring at neuron acts (both post and pre ReLU)
* Re-writing weight matrices as SVD to determine important directions
* Using fraction of variance explained and frobenius norm ratios to quantitatively verify hypotheses about weights
* Reframing activations / weights with a change of basis
* Taking norms over all neuron / logit activations