<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Grokking_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Dev Interp - Grokking Modular Addition and Multiplication

# Setup

In [2]:
TRAIN_MODEL = True

In [3]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = True
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install circuitsvis

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [4]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: notebook_connected


In [5]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [6]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [7]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"device: {device}")

device: cpu


Plotting helper functions:

In [8]:
from collections import deque

def rolling_average(values, window_size):
    window = deque(maxlen=window_size)
    averages = []
    
    for value in values:
        window.append(value)
        averages.append(sum(window) / len(window))
    
    return averages

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Model Training

## Config

In [9]:


DATA_SEED = 598

## Define Task
* Define modular addition and modular multiplication datasets
* Combine the two data sets
* We can vary the proportion of addition to multiplication in the training dataset

In [14]:
import pandas as pd
from torch.utils.data import TensorDataset

max_nums = 113
mod_value = 113
# batch_size = 128

add_frac = 0.65
multi_frac = 0.65

def create_dataset(max_nums: int, addition: bool, mod_value: int):
    data = list()
    for a in range(max_nums):
        for b in range(max_nums):
            if addition:
                label = (a + b) % mod_value
            else:
                label = (a * b) % mod_value
            data.append([[a, int(addition), b], label])

    return pd.DataFrame(data, columns=["input", "label"])

addition_df = create_dataset(max_nums, True, mod_value)
multiplication_df = create_dataset(max_nums, False, mod_value)

print(f"Addition dataset size = {len(addition_df)}")
print(f"Multiplication dataset size = {len(multiplication_df)}")

# Calculate the sizes for train datasets
add_train_size = int(len(addition_df) * add_frac)
multi_train_size = int(len(multiplication_df) * multi_frac)

# Determine the size for test datasets (use the smaller of the two remaining sets)
test_size = min(len(addition_df) - add_train_size, len(multiplication_df) - multi_train_size)

# Create train datasets
add_train_df = addition_df.sample(n=add_train_size, random_state=DATA_SEED)
multi_train_df = multiplication_df.sample(n=multi_train_size, random_state=DATA_SEED)

# Create test datasets with equal size
add_test_df = addition_df.drop(add_train_df.index).sample(n=test_size, random_state=DATA_SEED)
multi_test_df = multiplication_df.drop(multi_train_df.index).sample(n=test_size, random_state=DATA_SEED)

# Print sizes for verification
print(f"Addition train size = {len(add_train_df)}")
print(f"Addition test size = {len(add_test_df)}")
print(f"Multiplication train size = {len(multi_train_df)}")
print(f"Multiplication test size = {len(multi_test_df)}")

print(f"Addition: {len(add_train_df)/(len(add_train_df) + len(multi_train_df))*100:0.1f}% \
Multiplication: {len(multi_train_df)/(len(add_train_df) + len(multi_train_df))*100:0.1f}%")

# Combine and shuffle the datasets
train_df = pd.concat([add_train_df, multi_train_df], ignore_index=True).sample(frac=1, random_state=DATA_SEED).reset_index(drop=True)
test_df = pd.concat([add_test_df, multi_test_df], ignore_index=True).sample(frac=1, random_state=DATA_SEED).reset_index(drop=True)

print(f"Combined dataset = {len(train_df) + len(test_df)}")

print(f"Train size = {len(train_df)}")
print(f"Test size = {len(test_df)}")

# Create the dataloaders
def get_dataloader(df, batch_size, shuffle):
    inputs = torch.tensor(df['input'].tolist())
    labels = torch.tensor(df['label'].tolist())
    dataset = TensorDataset(inputs, labels)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

add_test_loader = get_dataloader(add_test_df, len(add_test_df), shuffle=False)
multi_test_loader = get_dataloader(multi_test_df, len(multi_test_df), shuffle=False)

train_loader = get_dataloader(train_df, 1024, shuffle=True)
test_loader = get_dataloader(test_df, len(test_df), shuffle=False)

train_data = torch.tensor(train_df['input'].tolist())
train_labels = torch.tensor(train_df['label'].tolist())
test_data = torch.tensor(test_df['input'].tolist())
test_labels = torch.tensor(test_df['label'].tolist())

print(f"train_data.shape = {train_data.shape}")
print(f"train_labels.shape = {train_labels.shape}")
print(f"test_data.shape = {test_data.shape}")
print(f"test_labels.shape = {test_labels.shape}")


Addition dataset size = 12769
Multiplication dataset size = 12769
Addition train size = 8299
Addition test size = 4470
Multiplication train size = 8299
Multiplication test size = 4470
Addition: 50.0% Multiplication: 50.0%
Combined dataset = 25538
Train size = 16598
Test size = 8940
train_data.shape = torch.Size([16598, 3])
train_labels.shape = torch.Size([16598])
test_data.shape = torch.Size([8940, 3])
test_labels.shape = torch.Size([8940])


In [349]:
'''
num_of_ops = 2
max_nums = 113
mod_value = 113

# Generate all combinations
a_vector = einops.repeat(torch.arange(max_nums), "a -> (a b o)", b=max_nums, o=num_of_ops)
b_vector = einops.repeat(torch.arange(max_nums), "b -> (a b o)", a=max_nums, o=num_of_ops)
operations = einops.repeat(torch.arange(num_of_ops), "o -> (a b o)", a=max_nums, b=max_nums)
mod_vector = torch.full((max_nums * max_nums * num_of_ops,), mod_value)

# Stack the vectors to create the
if num_of_ops == 1:
    dataset = torch.stack([a_vector, b_vector], dim=1)
else:
    dataset = torch.stack([a_vector, operations, b_vector], dim=1)

print(f"dataset.shape = {dataset.shape}")
print(f"dataset[:5] = {dataset[:5]}")
print(f"dataset[len(dataset) - 10:] = {dataset[len(dataset) - 10:]}")

if num_of_ops == 1:
    labels = (dataset[:, 0] + dataset[:, 1]) % mod_value
else:
    labels = ((dataset[:, 0] + dataset[:, 2]) * (1 - dataset[:, 1]) + 
                (dataset[:, 0] * dataset[:, 2]) * dataset[:, 1]) % mod_value
print(labels.shape)
print(labels[:100])

frac_train = 0.65
torch.manual_seed(DATA_SEED)
indices = torch.randperm(len(dataset))
cutoff = int(len(indices)*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(f"train_data.shape = {train_data.shape}")
print(f"train_labels.shape = {train_labels.shape}")
print(f"test_data.shape = {test_data.shape}")
print(f"test_labels.shape = {test_labels.shape}")
'''


'\nnum_of_ops = 2\nmax_nums = 113\nmod_value = 113\n\n# Generate all combinations\na_vector = einops.repeat(torch.arange(max_nums), "a -> (a b o)", b=max_nums, o=num_of_ops)\nb_vector = einops.repeat(torch.arange(max_nums), "b -> (a b o)", a=max_nums, o=num_of_ops)\noperations = einops.repeat(torch.arange(num_of_ops), "o -> (a b o)", a=max_nums, b=max_nums)\nmod_vector = torch.full((max_nums * max_nums * num_of_ops,), mod_value)\n\n# Stack the vectors to create the\nif num_of_ops == 1:\n    dataset = torch.stack([a_vector, b_vector], dim=1)\nelse:\n    dataset = torch.stack([a_vector, operations, b_vector], dim=1)\n\nprint(f"dataset.shape = {dataset.shape}")\nprint(f"dataset[:5] = {dataset[:5]}")\nprint(f"dataset[len(dataset) - 10:] = {dataset[len(dataset) - 10:]}")\n\nif num_of_ops == 1:\n    labels = (dataset[:, 0] + dataset[:, 1]) % mod_value\nelse:\n    labels = ((dataset[:, 0] + dataset[:, 2]) * (1 - dataset[:, 1]) + \n                (dataset[:, 0] * dataset[:, 2]) * dataset[:, 1

## Define Model

In [350]:

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type="LN",
    d_vocab=max_nums+1,
    d_vocab_out=mod_value,
    n_ctx= train_data.shape[1],
    init_weights=True,
    device=device,
    seed = 999,
)

model = HookedTransformer(cfg)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [351]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [352]:
# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.90, 0.98)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [353]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()

'''
# Get one batch from train_loader
train_data, train_labels = next(iter(train_loader))

# Compute train loss
train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(f"Train loss: {train_loss.item()}")

# Get one batch from test_loader
test_data, test_labels = next(iter(test_loader))

# Compute test loss
model.eval()
with torch.no_grad():
    test_logits = model(test_data)
    test_loss = loss_fn(test_logits, test_labels)
    print(f"Test loss: {test_loss.item()}")
'''

'\n# Get one batch from train_loader\ntrain_data, train_labels = next(iter(train_loader))\n\n# Compute train loss\ntrain_logits = model(train_data)\ntrain_loss = loss_fn(train_logits, train_labels)\nprint(f"Train loss: {train_loss.item()}")\n\n# Get one batch from test_loader\ntest_data, test_labels = next(iter(test_loader))\n\n# Compute test loss\nmodel.eval()\nwith torch.no_grad():\n    test_logits = model(test_data)\n    test_loss = loss_fn(test_logits, test_labels)\n    print(f"Test loss: {test_loss.item()}")\n'

## Actually Train

In [354]:
num_epochs = 4000
checkpoint_every = 100
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if TRAIN_MODEL:
    for epoch in tqdm.tqdm(range(num_epochs)):
        model.train()
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())
        optimizer.step()
        optimizer.zero_grad()

        model.eval()
        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            checkpoint_epochs.append(epoch)
            model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch 99 Train Loss 2.7838976809995835 Test Loss 5.42454083928484
Epoch 199 Train Loss 0.7165533204930278 Test Loss 6.316207744018711
Epoch 299 Train Loss 0.17908328071398139 Test Loss 7.165945008941403
Epoch 399 Train Loss 0.06829362298653581 Test Loss 7.67208767453406
Epoch 499 Train Loss 0.028300404749086756 Test Loss 7.968013660346753
Epoch 599 Train Loss 0.031500062684813576 Test Loss 7.738093698335426
Epoch 699 Train Loss 0.020929337261837527 Test Loss 7.870526332706938
Epoch 799 Train Loss 0.0635983767037389 Test Loss 7.474254963355664
Epoch 899 Train Loss 0.03263485904236465 Test Loss 7.190318190637799
Epoch 999 Train Loss 0.019759603696776363 Test Loss 7.133379203119872
Epoch 1099 Train Loss 0.03839824552882295 Test Loss 6.409300465237553
Epoch 1199 Train Loss 0.029663208432289357 Test Loss 6.1865673985865035
Epoch 1299 Train Loss 0.05311028475869496 Test Loss 5.380830525169386
Epoch 1399 Train Loss 0.03673764538977396 Test Loss 5.068143673210882
Epoch 1499 Train Loss 0.056128

In [159]:
PTH_LOCATION = "../saves/grokking_add_and_multi.pth"
if TRAIN_MODEL:
    # Create the directory if it does not exist
    os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)
    
    print(f"len(train_losses) = {len(train_losses)} len(test_losses) = {len(test_losses)} len(model_checkpoints) = {len(model_checkpoints)}")
    torch.save(
        {
            "model":model.state_dict(),
            "config": model.cfg,
            "checkpoints": model_checkpoints,
            "checkpoint_epochs": checkpoint_epochs,
            "test_losses": test_losses,
            "train_losses": train_losses,
        },
        PTH_LOCATION)

len(train_losses) = 155 len(test_losses) = 155 len(model_checkpoints) = 1


In [165]:
LOAD_LOCATION = "../saves/grokking_add_65_and_multi_65.pth"
if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION, weights_only=False)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    print(f"len(train_losses) = {len(train_losses)} len(test_losses) = {len(test_losses)} len(model_checkpoints) = {len(model_checkpoints)}")

len(train_losses) = 155 len(test_losses) = 155 len(model_checkpoints) = 1


## Show Model Training Statistics, Check that it groks!

In [355]:
from neel_plotly.plot import line
step = 1
train_losses_avg = rolling_average(train_losses, 10)
test_losses_avg = rolling_average(test_losses, 10)

line([train_losses[::step], test_losses[::step]], x=np.arange(0, len(train_losses), step), xaxis="Epoch", yaxis="Loss", log_y=False, title="Training Curve for Modular Arithmetic", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)