# Dev Interp - Grokking Modular Addition and Multiplication

# Setup

In [73]:
TRAIN_MODEL = False

In [74]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = True
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install circuitsvis

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload








In [75]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: notebook_connected


In [76]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [77]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [78]:
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, HookedTransformerConfig

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"device: {device}")

device: cpu


Plotting helper functions:

In [79]:
from collections import deque

def rolling_average(values, window_size):
    window = deque(maxlen=window_size)
    averages = []
    
    for value in values:
        window.append(value)
        averages.append(sum(window) / len(window))
    
    return averages

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Model Training

In [80]:
DATA_SEED = 598

## Create the dataset
* Define modular addition and modular multiplication datasets
* Combine the two data sets
* We can vary the proportion of addition to multiplication in the training dataset

In [81]:
import pandas as pd
from torch.utils.data import TensorDataset

max_nums = 130
mod_value = 113

# The fraction of the data that should be used for training (not all of the rest will be used for testing)
train_frac = 0.475

# The fraction of the training data that should be addition (the rest will be multiplication)
addition_frac = 0.684

def create_dataset(max_nums: int, addition: bool, mod_value: int):
    data = list()
    for a in range(max_nums):
        for b in range(max_nums):
            if addition:
                label = (a + b) % mod_value
            else:
                label = (a * b) % mod_value
            data.append([[a, int(addition), b], label])

    return pd.DataFrame(data, columns=["input", "label"])

addition_df = create_dataset(max_nums, True, mod_value)
multiplication_df = create_dataset(max_nums, False, mod_value)

print(f"Addition dataset size = {len(addition_df)}")
print(f"Multiplication dataset size = {len(multiplication_df)}")

total_train_size = int((len(addition_df) + len(multiplication_df)) * train_frac)

print(f"Total train size = {total_train_size}")
# Calculate the sizes for train datasets based on the desired proportion
add_train_size = int(total_train_size * addition_frac)
multi_train_size = total_train_size - add_train_size

# Determine the size for test datasets (use the remaining data, but ensure equal sizes)
test_size = min(len(addition_df) - add_train_size, len(multiplication_df) - multi_train_size)

# Create train datasets
add_train_df = addition_df.sample(n=add_train_size, random_state=DATA_SEED)
multi_train_df = multiplication_df.sample(n=multi_train_size, random_state=DATA_SEED)

# Create test datasets with equal size
add_test_df = addition_df.drop(add_train_df.index).sample(n=test_size, random_state=DATA_SEED)
multi_test_df = multiplication_df.drop(multi_train_df.index).sample(n=test_size, random_state=DATA_SEED)

# Print sizes for verification
print(f"Addition train size = {len(add_train_df)}")
print(f"Addition test size = {len(add_test_df)}")
print(f"Multiplication train size = {len(multi_train_df)}")
print(f"Multiplication test size = {len(multi_test_df)}")

print(f"Addition: {len(add_train_df)/(len(add_train_df) + len(multi_train_df))*100:0.1f}% \
Multiplication: {len(multi_train_df)/(len(add_train_df) + len(multi_train_df))*100:0.1f}%")

# Combine and shuffle the datasets
train_df = pd.concat([add_train_df, multi_train_df], ignore_index=True).sample(frac=1, random_state=DATA_SEED).reset_index(drop=True)
test_df = pd.concat([add_test_df, multi_test_df], ignore_index=True).sample(frac=1, random_state=DATA_SEED).reset_index(drop=True)

print(f"Combined dataset = {len(train_df) + len(test_df)}")

print(f"Train size = {len(train_df)}")
print(f"Test size = {len(test_df)}")

# Create the dataloaders
def get_dataloader(df, batch_size, shuffle):
    inputs = torch.tensor(df['input'].tolist())
    labels = torch.tensor(df['label'].tolist())
    dataset = TensorDataset(inputs, labels)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

train_loader = get_dataloader(train_df, 1024, shuffle=True)
test_loader = get_dataloader(test_df, len(test_df), shuffle=False)

train_data = torch.tensor(train_df['input'].tolist())
train_labels = torch.tensor(train_df['label'].tolist())

test_data = torch.tensor(test_df['input'].tolist())
test_labels = torch.tensor(test_df['label'].tolist())

add_test_data = torch.tensor(add_test_df['input'].tolist())
add_test_labels = torch.tensor(add_test_df['label'].tolist())

multi_test_data = torch.tensor(multi_test_df['input'].tolist())
multi_test_labels = torch.tensor(multi_test_df['label'].tolist())

print(f"train_data.shape = {train_data.shape}")
print(f"train_labels.shape = {train_labels.shape}")
print(f"test_data.shape = {test_data.shape}")
print(f"test_labels.shape = {test_labels.shape}")
print(f"train_data[:10] = {train_data[:10]}")


Addition dataset size = 16900
Multiplication dataset size = 16900
Total train size = 16055
Addition train size = 10981
Addition test size = 5919
Multiplication train size = 5074
Multiplication test size = 5919
Addition: 68.4% Multiplication: 31.6%
Combined dataset = 27893
Train size = 16055
Test size = 11838
train_data.shape = torch.Size([16055, 3])
train_labels.shape = torch.Size([16055])
test_data.shape = torch.Size([11838, 3])
test_labels.shape = torch.Size([11838])
train_data[:10] = tensor([[ 63,   1,  53],
        [ 58,   0,  22],
        [ 37,   1,   6],
        [ 58,   1, 103],
        [ 70,   1,  58],
        [120,   1, 119],
        [ 28,   1, 111],
        [ 49,   0,  80],
        [118,   1,  96],
        [ 44,   1,   2]])


## Define Model

In [82]:

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type="LN",
    d_vocab=max_nums+1,
    d_vocab_out=mod_value,
    n_ctx= train_data.shape[1],
    init_weights=True,
    device=device,
    seed = 999,
)

model = HookedTransformer(cfg)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [83]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [84]:
# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.90, 0.98)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [85]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()

train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(f"train_loss = {train_loss}")
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(f"test_loss = {test_loss}")

train_loss = 5.047642391787601
test_loss = 5.06579144378938


## Actually Train

In [86]:
num_epochs = 6000
checkpoint_every = 100
train_losses = []
test_losses = []
add_test_losses = []
multi_test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if TRAIN_MODEL:
    for epoch in tqdm.tqdm(range(num_epochs)):
        model.train()
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())
        optimizer.step()
        optimizer.zero_grad()

        model.eval()
        with torch.inference_mode():
            # General test loss
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

            if (epoch+1) % 10 == 0:
                # Addition test loss
                add_test_logits = model(add_test_data)
                add_test_loss = loss_fn(add_test_logits, add_test_labels)
                add_test_losses.append(add_test_loss.item())

                # Multiplication test loss
                multi_test_logits = model(multi_test_data)
                multi_test_loss = loss_fn(multi_test_logits, multi_test_labels)
                multi_test_losses.append(multi_test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            checkpoint_epochs.append(epoch)
            model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()} Add Test Loss {add_test_loss.item()} Multi Test Loss {multi_test_loss.item()}")

In [87]:
PTH_LOCATION = "../saves/grokking_add_multi_68%_addition_47.5%_train.pth"
if TRAIN_MODEL:
    # Create the directory if it does not exist
    os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)
    
    print(f"len(train_losses) = {len(train_losses)} len(test_losses) = {len(test_losses)} len(model_checkpoints) = {len(model_checkpoints)}")
    torch.save(
        {
            "model":model.state_dict(),
            "config": model.cfg,
            "checkpoints": model_checkpoints,
            "checkpoint_epochs": checkpoint_epochs,
            "test_losses": test_losses,
            "train_losses": train_losses,
            "add_test_losses": add_test_losses,
            "multi_test_losses": multi_test_losses,
            "max_nums": max_nums,
            "mod_value": mod_value,
            "train_frac": train_frac,
            "addition_frac": addition_frac,
        },
        PTH_LOCATION)

In [88]:
LOAD_LOCATION = "../saves/grokking_add_multi_0.7.pth"
if not TRAIN_MODEL:
    cached_data = torch.load(LOAD_LOCATION, weights_only=False)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    add_test_losses = cached_data['add_test_losses']
    multi_test_losses = cached_data['multi_test_losses']
    max_nums = cached_data['max_nums']
    mod_value = cached_data['mod_value']
    train_frac = cached_data['train_frac']
    addition_frac = cached_data['addition_frac']
    print(f"train_frac = {train_frac} addition_frac = {addition_frac}")
    print(f"len(train_losses) = {len(train_losses)} len(test_losses) = {len(test_losses)} len(model_checkpoints) = {len(model_checkpoints)}")

train_frac = 0.5 addition_frac = 0.7
len(train_losses) = 5000 len(test_losses) = 5000 len(model_checkpoints) = 50


In [89]:
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(f"test_loss = {test_loss}")

test_loss = 0.30149553907995896


## Show Model Training Statistics, Check that it groks!

In [90]:
from neel_plotly.plot import line
step = 10
average_window = 200
train_losses_avg = rolling_average(train_losses, average_window)
test_losses_avg = rolling_average(test_losses, average_window)
add_test_losses_avg = rolling_average(add_test_losses, average_window//10)
multi_test_losses_avg = rolling_average(multi_test_losses, average_window//10)

line([train_losses[::step], test_losses[::step], add_test_losses, multi_test_losses], x=np.arange(0, len(train_losses_avg), step), xaxis="Epoch", yaxis="Loss", log_y=False, title=f"Training Curve for Modular Arithmetic - {addition_frac*100:.0f}% Addition", line_labels=['train loss', 'test loss', 'add loss', 'multi loss'], toggle_x=True, toggle_y=True)
line([train_losses_avg[::step], test_losses_avg[::step], add_test_losses_avg, multi_test_losses_avg], x=np.arange(0, len(train_losses_avg), step), xaxis="Epoch", yaxis="Rolling Avg Loss", log_y=False, title=f"Training Curve for Modular Arithmetic - {addition_frac*100:.0f}% Addition", line_labels=['train', 'test', 'add loss', 'multi loss'], toggle_x=True, toggle_y=True)

### Estiamtating Local Learning Coefficient (RLCT)

In [120]:
from devinterp.slt.sampler import estimate_learning_coeff_with_summary, SGLD

def evaluate(model, data):
    inputs, outputs = data

    return loss_fn(model(inputs), outputs), {
        "logits": model(inputs)
    }

# optimizer_kwargs = {"lr": 1e-5, "localization": 100.0, "noise_level": 1.0}
# num_draws = 400

results = estimate_learning_coeff_with_summary(
            model,
            loader=train_loader,
            evaluate=evaluate,
            sampling_method=SGLD,
            optimizer_kwargs=dict(lr=4e-4, localization=100.0),
            num_chains=3,           # How many independent chains to run
            num_draws=5,            # How many samples to draw per chain
            num_burnin_steps=10,    # How many samples to discard at the beginning of each chain
            num_steps_bw_draws=1,   # How many steps to take between each sample
            device=device,
            online=True,
        )



Moving model to device:  cpu
Moving model to device:  cpu
Moving model to device:  cpu


Chain 0: 100%|██████████| 15/15 [00:00<00:00, 28.92it/s]


Moving model to device:  cpu


Chain 1: 100%|██████████| 15/15 [00:00<00:00, 32.70it/s]


Moving model to device:  cpu


Chain 2: 100%|██████████| 15/15 [00:00<00:00, 25.40it/s]


In [119]:
estimate = results["llc/means"]
print(f"estimate = {estimate}")
print(f"len(estimate) = {len(estimate)}")
print(f"results.keys() = {results.keys()}")

estimate = [6.5082170e-02 1.0739037e+03 1.0843871e+03 9.7386206e+02 8.9099689e+02
 8.0364478e+02 8.0303638e+02 8.0884467e+02 7.9820660e+02 7.8117023e+02
 7.7506818e+02 7.6936987e+02 7.7463477e+02 7.8124323e+02 7.8629517e+02]
len(estimate) = 15
results.keys() = dict_keys(['init_loss', 'llc/means', 'llc/stds', 'llc/trace', 'loss/trace'])
