# Grokking Demo Notebook

# Setup
(No need to read)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git@new-demo
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: notebook_connected


In [3]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [4]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Dashiell")

In [5]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [6]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [7]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Model Training

## Config

In [8]:
p = 53
frac_train = 0.5

# Optimizer config
lr = 1e-3
wd = 1. 
betas = (0.9, 0.98)

num_epochs = 50000
checkpoint_every = 100

In [9]:
53 ** 4

7890481

## Define Task
* Define modular addition
* Define the dataset & labels

Input format:
|a|b|=|

In [10]:
# For p**4 elements add a d vector with l -> (i j k l)
a_vector = einops.repeat(torch.arange(p), "i -> (i j k l)", j=p, k=p, l=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j k l)", i=p, k=p, l=p)
c_vector = einops.repeat(torch.arange(p), "k -> (i j k l)", i=p, j=p, l=p)
d_vector = einops.repeat(torch.arange(p), "l -> (i j k l)", i=p, j=p, k=p) 
equals_vector = einops.repeat(torch.tensor(p), " -> (i j k l)", i=p, j=p, k=p, l=p)
star_vector = einops.repeat(torch.tensor(p+1), " -> (i j k l)", i=p, j=p, k=p, l=p)
plus_vector = einops.repeat(torch.tensor(p+2), " -> (i j k l)", i=p, j=p, k=p, l=p)
caret_vector = einops.repeat(torch.tensor(p+3), " -> (i j k l)", i=p, j=p, k=p, l=p)
lparen = einops.repeat(torch.tensor(p+4), " -> (i j k l)", i=p, j=p, k=p, l=p)
rparen = einops.repeat(torch.tensor(p+5), " -> (i j k l)", i=p, j=p, k=p, l=p)

In [11]:
dataset = torch.stack([
    lparen,
    a_vector,
    star_vector,
    b_vector,
    plus_vector,
    c_vector,
    rparen,
    caret_vector,
    d_vector,
    equals_vector], dim=1).cuda()
print(dataset[:5])
print(dataset.shape)

tensor([[57,  0, 54,  0, 55,  0, 58, 56,  0, 53],
        [57,  0, 54,  0, 55,  0, 58, 56,  1, 53],
        [57,  0, 54,  0, 55,  0, 58, 56,  2, 53],
        [57,  0, 54,  0, 55,  0, 58, 56,  3, 53],
        [57,  0, 54,  0, 55,  0, 58, 56,  4, 53]], device='cuda:0')
torch.Size([7890481, 10])


In [12]:
labels = ((dataset[:, 1] * dataset[:, 3] + dataset[:, 5]) ** dataset[:, 8]) % p
print(labels.shape)
print(labels[:5])

torch.Size([7890481])
tensor([1, 0, 0, 0, 0], device='cuda:0')


In [13]:
from torch.utils.data import TensorDataset, DataLoader

Convert this to a train + test set - 30% in the training set

In [14]:
torch.manual_seed(314159)
indices = torch.randperm(p**4)
cutoff = int((p**4)*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = TensorDataset(dataset[train_indices], labels[train_indices])
test_data = TensorDataset(dataset[test_indices], labels[test_indices])
print(train_data[5])
#print(train_labels[:5])
#print(train_data.shape)
print(test_data[5])
#print(test_labels[:5])
#print(test_data.shape)

(tensor([57,  1, 54,  6, 55, 16, 58, 56,  9, 53], device='cuda:0'), tensor(34, device='cuda:0'))
(tensor([57, 39, 54,  3, 55, 25, 58, 56, 31, 53], device='cuda:0'), tensor(29, device='cuda:0'))


## Define Model

In [15]:

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 256,
    d_head = 64,
    d_mlp = 1024,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=p+6,
    d_vocab_out=p,
    n_ctx=10,
    init_weights=True,
    device="cuda",
    seed = 999,
)

In [16]:
model = HookedTransformer(cfg)

Disable the biases, as we don't need them for this task.

In [17]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [18]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [19]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()

#train_logits = model(train_data)
#train_loss = loss_fn(train_logits, train_labels)
#print(train_loss)
#test_logits = model(test_data)
#test_loss = loss_fn(test_logits, test_labels)
#print(test_loss)

def train_forward(model, dataloader):
    total_loss = torch.tensor(0., device='cuda', requires_grad=False)
    for batch, labels in dataloader:
        logits = model(batch)
        loss = loss_fn(logits, labels)
        loss.backward()
        total_loss += loss
    return total_loss

def test_forward(model, dataloader):
    total_loss = torch.tensor(0., device='cuda', requires_grad=False)
    for batch, labels in dataloader:
        logits = model(batch)
        loss = loss_fn(logits, labels)
        total_loss += loss
    return total_loss

In [20]:
print("Uniform loss:")
print(np.log(p))

Uniform loss:
3.970291913552122


In [21]:
batch_size = 2 ** 16
batch_size

65536

In [22]:
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent.

In [None]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
num_epochs = 50_000
grok_threshold = 0.01

for epoch in tqdm.tqdm(range(num_epochs)):
    train_loss = train_forward(model, train_dataloader)
    train_losses.append(train_loss.item())

    optimizer.step()
    optimizer.zero_grad()

    with torch.inference_mode():
        test_loss = test_forward(model, train_dataloader)
        test_losses.append(test_loss.item())
    
    if (epoch % checkpoint_every) == 0:
        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))
        print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")
    if test_loss.item() <= grok_threshold:
        break

  0%|          | 0/50000 [00:00<?, ?it/s]

Epoch 0 Train Loss 242.78997802734375 Test Loss 237.64923095703125
Epoch 100 Train Loss 217.28501892089844 Test Loss 217.24752807617188
Epoch 200 Train Loss 209.68406677246094 Test Loss 209.5689697265625
Epoch 300 Train Loss 207.1466064453125 Test Loss 206.0689239501953
Epoch 400 Train Loss 204.72836303710938 Test Loss 204.74916076660156
Epoch 500 Train Loss 203.6313934326172 Test Loss 203.71217346191406
Epoch 600 Train Loss 202.8385772705078 Test Loss 202.80880737304688
Epoch 700 Train Loss 202.04086303710938 Test Loss 201.99293518066406
Epoch 800 Train Loss 201.083251953125 Test Loss 201.0811767578125
Epoch 900 Train Loss 200.2921142578125 Test Loss 200.03111267089844
Epoch 1000 Train Loss 199.1699676513672 Test Loss 199.12059020996094


In [None]:
"""
num_epochs = 150_000
grok_threshold = 0.01

for epoch in tqdm.tqdm(range(num_epochs)):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()
    train_losses.append(train_loss.item())

    optimizer.step()
    optimizer.zero_grad()

    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits, test_labels)
        test_losses.append(test_loss.item())
    
    if ((epoch+1)%checkpoint_every)==0:
        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))
        print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")
    if test_loss.item() <= grok_threshold:
        break
"""

In [None]:
torch.save(
     {
         "model":model.state_dict(),
         "config": model.cfg,
         "checkpoints": model_checkpoints,
         "checkpoint_epochs": checkpoint_epochs,
         "test_losses": test_losses,
         "train_losses": train_losses,
         "train_indices": train_indices,
         "test_indices": test_indices
     },
     "checkpoints/grokking_xyz_3333.pth")

## Show Model Training Statistics, Check that it groks!

In [None]:
len(train_losses[::100])

In [None]:
line(
    train_losses[::100],
    x=np.arange(0, len(train_losses), 100),
    xaxis="Epoch", yaxis="Loss", log_y=True, 
    title="Training Curve for Modular Addition")

# Analysing the Model

## Standard Things to Try

In [None]:
original_logits, cache = model.run_with_cache(dataset)
print(original_logits.numel())

Get key weight matrices:

In [None]:
W_E = model.embed.W_E[:-1]
print("W_E", W_E.shape)
W_neur = W_E @ model.blocks[0].attn.W_V @ model.blocks[0].attn.W_O @ model.blocks[0].mlp.W_in
print("W_neur", W_neur.shape)
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

In [None]:
original_loss = loss_fn(original_logits, labels).item()
print("Original Loss:", original_loss)

### Looking at Activations

Helper variable:

In [None]:
pattern_a = cache["pattern", 0, "attn"][:, :, -1, 0]
pattern_b = cache["pattern", 0, "attn"][:, :, -1, 1]
neuron_acts = cache["post", 0, "mlp"][:, -1, :]
neuron_pre_acts = cache["pre", 0, "mlp"][:, -1, :]

Get all shapes:

In [None]:
for param_name, param in cache.items():
    print(param_name, param.shape)

In [None]:
imshow(cache["pattern", 0].mean(dim=0)[:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [None]:
imshow(cache["pattern", 0][5][:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [None]:
dataset[:4]

In [None]:
imshow(cache["pattern", 0][:, 0, -1, 0].reshape(p, p, p), title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a")

In [None]:
npx.imshow(
    einops.rearrange(cache["pattern", 0][:, :, -1, 0], "(a b) head -> head a b", a=p, b=p), 
    title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a", facet_col=0)

Plotting neuron activations

In [None]:
cache["post", 0, "mlp"].shape

In [None]:
npx.imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p), 
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

### Singular Value Decomposition

In [None]:
W_E.shape

In [None]:
U, S, Vh = torch.svd(W_E)
npx.line(S, title="Singular Values")
npx.imshow(U, title="Principal Components on the Input")

In [None]:
# Control - random Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))
npx.line(S, title="Singular Values Random")
npx.imshow(U, title="Principal Components Random")

## Explaining Algorithm

### Analyse the Embedding - It's a Lookup Table!

In [None]:
U, S, Vh = torch.svd(W_E)
npx.line(U[:, :8].T, title="Principal Components of the embedding", xaxis="Input Vocabulary")

In [None]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")
for freq in range(1, p//2+1):
    fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Sin {freq}")
    fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Cos {freq}")
fourier_basis = torch.stack(fourier_basis, dim=0).cuda()
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)
npx.imshow(fourier_basis, xaxis="Input", yaxis="Component", y=fourier_basis_names)

In [None]:
npx.line(fourier_basis[:8], xaxis="Input", line_labels=fourier_basis_names[:8], title="First 8 Fourier Components")
npx.line(fourier_basis[25:29], xaxis="Input", line_labels=fourier_basis_names[25:29], title="Middle Fourier Components")

In [None]:
npx.imshow(fourier_basis @ fourier_basis.T, title="All Fourier Vectors are Orthogonal")

### Analyse the Embedding

In [None]:
npx.imshow(fourier_basis @ W_E, yaxis="Fourier Component", xaxis="Residual Stream", y=fourier_basis_names, title="Embedding in Fourier Basis")

In [None]:
npx.line((fourier_basis @ W_E).norm(dim=-1), xaxis="Fourier Component", x=fourier_basis_names, title="Norms of Embedding in Fourier Basis")

In [None]:
key_freqs = [17, 25, 32, 47]
key_freq_indices = [33, 34, 49, 50, 63, 64, 93, 94]
fourier_embed = fourier_basis @ W_E
key_fourier_embed = fourier_embed[key_freq_indices]
print("key_fourier_embed", key_fourier_embed.shape)
npx.imshow(key_fourier_embed @ key_fourier_embed.T, title="Dot Product of embedding of key Fourier Terms")

### Key Frequencies

In [None]:
npx.line(fourier_basis[[34, 50, 64, 94]], title="Cos of key freqs", line_labels=[34, 50, 64, 94])

In [None]:
npx.line(fourier_basis[[34, 50, 64, 94]].mean(0), title="Constructive Interference")

## Analyse Neurons

In [None]:
npx.imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p), 
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

In [None]:
npx.imshow(
    einops.rearrange(neuron_acts[:, 0], "(a b) -> a b", a=p, b=p), 
    title="First neuron act", xaxis="b", yaxis="a",)

In [None]:
npx.imshow(fourier_basis[94][None, :] * fourier_basis[94][:, None], title="Cos 47a * cos 47b")

In [None]:
npx.imshow(fourier_basis[94][None, :] * fourier_basis[0][:, None], title="Cos 47a * const")

In [None]:
npx.imshow(fourier_basis @ neuron_acts[:, 0].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 0", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
npx.imshow(fourier_basis @ neuron_acts[:, 5].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 5", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
npx.imshow(fourier_basis @ torch.randn_like(neuron_acts[:, 0]).reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of RANDOM", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

### Neuron Clusters

In [None]:
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)

In [None]:
neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()
for freq in range(0, p//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]
npx.imshow(neuron_freq_norm, xaxis="Neuron", yaxis="Freq", y=torch.arange(1, p//2+1), title="Neuron Frac Explained by Freq")

In [None]:
npx.line(neuron_freq_norm.max(dim=0).values.sort().values, xaxis="Neuron", title="Max Neuron Frac Explained over Freqs")

## Read Off the Neuron-Logit Weights to Interpret

In [None]:
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

In [None]:
npx.line((W_logit @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title="W_logit in the Fourier Basis")

In [None]:
neurons_17 = neuron_freq_norm[17-1]>0.85
neurons_17.shape

In [None]:
neurons_17.sum()

In [None]:
npx.line((W_logit[neurons_17] @ fourier_basis.T).norm(dim=0), x=fourier_basis_names, title="W_logit for freq 17 neurons in the Fourier Basis")

Study sin 17

In [None]:
freq = 17
W_logit_fourier = W_logit @ fourier_basis
neurons_sin_17 = W_logit_fourier[:, 2*freq-1]
npx.line(neurons_sin_17)

In [None]:
neuron_acts.shape

In [None]:
inputs_sin_17c = neuron_acts @ neurons_sin_17
npx.imshow(fourier_basis @ inputs_sin_17c.reshape(p, p) @ fourier_basis.T, title="Fourier Heatmap over inputs for sin17c", x=fourier_basis_names, y=fourier_basis_names)

# Black Box Methods + Progress Measures