# Grokking

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timaeus-research/devinterp/blob/main/examples/grokking.ipynb)

This notebook aims to show how LLC estimation is calibrated in a simple modular addition grokking example, showing a moderately interesting result at the end.

We'll starting off with some standard grokking code, adapted loosely from Nina Panickssery and Dmitry Vaintrob's [modular addition learning coefficient post](https://www.alignmentforum.org/posts/4v3hMuKfsGatLXPgt/investigating-the-learning-coefficient-of-modular-addition) and [github code repo](https://github.com/nrimsky/devinterp). (Thank you for your help!)

In [None]:
%pip install devinterp nbformat
%pip install devinterp[vis]

In [None]:
import random
from copy import deepcopy
from dataclasses import dataclass

import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from devinterp.optim.sgld import SGLD
from devinterp.slt.sampler import estimate_learning_coeff_with_summary
from devinterp.utils import evaluate_mse

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
@dataclass
class ExperimentParams:
    vocab_size: int = 26 #size of vocabulary (a-z)
    seq_length: int = 10 #length of input sequences
    n_samples: int = 2000 #Total samples to generate (replaces p)
    # p: int = 53 #prime modulus (learns addition mod 53) -> CHANGE to MY MODEL FRACTION X PREVIOUS TOKENS
    n_batches: int = 50000 #Number of training steps was 25000
    n_save_model_checkpoints: int = 100
    print_times: int = 100
    lr: float = 0.01 #learning rate #raise to help memorization for grokking
    batch_size: int = 256 #batch size was 128
    hidden_size: int = 48 #hidden layer size (was 48)
    embed_dim: int = 64 #embedding dimension increased from 12
    # train_frac: float = 0.4 #use 40% of data for training
    train_frac: float = 0.3 #try 30% for grokking behavior more memorization
    # the shown grokking / llc curve behavior is robust to change of seed from my experiments, but not all seeds show grokking withying the first 100 checkpoints, NB!
    random_seed: int = 0
    device: str = DEVICE
    # weight_decay: float = 0.0002 #regularization
    weight_decay: float = 0.1 # allow for more simpler solutions [generalization] to show grokking

class MLP(nn.Module):  # Keep same name for compatibility
    def __init__(self, params):
        super().__init__()
        self.embedding = nn.Embedding(params.vocab_size, params.embed_dim)
        self.pos_encoding = nn.Embedding(params.seq_length, params.embed_dim)
        
        # Use TransformerEncoderLayer (like your teammate did)
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=params.embed_dim,
                nhead=4,  # Number of attention heads
                dim_feedforward=params.hidden_size,
                batch_first=True
            )
            for _ in range(2)  # 2 layers
        ])
        
        # Output: single value per position (the fraction)
        self.output = nn.Linear(params.embed_dim, 1)
        #initialize larger weights
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p, gain=2.0) 
        
    def forward(self, x):
        # x: (batch, seq_len) - tensor of token indices
        batch_size, seq_len = x.shape
        
        # Embed tokens
        token_embeds = self.embedding(x)  # (batch, seq_len, embed_dim)
        
        # Add positional encoding
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
        pos_embeds = self.pos_encoding(positions)
        x = token_embeds + pos_embeds
        
        # Create causal mask (position i only sees 0...i)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
        
        # Apply transformer layers with causal mask
        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x, src_mask=causal_mask)
        
        # Output fraction at each position
        x = self.output(x).squeeze(-1)  # (batch, seq_len)
        return torch.sigmoid(x)  # Constrain to [0, 1]

def test(model, dataset, device):
    total_loss = 0
    total_mae = 0  # Mean Absolute Error
    total_positions = 0
    model.eval()
    loss_fn = nn.MSELoss()
    
    with torch.no_grad():
        for x, y in dataset:
            x, y = x.to(device), y.to(device)
            
            # x is (seq_len,), y is (seq_len,)
            # Need to add batch dimension
            x = x.unsqueeze(0)  # (1, seq_len)
            y = y.unsqueeze(0)  # (1, seq_len)
            
            out = model(x)  # (1, seq_len)
            
            loss = loss_fn(out, y)
            total_loss += loss.item()
            
            # Calculate MAE (Mean Absolute Error)
            mae = torch.abs(out - y).sum()
            total_mae += mae.item()
            total_positions += y.shape[1]  # Count all positions
    
    avg_loss = total_loss / len(dataset)
    avg_mae = total_mae / total_positions
    
    return avg_mae, avg_loss  # Return MAE instead of accuracy


def train(train_dataset, test_dataset, params, verbose=True):
    all_models = []
    model = MLP(params).to(params.device)
    optimizer = torch.optim.Adam(
        model.parameters(), weight_decay=params.weight_decay, lr=params.lr
    )
    # loss_fn = torch.nn.CrossEntropyLoss()
    loss_fn = torch.nn.MSELoss()

    train_loader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True)

    print_every = params.n_batches // params.print_times
    checkpoint_every = None
    if params.n_save_model_checkpoints > 0:
        checkpoint_every = params.n_batches // params.n_save_model_checkpoints

    loss_data = []
    if verbose:
        pbar = tqdm(total=params.n_batches, desc="Training")
    for i in range(params.n_batches):
        # Sample random batch of data
        batch = next(iter(train_loader))
        print(batch)
        X, Y = batch
        X, Y = X.to(params.device), Y.to(params.device)
        # Gradient update
        optimizer.zero_grad()
        out = model(X)
        loss = loss_fn(out, Y)
        loss.backward()
        optimizer.step()

        if checkpoint_every and (i + 1) % checkpoint_every == 0:
            all_models += [deepcopy(model)]

        if (i + 1) % print_every == 0:
            # val_acc, val_loss = test(model, test_dataset, params.device)
            val_mae, val_loss = test(model, test_dataset, params.device)
            # train_acc, train_loss = test(model, train_dataset, params.device)
            train_mae, train_loss = test(model, train_dataset, params.device)

            loss_data.append(
                {
                    "batch": i + 1,
                    "train_loss": train_loss,
                    "train_mae": train_mae,
                    "val_loss": val_loss,
                    "val_mae": val_mae,
                }
            )
            if verbose:
                pbar.set_postfix(
                    {
                        "train_loss": f"{train_loss:.4f}",
                        "train_acc": f"{train_mae:.4f}", #changed to mae
                        "val_loss": f"{val_loss:.4f}",
                        "val_acc": f"{val_mae:.4f}", #changed to mae
                    }
                )
                pbar.update(print_every)
    if verbose:
        pbar.close()
    df = pd.DataFrame(loss_data)
    train_mae, train_loss = test(model, train_dataset, params.device) #changes to mae
    val_mae, val_loss = test(model, test_dataset, params.device) #changed to mae
    if verbose:
        print(f"Final Train Acc: {train_mae:.4f} | Final Train Loss: {train_loss:.4f}") #changed to mae
        print(f"Final Val Acc: {val_mae:.4f} | Final Val Loss: {val_loss:.4f}") #changed to mae
    return all_models, df


def deterministic_shuffle(lst, seed):
    random.seed(seed)
    random.shuffle(lst)
    return lst


def get_all_pairs(p):
    pairs = []
    for i in range(p):
        for j in range(p):
            pairs.append((i, j))
    return set(pairs)

#CHANGE 1 - fraction of x's at each position
def make_dataset(p):  # Keep p parameter name for compatibility
    data = []
    vocab = list("abcdefghijklmnopqrstuvwxyz"[:26])
    char_to_idx = {char: idx for idx, char in enumerate(vocab)}
    x_token = "x"
    x_idx = char_to_idx[x_token]
    seq_length = 10
    
    for _ in range(p):
        # Generate random sequence - GUARANTEE at least 1 x
        num_x = random.randint(0, seq_length // 2)  # 1 to 5 x's
        
        # Create sequence with x's and other letters
        seq_chars = ['x'] * num_x
        for _ in range(seq_length - num_x):
            # Pick random letter that's not 'x'
            other_letters = [c for c in vocab if c != 'x']
            seq_chars.append(random.choice(other_letters))
        
        # Shuffle so x's are distributed randomly
        random.shuffle(seq_chars)
        
        # Convert to indices
        seq = [char_to_idx[c] for c in seq_chars]
        
        # Calculate fraction of x's at each position
        targets = []
        x_count = 0
        for i, token_idx in enumerate(seq):
            if token_idx == x_idx:
                x_count += 1
            fraction = x_count / (i + 1)
            targets.append(fraction)
        
        data.append((torch.tensor(seq), torch.tensor(targets, dtype=torch.float32)))
    
    return data

def train_test_split(dataset, train_split_proportion, seed):
    l = len(dataset)
    train_len = int(train_split_proportion * l)
    idx = list(range(l))
    idx = deterministic_shuffle(idx, seed)
    train_idx = idx[:train_len]
    test_idx = idx[train_len:]
    return [dataset[i] for i in train_idx], [dataset[i] for i in test_idx]

In [None]:
params = ExperimentParams()
torch.manual_seed(params.random_seed) #reproducibility

# dataset = make_dataset(params.p) # All 53*53 = 2809 pairs
dataset = make_dataset(params.n_samples) #changed to n_samples
train_data, test_data = train_test_split(dataset, params.train_frac, params.random_seed)

In [None]:
all_checkpointed_models, df = train(
    train_dataset=train_data, test_dataset=test_data, params=params
)

In [None]:
plt.plot(df["val_mae"], label="test")  # Changed from val_acc
plt.plot(df["train_mae"], label="train")  # Changed from train_acc
plt.legend()
plt.ylabel("Mean Absolute Error")  # Changed from "Correct answer %"
plt.xlabel("Checkpoint")
plt.title(f"Train & test MAE for fraction-of-x with vocab_size={params.vocab_size}") 

From this plot, we see the classic grokking behavior: although the train accuracy is perfect after a few iterations, it takes many more examples for the test accuracy to meaningfully improve. (Note that this is not the same statement as train loss being perfect, see below plot.)

In [None]:
plt.plot(df["val_loss"], label="test")
plt.plot(df["train_loss"], label="train")
plt.legend()
plt.ylabel("Loss")
plt.xlabel("Checkpoint")
plt.title(f"Train & test loss for modular addition with vocab={params.vocab_size}")

In [None]:
# Test the dataset
test_sample = dataset[0]
x, y = test_sample
print(f"Input indices: {x}")
print(f"Target fractions: {y}")
print(f"\nConverting back to chars:")
vocab = list("abcdefghijklmnopqrstuvwxyz"[:26])
chars = [vocab[i] for i in x]
print(f"Sequence: {chars}")
print(f"Targets:  {[f'{v:.3f}' for v in y.tolist()]}")

# Manual verification
x_count = 0
for i, char in enumerate(chars):
    if char == 'x':
        x_count += 1
    expected = x_count / (i + 1)
    print(f"Position {i}: '{char}' -> x_count={x_count}, expected={expected:.3f}, got={y[i]:.3f}")

## LLC estimation hyperparameter tuning (THIS HAS NOT BEEN IMPLEMENTED FOR FRACTION TOKEN YET)

In order to get LLC estimates for this simple grokking model over training, we first need to choose hyperparameters. The most important ones to calibrate are epsilon (the SGLD learning rate / step size) and n\*beta (the effective inverse temperature). Let's run a quick sweep over a wide range of epsilon and n\*beta, and look for a range of values within this which shows little change in LLC change in LLC values when we change epsilon and nbeta. We can use `devinterp.vis_utils.EpsilonBetaAnalyzer` for this.

In [None]:
import typing
from typing import Type

import numpy as np


def estimate_llc_given_model(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    evaluate: typing.Callable,
    epsilon: float,
    beta: float,
    sampling_method: Type[torch.optim.Optimizer] = SGLD,
    localization: float = 5.0,
    num_chains: int = 2,
    num_draws: int = 500,
    num_burnin_steps: int = 0,
    num_steps_bw_draws: int = 1,
    device: torch.device = DEVICE,
    online: bool = True,
    verbose: bool = False,
):
    sweep_stats = estimate_learning_coeff_with_summary(
        model,
        loader=loader,
        evaluate=evaluate,
        sampling_method=sampling_method,
        optimizer_kwargs=dict(lr=epsilon, localization=localization, nbeta=beta),
        num_chains=num_chains,  # How many independent chains to run
        num_draws=num_draws,  # How many samples to draw per chain
        num_burnin_steps=num_burnin_steps,  # How many samples to discard at the beginning of each chain
        num_steps_bw_draws=num_steps_bw_draws,  # How many steps to take between each sample
        device=device,
        online=online,
        verbose=verbose,
    )

    sweep_stats["llc/trace"] = np.array(sweep_stats["llc/trace"])
    return sweep_stats

In [None]:
from devinterp.vis_utils import EpsilonBetaAnalyzer

loader = DataLoader(train_data, shuffle=True, batch_size=params.batch_size)
analyzer = EpsilonBetaAnalyzer()
analyzer.configure_sweep(
    llc_estimator=estimate_llc_given_model,
    llc_estimator_kwargs=dict(
        model=all_checkpointed_models[-1],
        evaluate=evaluate_mse,
        device=DEVICE,
        loader=loader,
    ),
    min_epsilon=3e-5,
    max_epsilon=3e-1,
    epsilon_samples=5,
    min_beta=None,
    max_beta=None,
    beta_samples=5,
    dataloader=loader,
)
analyzer.sweep()

In [None]:
analyzer.plot()

From this, we can see that the final LLC flattens out if epsilon > 0.001, so that's the epsilon parameter range we should go for. But we also have some dependence of the llc on beta, which is maybe linear from the looks of it? We get our LLC estimates by taking (sampled_loss - initial_loss) * nbeta, so maybe that final nbeta term is what we're seeing here. Let's divide it out to see this better.

(Note that this does not quite mean the LLC curve should be fully linear in nbeta, as the choice of nbeta can and does influence the SGLD sampling process and so can change the sampled loss.)

In [None]:
analyzer.plot(div_out_beta=True)

From this, we can see that the effective sampled loss for low-ish nbetas (<100) shows very little dependence on the exact choice of nbeta. So let's a point in this flat region (~1), and a high-but-still-in-the-flat-region epsilon (0.03), so we don't need to run many draws, but still have little dependence of our samples on epsilon.

Let's check that the loss chain for these hyperparams looks decent, and then run LLC estimation on all trained checkpoints if it does.

In [None]:
lr = 1e-3
gamma = 5
nbeta = 2.0
num_draws = 75
num_chains = 2

In [None]:
learning_coeff_stats = estimate_learning_coeff_with_summary(
    all_checkpointed_models[-1],
    loader=DataLoader(train_data, batch_size=params.batch_size, shuffle=True),
    evaluate=evaluate_mse,
    sampling_method=SGLD,
    optimizer_kwargs=dict(lr=0.03, nbeta=2.0, localization=5.0),
    num_chains=3,
    num_draws=1500,
    device=DEVICE,
    online=True,
)
trace = learning_coeff_stats["loss/trace"]

In [None]:
from devinterp.utils import plot_trace

plot_trace(
    trace,
    "Loss",
    x_axis="Step",
    title=f"Loss Trace, avg LLC = {sum(learning_coeff_stats['llc/means']) / len(learning_coeff_stats['llc/means']):.2f}",
    plot_mean=False,
    plot_std=False,
    fig_size=(12, 9),
    true_lc=None,
)

This looks good! The loss flattens out nicely, and well within the num_draws we chose. Looks like we can get away with using 500 draws, as the loss trace has well flattened out by then.

In [None]:
llcs = [
    estimate_learning_coeff_with_summary(
        model_checkpoint,
        loader=DataLoader(train_data, batch_size=params.batch_size, shuffle=True),
        evaluate=evaluate_mse,
        sampling_method=SGLD,
        optimizer_kwargs=dict(lr=lr, nbeta=nbeta, localization=gamma),
        num_chains=1,
        num_draws=num_draws,
        device=DEVICE,
        online=False,
    )
    for model_checkpoint in all_checkpointed_models
]

In [None]:
fig, ax1 = plt.subplots()
plt.title(
    f"Lambdahat vs acc for fraction token vocab={params.vocab_size}, train_frac={params.train_frac}, nβ={nbeta:.1f}, ε={lr}, γ={gamma}, num_draws={num_draws}, num_chains={num_chains}"
)

ax2 = ax1.twinx()
ax1.plot(df["val_mae"], label="test mae")
ax1.plot(df["train_mae"], label="train mae")
ax2.plot([llc["llc/mean"] for llc in llcs], color="g", label="Lambdahat")
ax1.set_xlabel("Checkpoint no.")
fig.legend(loc="center right")

fig.show()

fig, ax1 = plt.subplots()
plt.title(
    f"Lambdahat vs loss for fraction token, vocab={params.vocab_size}, train_frac={params.train_frac}, nβ={nbeta:.1f}, ε={lr}, γ={gamma}, num_draws={num_draws}, num_chains={num_chains}"
)
ax2 = ax1.twinx()
ax1.plot(df["val_mae"], label="test mae")
ax1.plot(df["train_mae"], label="train mae")
ax2.plot([llc["llc/mean"] for llc in llcs], color="g", label="Lambdahat")
ax1.set_xlabel("Checkpoint no.")
fig.legend(loc="center right")

That's interesting!

In the first plot, we see that the LLC first increases during memorization and then decreases ~smoothly afterward, flattening out after the model is done grokking. This is basically what we would expect from a simple reading of phase transitions in the free energy formula.

From the second plot, we see that the LLC, which was measured only on the train set, tracks the test loss pretty well. That was a big surprise for me when I made this notebook, and I don't know what it means.

Anyway, I hope this notebook clarifies how one can use the devinterp library and LLC estimation more generally to gain insight in the development of structure in neural networks. Thanks for reading!