# Grokking MLP
---
This notebook will demonstrate the grokking phenomenon on a simple MLP trained on a small modular addition dataset.

Based on OpenAI's Grokking paper: [Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets](https://arxiv.org/abs/2201.02177)

In [1]:
from tqdm.auto import tqdm
import time
import numpy as np
import einops
from einops.layers.torch import Rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
from sklearn.decomposition import PCA
import plotly.graph_objects as go
from plotly.subplots import make_subplots

## Create Dataset

Dataset will be modular addition `a + b (mod p)` where p is equal to 97. Each value will be a different token that will feed into an embedding layer i.e. "a" "operation" "b" "=" "y"

In [2]:
# Define starting prime number and character mapping
p = 97
operation = "+"
char2idx = {str(i): i for i in range(p)}
char2idx[operation] = len(char2idx)
char2idx["="] = len(char2idx)
idx2char = {v: k for k, v in char2idx.items()}

# Create dataset for modular addition
a = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b = einops.repeat(torch.arange(p), "i -> (j i)", j=p)
y = (a + b) % p
op = torch.full((p*p,), char2idx[operation])
eq = torch.full((p*p,), char2idx["="])

# Combine into a single tensor for features and labels
features = torch.stack([a, op, b, eq], dim=-1)
y_ohe = F.one_hot(y, len(idx2char)).to(torch.float32)

print(a.shape, b.shape, y.shape, y.dtype, features.shape, features.dtype)

torch.Size([9409]) torch.Size([9409]) torch.Size([9409]) torch.int64 torch.Size([9409, 4]) torch.int64


In [3]:
def convert_to_str(tokens):
    """Convert a tensor vector of token ids to string"""
    return " ".join([idx2char[x] for x in tokens.tolist()])

In [4]:
for i in np.random.randint(0, features.shape[0], size=10):
    print(features[i], y[i])
    print(convert_to_str(torch.cat([features[i], y[i].unsqueeze(0)])))
# y

tensor([32, 97, 72, 98]) tensor(7)
32 + 72 = 7
tensor([56, 97, 81, 98]) tensor(40)
56 + 81 = 40
tensor([43, 97, 41, 98]) tensor(84)
43 + 41 = 84
tensor([75, 97, 33, 98]) tensor(11)
75 + 33 = 11
tensor([58, 97, 88, 98]) tensor(49)
58 + 88 = 49
tensor([52, 97, 91, 98]) tensor(46)
52 + 91 = 46
tensor([53, 97, 15, 98]) tensor(68)
53 + 15 = 68
tensor([70, 97, 94, 98]) tensor(67)
70 + 94 = 67
tensor([37, 97, 88, 98]) tensor(28)
37 + 88 = 28
tensor([ 6, 97, 74, 98]) tensor(80)
6 + 74 = 80


In [5]:
# Split into train/val sets
rng_generator = torch.Generator().manual_seed(21)

# Create train and val set randomly
data = TensorDataset(features, y_ohe)
train_data, val_data = random_split(data, [0.8, 0.2], generator=rng_generator)

# Create train and val set based on condition
# mask = (y == 1) | (y == 2)
# train_data = TensorDataset(features[mask], y_ohe[mask])
# val_data = TensorDataset(features[~mask], y_ohe[~mask])

train_dataloader = DataLoader(
    train_data,
    batch_size=512,
    shuffle=True,
    generator=rng_generator
)

val_dataloader = DataLoader(
    val_data,
    batch_size=512,
    shuffle=True,
    generator=rng_generator
)

print(f"Total samples in train set: {len(train_data):,}")
print(f"Total samples in val set: {len(val_data):,}")

print(f"Total number of batches in train dataloader: {len(train_dataloader):,}")
print(f"Total number of batches in validation dataloader: {len(val_dataloader):,}")

Total samples in train set: 7,528
Total samples in val set: 1,881
Total number of batches in train dataloader: 15
Total number of batches in validation dataloader: 4


## Create Model

Original paper trained a 2-layer decoder-only transformer with 400,000 non-embedding parameters

In [6]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Embedding(len(idx2char), 128),
            Rearrange("batch token d_model -> batch (token d_model)"),
            nn.Linear(4*128, 128),
            nn.ReLU(),
            nn.Linear(128, len(idx2char))
        )
        self.initialize_params()

    def initialize_params(self):
        for layer in self.net:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
                nn.init.zeros_(layer.bias)

    def forward(self, x):
        return self.net(x)

In [7]:
net = MLP()

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=1e-3, betas=(0.9, 0.98))

# writer = SummaryWriter("runs/exp1")

print(f"Number of parameters in model: {sum([x.numel() for x in net.parameters()]):,}")

Number of parameters in model: 91,107


## Train Model

In [8]:
train_stats = {
    "epoch_train_losses": [],
    "epoch_train_accs": [],
    "epoch_val_losses": [],
    "epoch_val_accs": [],
    "epoch_layer_norm": []
}

In [None]:
def layer_norm(net):
    """Get the weight norm for given layer"""
    output_layer = net.net[-1]
    return torch.linalg.norm(output_layer.weight.data.detach(), ord=2)

In [9]:
start = time.time()
num_epochs = 20_000
# num_epochs = 100
# num_optimization_steps = 100_000

step = 0
for epoch in tqdm(range(num_epochs)):
    running_loss = 0.0
    total_correct = total_loss = 0.
    total_val_correct = total_val_loss = 0.
    for inputs, labels in train_dataloader:
        step += 1
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        outputs = outputs.type(torch.float64)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Keep track of stats
        with torch.no_grad():
            total_correct += (outputs.detach().softmax(-1).argmax(-1) == labels.detach().argmax(-1)).numpy().sum()
            total_loss += loss.item()

        

    # Evaluate on validation set
    with torch.no_grad():
        for inputs, labels in val_dataloader:
            # forward pass
            outputs = net(inputs)
            val_loss = criterion(outputs, labels)
            # calculate metrics
            total_val_correct += (outputs.detach().softmax(-1).argmax(-1) == labels.detach().argmax(-1)).numpy().sum()
            total_val_loss += val_loss.item()

    # Store epoch-level loss + accuracy
    epoch_acc = (total_correct / len(train_data)) * 100.
    epoch_loss = total_loss / len(train_dataloader)
    val_acc = (total_val_correct / len(val_data)) * 100.
    val_epoch_loss = total_val_loss / len(val_dataloader)
    train_stats["epoch_train_accs"].append(epoch_acc)
    train_stats["epoch_train_losses"].append(epoch_loss)
    train_stats["epoch_val_accs"].append(val_acc)
    train_stats["epoch_val_losses"].append(val_epoch_loss)
    train_stats["epoch_layer_norm"].append(layer_norm(net))

    # Log losses every X epochs
    if (epoch % 100) == 0:
        tqdm.write(f"[{epoch+1}/{num_epochs}][{step:,}]\tTrain Loss: {epoch_loss:.3f}\tVal Loss: {val_epoch_loss:.3f}")

    # stop training after about X optimization steps
    # if step > num_optimization_steps:
    #     break

print(f"Total training time: {(time.time()-start)/60.:.2f} minutes")

  0%|          | 0/20000 [00:00<?, ?it/s]

[1/20000][15]	Train Loss: 4.934	Val Loss: 4.702
[101/20000][1,515]	Train Loss: 0.354	Val Loss: 4.665
[201/20000][3,015]	Train Loss: 0.000	Val Loss: 3.334
[301/20000][4,515]	Train Loss: 0.000	Val Loss: 2.848
[401/20000][6,015]	Train Loss: 0.000	Val Loss: 2.601
[501/20000][7,515]	Train Loss: 0.000	Val Loss: 2.433
[601/20000][9,015]	Train Loss: 0.000	Val Loss: 2.262
[701/20000][10,515]	Train Loss: 0.000	Val Loss: 2.133
[801/20000][12,015]	Train Loss: 0.000	Val Loss: 1.979
[901/20000][13,515]	Train Loss: 0.000	Val Loss: 1.892
[1001/20000][15,015]	Train Loss: 0.000	Val Loss: 1.752
[1101/20000][16,515]	Train Loss: 0.000	Val Loss: 1.609
[1201/20000][18,015]	Train Loss: 0.000	Val Loss: 1.520
[1301/20000][19,515]	Train Loss: 0.000	Val Loss: 1.387
[1401/20000][21,015]	Train Loss: 0.000	Val Loss: 1.338
[1501/20000][22,515]	Train Loss: 0.000	Val Loss: 1.198
[1601/20000][24,015]	Train Loss: 0.000	Val Loss: 1.096
[1701/20000][25,515]	Train Loss: 0.000	Val Loss: 1.041
[1801/20000][27,015]	Train Loss:

## Visualize Results

In [10]:
# Plot loss curve
fig = go.Figure()
fig.add_trace(
    go.Scatter(x=np.arange(len(train_stats["epoch_train_losses"])), y=train_stats["epoch_train_losses"], name="Train Loss")
)
fig.add_trace(
    go.Scatter(x=np.arange(len(train_stats["epoch_val_losses"])), y=train_stats["epoch_val_losses"], name="Val Loss")
)

# Add buttons for scaling axis
updatemenus = [
    dict(
        type = "buttons",
        direction = "left",
        buttons=list([
            dict(
                args=[{"yaxis": dict(type="linear")}],
                label="Y Axis Linear Scale",
                method="relayout"
            ),
            dict(
                args=[{"yaxis": dict(type="log")}],
                label="Y Axis Log Scale",
                method="relayout"
            )
        ]),
        showactive=True,
        x=0.01,
        xanchor="left",
        y=1.1,
        yanchor="top"
    )
]

# Update layout
fig.update_layout(
    height=700, width=1000,
    updatemenus=updatemenus,
    title_text="Loss Curve",
    title_x=0.5,
    xaxis_title_text="epoch",
    yaxis_title_text="loss"
)

fig.show()

In [11]:
# Plot accuracy
fig = go.Figure()
fig.add_trace(
    go.Scatter(x=np.arange(len(train_stats["epoch_train_accs"])), y=train_stats["epoch_train_accs"], name="Train Accuracy")
)
fig.add_trace(
    go.Scatter(x=np.arange(len(train_stats["epoch_val_accs"])), y=train_stats["epoch_val_accs"], name="Val Accuracy")
)

# Add buttons for scaling axis
updatemenus = [
    dict(
        type = "buttons",
        direction = "left",
        buttons=list([
            dict(
                args=[{"xaxis": dict(type="linear")}],
                label="X Axis Linear Scale",
                method="relayout"
            ),
            dict(
                args=[{"xaxis": dict(type="log")}],
                label="X Axis Log Scale",
                method="relayout"
            )
        ]),
        showactive=True,
        x=0.01,
        xanchor="left",
        y=1.1,
        yanchor="top"
    )
]

# Update layout
fig.update_layout(
    height=700, width=1000,
    updatemenus=updatemenus,
    title_text="Training Accuracy",
    title_x=0.5,
    xaxis_title_text="epoch",
    yaxis_title_text="accuracy"
)

fig.show()

## Evaluate on Val Set

In [12]:
# Evaluate on test set
net.eval()
test_loss = total_correct = 0.
with torch.no_grad():
    for i, (batch, labels) in tqdm(enumerate(val_dataloader), total=len(val_dataloader), leave=False):
        # forward pass
        logits = net(batch)
        loss = criterion(logits, labels)
        # calculate metrics
        total_correct += (logits.detach().softmax(-1).argmax(-1) == labels.detach().argmax(-1)).numpy().sum()
        test_loss += loss.item()

    test_acc = (total_correct / len(val_data)) * 100.
    test_loss = test_loss / len(val_dataloader)
    
print(f"Val set loss: {test_loss:.3f}")
print(f"Val set accuracy: {test_acc:.2f}%")

  0%|          | 0/4 [00:00<?, ?it/s]

Val set loss: 0.000
Val set accuracy: 100.00%


## Dimensionality Reduction

Perform dimensionality reduction on the weights of the input layer and output layer and visualize. Visualizing data points for input equations X + 8 where x goes from 0 to `p`. The plots show a circular structure described in the OpenAI paper. It's interesting that the only characters that do not follow the circular structure are the `+` and `=` characters. This makes sense since they are unimportant to the actual modular addition equation, the model should be able to learn without these characters present.

In [45]:
embed_weights_reduced = PCA(2).fit_transform(net.net[0].weight.data)
print(embed_weights_reduced.shape)

output_weights_reduced = PCA(2).fit_transform(net(features[features[:, 2] == 8]).detach())
print(output_weights_reduced.shape)

(99, 2)
(97, 2)


In [74]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=embed_weights_reduced[:, 0],
        y=embed_weights_reduced[:, 1],
        mode="markers+text",
        text=list(char2idx.keys()),
        textposition="top center",
        visible=True
    )
)
fig.add_trace(
    go.Scatter(
        x=output_weights_reduced[:, 0],
        y=output_weights_reduced[:, 1],
        mode="markers+text",
        text=features[features[:, 2] == 8][:, 0],
        textposition="top center",
        visible=False
    )
)

# Add buttons for switching layers
updatemenus = [
    dict(
        type = "buttons",
        direction = "left",
        buttons=list([
            dict(
                args=["visible", [True, False]],
                label="Embedding layer",
                method="restyle"
            ),
            dict(
                args=["visible", [False, True]],
                label="Output layer",
                method="restyle"
            )
        ]),
        showactive=True,
        x=0.01,
        xanchor="left",
        y=1.1,
        yanchor="top"
    )
]

# Update layout
fig.update_layout(
    height=700, width=1000,
    updatemenus=updatemenus,
    title_text="Layer weights",
    title_x=0.5
)

fig.show()