#Abstract:

I took a step further from Neel's modular addition (a+b)mod 113 task, and  proposed a more complex challenge to predict (a^m + b^n) mod 113, where m and n is a fixed positive integer 1~9. This Notebook is based on [Neel's Grokking demo](https://github.com/neelnanda-io/TransformerLens/blob/main/demos/Grokking_Demo.ipynb). Based on [Neel's paper](https://arxiv.org/abs/2301.05217) I highly suspect and have a hypothesis that when m=n the model is cabable of extending trigonometry to [Polar System](https://en.wikipedia.org/wiki/Polar_coordinate_system), and "Grokked" the [De Moivre's formula](https://en.wikipedia.org/wiki/De_Moivre%27s_formula), an extension of [Euler's formula](https://en.wikipedia.org/wiki/Euler%27s_formula), to map (a,b) to (a^m + b^n) correctly on the Polar System.


*   a and b 1~113, n > 9 will overflow torch.int64
*   n is fixed, not enough RAM to do full batch on p^3 dataset





#Hypothesis:
In Neel's paper with the task of (a+b) mod 113, the model was found to independently uncover the Discrete Fourier Transform. This result indicates an early understanding of trigonometric concepts during the training phase. My hypothesis, largely based on observing the model's immediate and accurate generalization for cases where m equals n, suggests that this occurs early in training, likely as the model begins to grasp trigonometry. I propose that the model extends its trigonometric knowledge to the Polar Coordinate System. Notably, in situations where m and n are both 4, the model showed instant and precise generalization. In more complex scenarios like m=n=9, the model initially overfitted but then quickly adapted and generalized perfectly. This was particularly surprising as the model performed better in this more complex setting than in a simpler one (m=4, n=3).


Cases:
1. There is a geometric elegance when m equals n for the expression (a^m + b^n). This elegance arises from how the expression, in these cases, maps onto the Polar Coordinates. Specifically, it leverages the rotational symmetry inherent in complex numbers. De Moivre's formula, which extends Euler's formula, aptly describes these rotations in the complex plane. It suggests that the model intuitively understands and applies complex number theory, particularly the geometric aspect of rotating points in the complex plane, to efficiently predict these expressions.

2. When m is an integer multiple of n (m = kn), the computation of (a^m + b^n) also has a geometric intepretation. This scenario unfolds within the domain of complex numbers, where the powers of complex numbers exhibit specific geometric patterns. In these cases, the repeated multiplication implied by the exponentiation can be visualized as scaled rotations in the complex plane.

3. In scenarios where m and n do not align with the special conditions of being equal (m = n) or one being an integer multiple of the other (m = kn), computing (a^m + b^n) poses significant challenges. Without the geometric clarity provided by rotational symmetry in the complex plane (as seen in the m = n or m = kn cases), the model lacks a straightforward complex number method for prediction. Algebraically, the task becomes increasingly complex with larger exponents, as no simplifying symmetries or patterns emerge to ease computation. From my observations, the absence of these specific conditions leads to slower generalization and difficulty in accurately predicting (a^m + b^n), unlike in cases where m and n adhere to these special conditions.








#Setup

In [None]:
TRAIN_MODEL = True

In [None]:
DEVELOPMENT_MODE = True
import google.colab
IN_COLAB = True
!pip install transformer_lens
#%pip install circuitsvis
!pip install git+https://github.com/neelnanda-io/neel-plotly

Collecting transformer_lens
  Downloading transformer_lens-1.11.0-py3-none-any.whl (119 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/119.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m112.6/119.1 kB[0m [31m4.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.1/119.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m58.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transf

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import neel_plotly as npx

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Model Training

## Config

In [None]:
frac_train = 0.3

p = 113

# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.9, 0.98)

num_epochs = 25000
checkpoint_every = 100

DATA_SEED = 598

## Define Task
(a^m + b^n) mod 113

In [None]:
max_int64 = torch.iinfo(torch.int64).max
print(max(max_int64,113**10 + 113**10)) #overflowed
print(max(max_int64,113**9 + 113**9))  # n <= 9 is fine

678913477984444629698
9223372036854775807


In [None]:
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)
equals_vector = einops.repeat(torch.tensor(113), " -> (i j)", i=p, j=p)

dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1)
a = 8
b = 9
labels = (dataset[:, 0]**a+ dataset[:, 1]**b) % p

Convert this to a train + test set - 30% in the training set

In [None]:
torch.manual_seed(DATA_SEED)
indices = torch.randperm(p*p)
cutoff = int(p*p*frac_train)

train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices].cuda()
train_labels = labels[train_indices].cuda()
test_data = dataset[test_indices].cuda()
test_labels = labels[test_indices].cuda()

## Define Model

In [None]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=p+1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True,
    device="cuda",
    seed = 999,
)

In [None]:
model = HookedTransformer(cfg).cuda()

Moving model to device:  cuda


Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [None]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [None]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()
train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)

tensor(4.7332, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(4.7343, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)


## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent. We do this so to make training smoother and reduce the number of slingshots.

In [None]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if TRAIN_MODEL:
    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())

        optimizer.step()
        optimizer.zero_grad()

        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            checkpoint_epochs.append(epoch)
            model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

  0%|          | 0/25000 [00:00<?, ?it/s]

Epoch 99 Train Loss 2.6091830243313203 Test Loss 3.5536351797228947
Epoch 199 Train Loss 2.0558338910735423 Test Loss 3.8354116362670583
Epoch 299 Train Loss 1.2408248554461447 Test Loss 4.701145341135779
Epoch 399 Train Loss 0.8289972302662411 Test Loss 5.687701053324773
Epoch 499 Train Loss 0.5677488649032718 Test Loss 6.346257516163142
Epoch 599 Train Loss 0.31792136883912037 Test Loss 6.888970880181434
Epoch 699 Train Loss 0.0998254439462482 Test Loss 7.177476676341397
Epoch 799 Train Loss 0.008788213008298346 Test Loss 7.215221536944322
Epoch 899 Train Loss 0.0024206201506079337 Test Loss 6.598347792538869
Epoch 999 Train Loss 0.0007822396063566605 Test Loss 6.177143411256893
Epoch 1099 Train Loss 0.000257540698828453 Test Loss 5.847113952510939
Epoch 1199 Train Loss 8.630414289532207e-05 Test Loss 5.525009172994505
Epoch 1299 Train Loss 2.9406459102383222e-05 Test Loss 5.203618451567342
Epoch 1399 Train Loss 1.024608410685164e-05 Test Loss 4.869057738776905
Epoch 1499 Train Loss 

## Show Model Training Statistics

In [None]:
npx.line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=False, title="n=8, m= 9 (Equivalent to modular addition. Discover Discrete Fourier transform, which implies the model learns trigonometry in early training phase)", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)

In [None]:
import os
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

In [None]:
torch.save(
    {
        "model":model.state_dict(),
        "config": model.cfg,
        "checkpoints": model_checkpoints,
        "checkpoint_epochs": checkpoint_epochs,
        "test_losses": test_losses,
        "train_losses": train_losses,
        "train_indices": train_indices,
        "test_indices": test_indices,
    },
    PTH_LOCATION)

In [None]:
if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    train_indices = cached_data["train_indices"]
    test_indices = cached_data["test_indices"]

In [None]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line
line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Modular Addition", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)

Collecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-a709mdsk
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-a709mdsk
  Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.numel())

4328691


In [None]:
W_E = model.embed.W_E[:-1]
print("W_E", W_E.shape)
W_neur = W_E @ model.blocks[0].attn.W_V @ model.blocks[0].attn.W_O @ model.blocks[0].mlp.W_in
print("W_neur", W_neur.shape)
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

W_E torch.Size([113, 128])
W_neur torch.Size([4, 113, 512])
W_logit torch.Size([512, 113])


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
original_loss = loss_fn(original_logits.to(device), labels.to(device)).item()
print("Original Loss:", original_loss)

Original Loss: 0.09419264991408313


In [None]:
pattern_a = cache["pattern", 0, "attn"][:, :, -1, 0]
pattern_b = cache["pattern", 0, "attn"][:, :, -1, 1]
neuron_acts = cache["post", 0, "mlp"][:, -1, :]
neuron_pre_acts = cache["pre", 0, "mlp"][:, -1, :]

In [None]:
for param_name, param in cache.items():
    print(param_name, param.shape)

hook_embed torch.Size([12769, 3, 128])
hook_pos_embed torch.Size([12769, 3, 128])
blocks.0.hook_resid_pre torch.Size([12769, 3, 128])
blocks.0.attn.hook_q torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_k torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_v torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_attn_scores torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_pattern torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_z torch.Size([12769, 3, 4, 32])
blocks.0.hook_attn_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_mid torch.Size([12769, 3, 128])
blocks.0.mlp.hook_pre torch.Size([12769, 3, 512])
blocks.0.mlp.hook_post torch.Size([12769, 3, 512])
blocks.0.hook_mlp_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_post torch.Size([12769, 3, 128])


In [None]:
imshow(cache["pattern", 0].mean(dim=0)[:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [None]:
imshow(cache["pattern", 0][5][:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [None]:
imshow(cache["pattern", 0][:, 0, -1, 0].reshape(p, p), title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a")

In [None]:
imshow(
    einops.rearrange(cache["pattern", 0][:, :, -1, 0], "(a b) head -> head a b", a=p, b=p),
    title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a", facet_col=0)

In [None]:
cache["post", 0, "mlp"].shape

torch.Size([12769, 3, 512])

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

In [None]:
W_E.shape

torch.Size([113, 128])

In [None]:
U, S, Vh = torch.svd(W_E)
line(S, title="Singular Values")
imshow(U, title="Principal Components on the Input")

In [None]:
# Control - random Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))
line(S, title="Singular Values Random")
imshow(U, title="Principal Components Random")

In [None]:
U, S, Vh = torch.svd(W_E)
line(U[:, :8].T, title="Principal Components of the embedding", xaxis="Input Vocabulary")

In [None]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")
for freq in range(1, p//2+1):
    fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Sin {freq}")
    fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Cos {freq}")
fourier_basis = torch.stack(fourier_basis, dim=0).cuda()
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)
imshow(fourier_basis, xaxis="Input", yaxis="Component", y=fourier_basis_names)

In [None]:
line(fourier_basis[:8], xaxis="Input", line_labels=fourier_basis_names[:8], title="First 8 Fourier Components")
line(fourier_basis[25:29], xaxis="Input", line_labels=fourier_basis_names[25:29], title="Middle Fourier Components")

In [None]:
imshow(fourier_basis @ fourier_basis.T, title="All Fourier Vectors are Orthogonal")

In [None]:
imshow(fourier_basis @ W_E, yaxis="Fourier Component", xaxis="Residual Stream", y=fourier_basis_names, title="Embedding in Fourier Basis")

In [None]:
line((fourier_basis @ W_E).norm(dim=-1), xaxis="Fourier Component", x=fourier_basis_names, title="Norms of Embedding in Fourier Basis")

In [None]:
key_freqs = [17, 25, 32, 47]
key_freq_indices = [33, 34, 49, 50, 63, 64, 93, 94]
fourier_embed = fourier_basis @ W_E
key_fourier_embed = fourier_embed[key_freq_indices]
print("key_fourier_embed", key_fourier_embed.shape)
imshow(key_fourier_embed @ key_fourier_embed.T, title="Dot Product of embedding of key Fourier Terms")

key_fourier_embed torch.Size([8, 128])


In [None]:
line(fourier_basis[[34, 50, 64, 94]], title="Cos of key freqs", line_labels=[34, 50, 64, 94])

In [None]:
line(fourier_basis[[34, 50, 64, 94]].mean(0), title="Constructive Interference")

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, 0], "(a b) -> a b", a=p, b=p),
    title="First neuron act", xaxis="b", yaxis="a",)

In [None]:
imshow(fourier_basis[94][None, :] * fourier_basis[94][:, None], title="Cos 47a * cos 47b")

In [None]:
imshow(fourier_basis[94][None, :] * fourier_basis[0][:, None], title="Cos 47a * const")

In [None]:
imshow(fourier_basis @ neuron_acts[:, 0].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 0", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
imshow(fourier_basis @ neuron_acts[:, 5].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 5", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
imshow(fourier_basis @ torch.randn_like(neuron_acts[:, 0]).reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of RANDOM", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)

fourier_neuron_acts torch.Size([512, 113, 113])


In [None]:
neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()
for freq in range(0, p//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]
imshow(neuron_freq_norm, xaxis="Neuron", yaxis="Freq", y=torch.arange(1, p//2+1), title="Neuron Frac Explained by Freq")

In [None]:
line(neuron_freq_norm.max(dim=0).values.sort().values, xaxis="Neuron", title="Max Neuron Frac Explained over Freqs")

In [None]:
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

W_logit torch.Size([512, 113])


In [None]:
line((W_logit @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title="W_logit in the Fourier Basis")

In [None]:
neurons_17 = neuron_freq_norm[17-1]>0.85
neurons_17.shape

torch.Size([512])

In [None]:
neurons_17.sum()

tensor(0, device='cuda:0')

In [None]:
line((W_logit[neurons_17] @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title="W_logit for freq 17 neurons in the Fourier Basis")

In [None]:
freq = 17
W_logit_fourier = W_logit @ fourier_basis
neurons_sin_17 = W_logit_fourier[:, 2*freq-1]
line(neurons_sin_17)

In [None]:
neuron_acts.shape

torch.Size([12769, 512])

In [None]:
inputs_sin_17c = neuron_acts @ neurons_sin_17
imshow(fourier_basis @ inputs_sin_17c.reshape(p, p) @ fourier_basis.T, title="Fourier Heatmap over inputs for sin17c", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
def embed_to_cos_sin(fourier_embed):
    if len(fourier_embed.shape) == 1:
        return torch.stack([fourier_embed[1::2], fourier_embed[2::2]])
    else:
        return torch.stack([fourier_embed[:, 1::2], fourier_embed[:, 2::2]], dim=1)

from neel_plotly.plot import melt

def plot_embed_bars(
    fourier_embed,
    title="Norm of embedding of each Fourier Component",
    return_fig=False,
    **kwargs
):
    cos_sin_embed = embed_to_cos_sin(fourier_embed)
    df = melt(cos_sin_embed)
    # display(df)
    group_labels = {0: "sin", 1: "cos"}
    df["Trig"] = df["0"].map(lambda x: group_labels[x])
    fig = px.bar(
        df,
        barmode="group",
        color="Trig",
        x="1",
        y="value",
        labels={"1": "$w_k$", "value": "Norm"},
        title=title,
        **kwargs
    )
    fig.update_layout(dict(legend_title=""))

    if return_fig:
        return fig
    else:
        fig.show()

In [None]:
def test_logits(logits, bias_correction=False, original_logits=None, mode="all"):
    # Calculates cross entropy loss of logits representing a batch of all p^2
    # possible inputs
    # Batch dimension is assumed to be first
    if logits.shape[1] == p * p:
        logits = logits.T
    if logits.shape == torch.Size([p * p, p + 1]):
        logits = logits[:, :-1]
    logits = logits.reshape(p * p, p)
    if bias_correction:
        # Applies bias correction - we correct for any missing bias terms,
        # independent of the input, by centering the new logits along the batch
        # dimension, and then adding the average original logits across all inputs
        logits = (
            einops.reduce(original_logits - logits, "batch ... -> ...", "mean") + logits
        )
    if mode == "train":
        return loss_fn(logits[train_indices].to(device), labels[train_indices].to(device))
    elif mode == "test":
        return loss_fn(logits[test_indices].to(device), labels[test_indices].to(device))
    elif mode == "all":
        return loss_fn(logits.to(device), labels.to(device))

In [None]:
metric_cache = {}

In [None]:
def get_metrics(model, metric_cache, metric_fn, name, reset=False):
    if reset or (name not in metric_cache) or (len(metric_cache[name]) == 0):
        metric_cache[name] = []
        for c, sd in enumerate(tqdm.tqdm((model_checkpoints))):
            model.reset_hooks()
            model.load_state_dict(sd)
            out = metric_fn(model)
            if type(out) == torch.Tensor:
                out = utils.to_numpy(out)
            metric_cache[name].append(out)
        model.load_state_dict(model_checkpoints[-1])
        try:
            metric_cache[name] = torch.tensor(metric_cache[name])
        except:
            metric_cache[name] = torch.tensor(np.array(metric_cache[name]))

In [None]:
memorization_end_epoch = 1500
circuit_formation_end_epoch = 13300
cleanup_end_epoch = 16600

In [None]:
def add_lines(figure):
    figure.add_vline(memorization_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(circuit_formation_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(cleanup_end_epoch, line_dash="dash", opacity=0.7)
    return figure

In [None]:
fig = line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Modular Addition", line_labels=['train', 'test'], toggle_x=True, toggle_y=True, return_fig=True)
add_lines(fig)

In [None]:
all_logits = original_logits[:, -1, :]
print(all_logits.shape)
all_logits = einops.rearrange(all_logits, "(a b) c -> a b c", a=p, b=p)
print(all_logits.shape)

torch.Size([12769, 113])
torch.Size([113, 113, 113])


In [None]:
coses = {}
for freq in key_freqs:
    print("Freq:", freq)
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()
    cube_predicted_logits /= cube_predicted_logits.norm()
    coses[freq] = cube_predicted_logits

Freq: 17
Freq: 25
Freq: 32
Freq: 47


In [None]:
approximated_logits = torch.zeros_like(all_logits)
for freq in key_freqs:
    print("Freq:", freq)
    coeff = (all_logits * coses[freq]).sum()
    print("Coeff:", coeff)
    cosine_sim = coeff / all_logits.norm()
    print("Cosine Sim:", cosine_sim)
    approximated_logits += coeff * coses[freq]
residual = all_logits - approximated_logits
print("Residual size:", residual.norm())
print("Residual fraction of norm:", residual.norm()/all_logits.norm())

Freq: 17
Coeff: tensor(-2.0116, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(-0.0001, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 25
Coeff: tensor(-0.5212, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(-3.6056e-05, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 32
Coeff: tensor(0.1556, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(1.0766e-05, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 47
Coeff: tensor(12.3170, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.0009, device='cuda:0', grad_fn=<DivBackward0>)
Residual size: tensor(14454.8984, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
Residual fraction of norm: tensor(1.0000, device='cuda:0', grad_fn=<DivBackward0>)


In [None]:
random_logit_cube = torch.randn_like(all_logits)
print((all_logits * random_logit_cube).sum()/random_logit_cube.norm()/all_logits.norm())

tensor(-0.0017, device='cuda:0', grad_fn=<DivBackward0>)


In [None]:
test_logits(all_logits.to(device))

tensor(0.0942, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)

In [None]:
test_logits(approximated_logits)

tensor(4.7274, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)

In [None]:
cos_cube = []
for freq in range(1, p//2 + 1):
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()
    cube_predicted_logits /= cube_predicted_logits.norm()
    cos_cube.append(cube_predicted_logits)
cos_cube = torch.stack(cos_cube, dim=0)
print(cos_cube.shape)

torch.Size([56, 113, 113, 113])


In [None]:
def get_cos_coeffs(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals


get_metrics(model, metric_cache, get_cos_coeffs, "cos_coeffs")
print(metric_cache["cos_coeffs"].shape)

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([250, 56])



Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:261.)



In [None]:
fig = line(metric_cache["cos_coeffs"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Coefficients with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Coefficient", return_fig=True)
add_lines(fig)

In [None]:
def get_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals / logits.norm()

get_metrics(model, metric_cache, get_cos_sim, "cos_sim") # You may need a big GPU. If you don't have one and can't work around this, raise an issue for help!
print(metric_cache["cos_sim"].shape)

fig = line(metric_cache["cos_sim"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Cosine Sim with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([250, 56])


In [None]:
def get_residual_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    residual = logits - (vals[:, None, None, None] * cos_cube).sum(dim=0)
    return residual.norm() / logits.norm()

get_metrics(model, metric_cache, get_residual_cos_sim, "residual_cos_sim")
print(metric_cache["residual_cos_sim"].shape)

fig = line([metric_cache["cos_sim"][:, i] for i in range(p//2)]+[metric_cache["residual_cos_sim"]], line_labels=[f"Freq {i}" for i in range(1, p//2+1)]+["residual"], title="Cosine Sim with Predicted Logits + Residual", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([250])


In [None]:
neuron_acts.shape

torch.Size([12769, 512])

In [None]:
neuron_acts_square = einops.rearrange(neuron_acts, "(a b) neur -> a b neur", a=p, b=p).clone()
# Center it
neuron_acts_square -= einops.reduce(neuron_acts_square, "a b neur -> 1 1 neur", "mean")
neuron_acts_square_fourier = einsum("a b neur, fa a, fb b -> fa fb neur", neuron_acts_square, fourier_basis, fourier_basis)
imshow(neuron_acts_square_fourier.norm(dim=-1), xaxis="Fourier Component b", yaxis="Fourier Component a", title="Norms of neuron activations by Fourier Component", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.numel())
neuron_acts = cache["post", 0, "mlp"][:, -1, :]

4328691


In [None]:
approx_neuron_acts = torch.zeros_like(neuron_acts)
approx_neuron_acts += neuron_acts.mean(dim=0)
a = torch.arange(p)[:, None]
b = torch.arange(p)[None, :]
for freq in key_freqs:
    cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()
    cos_apb_vec /= cos_apb_vec.norm()
    cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
    sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()
    sin_apb_vec /= sin_apb_vec.norm()
    sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
restricted_logits = approx_neuron_acts @ W_logit
print(loss_fn(restricted_logits[test_indices], test_labels))

tensor(5.5029, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)


In [None]:
print(loss_fn(all_logits.to(device), labels.to(device))) # This bugged on models not fully trained

RuntimeError: ignored

In [None]:
def get_restricted_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)
    approx_neuron_acts += neuron_acts.mean(dim=0)
    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
    restricted_logits = approx_neuron_acts @ model.blocks[0].mlp.W_out @ model.unembed.W_U
    # Add bias term
    restricted_logits += logits.mean(dim=0, keepdim=True) - restricted_logits.mean(dim=0, keepdim=True)
    return loss_fn(restricted_logits[test_indices], test_labels)
get_restricted_loss(model)

tensor(4.8611, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)

In [None]:
get_metrics(model, metric_cache, get_restricted_loss, "restricted_loss", reset=True)
print(metric_cache["restricted_loss"].shape)

  0%|          | 0/250 [00:00<?, ?it/s]

OutOfMemoryError: ignored

In [None]:
fig = line([train_losses[::100], test_losses[::100], metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Restricted Loss Curve", line_labels=['train', 'test', "restricted_loss"], toggle_x=True, toggle_y=True, return_fig=True)
add_lines(fig)

KeyError: ignored

In [None]:
fig = line([torch.tensor(test_losses[::100])/metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Restricted Loss to Test Loss Ratio", toggle_x=True, toggle_y=True, return_fig=True)
# WARNING: bugged when cancelling training half way thr ough
add_lines(fig)

TypeError: ignored

In [None]:
approx_neuron_acts = torch.zeros_like(neuron_acts)
# approx_neuron_acts += neuron_acts.mean(dim=0)
a = torch.arange(p)[:, None]
b = torch.arange(p)[None, :]
for freq in key_freqs:
    cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()
    cos_apb_vec /= cos_apb_vec.norm()
    cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
    sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()
    sin_apb_vec /= sin_apb_vec.norm()
    sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
excluded_neuron_acts = neuron_acts - approx_neuron_acts
excluded_logits = excluded_neuron_acts @ W_logit
print(loss_fn(excluded_logits[train_indices], train_labels))

OutOfMemoryError: ignored

In [None]:
def get_excluded_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)
    # approx_neuron_acts += neuron_acts.mean(dim=0)
    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
    excluded_neuron_acts = neuron_acts - approx_neuron_acts
    residual_stream_final = excluded_neuron_acts @ model.blocks[0].mlp.W_out + cache["resid_mid", 0][:, -1, :]
    excluded_logits = residual_stream_final @ model.unembed.W_U
    return loss_fn(excluded_logits[train_indices], train_labels)
get_excluded_loss(model)

OutOfMemoryError: ignored

In [None]:
get_metrics(model, metric_cache, get_excluded_loss, "excluded_loss", reset=True)
print(metric_cache["excluded_loss"].shape)

  0%|          | 0/100 [00:00<?, ?it/s]

OutOfMemoryError: ignored

In [None]:
fig = line([train_losses[::100], test_losses[::100], metric_cache["excluded_loss"], metric_cache["restricted_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Excluded and Restricted Loss Curve", line_labels=['train', 'test', "excluded_loss", "restricted_loss"], toggle_x=True, toggle_y=True, return_fig=True)

add_lines(fig)