In [1]:
!pip install numpy torch sympy mod blobfile

import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from contextlib import suppress
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Callable, Literal, Optional

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import optim
import wandb

from tqdm.notebook import tqdm
from grokking.dataset import DEFAULT_MODULUS, ModularArithmetic, Operator
from grokking.transformer import Transformer
from grokking.utils import generate_run_name


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


# Unifying Grokking & DD

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@dataclass
class Config:
    # Model
    num_layers: int = 1
    num_heads: int = 4
    d_model: int = 128
    d_vocab: int = DEFAULT_MODULUS + 1
    d_mlp: int = 4 * 128 # 4 * d_model
    d_head: int = 64 # d_model // num_heads
    num_ctx: int = 3
    act_fn: Callable = F.relu
    load_path: Optional[str] = None
    # use_ln: bool = True

    # Dataset
    operator: Operator = "+"
    modulus: int = DEFAULT_MODULUS  
    frac_label_noise: float = 0.0
    seed: int = 0         
    shuffle: bool = True
    frac_train: float = 0.3

    # Dataloaders
    batch_size: int = 64

    # Optimizer
    lr: float = 1e-3
    weight_decay: float = 1e-5
    use_sgd: bool = False
    momentum: float | tuple[float, float] = (0.9, 0.98)

    # Training
    num_training_steps: int = int(3e5)
    num_jobs: int = 1

    # Logging
    wandb_project: str = "grokking"
    no_logging: bool = False
    resume_run_id: Optional[str] = None
    log_normalized_loss: bool = True
    log_interval: int = 10

    weights_dir: str = "weights"


config = Config()

# Model

model = (
    Transformer(
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        d_model=config.d_model,
        d_vocab=config.d_vocab,
        d_mlp=config.d_mlp,
        d_head=config.d_head,
        num_ctx=config.num_ctx,
        act_fn=config.act_fn,
        # use_ln=config.use_ln,
    )
    .float()
    .to(DEVICE)
)

if config.load_path is not None:
    model.load_state_dict(torch.load(config.load_path))

with suppress(FileExistsError):
    os.makedirs(config.weights_dir)

num_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
print(f"Model has {num_params} trainable parameters.")

# Dataset

train_dataset, val_dataset = ModularArithmetic.generate_split(
    operator=config.operator,
    modulus=config.modulus,
    frac_label_noise=config.frac_label_noise,
    seed=config.seed,
    shuffle=config.shuffle,
)

# Dataloaders

train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size)

# Optimizer

if config.use_sgd and not isinstance(config.momentum, tuple):
    optimizer = optim.SGD(
        model.parameters(),
        lr=config.lr,
        weight_decay=config.weight_decay,
        momentum=config.momentum,
    )
elif isinstance(config.momentum, tuple):
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config.lr,
        weight_decay=config.weight_decay,
        betas=config.momentum,
    )
else: 
    raise ValueError("Invalid optimizer configuration.")

# Logging
name = generate_run_name(
    asdict(config),
    aliases={
        "num_layers": "L",
        "num_heads": "H",
        "d_model": "D",
        "d_vocab": "V",
        "d_mlp": "M",
        "d_head": "d",
        "num_ctx": "C",
        "lr": "lr",
        "weight_decay": "wd",
        "momentum": "mom",
    },
    bool_aliases={
        "use_sgd": {True: "SGD", False: "Adam"},
    },
    append_hash=True,
)
date_time = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
mode = "disabled" if config.no_logging else None

if config.resume_run_id is None:
    wandb.init(
        project=config.wandb_project,
        id=date_time,
        settings=wandb.Settings(start_method="thread"),
        name=name,
        config=asdict(config),
        mode=mode,
    )
else:
    wandb.init(
        project=config.wandb_project,
        id=config.resume_run_id,
        resume="must",
        settings=wandb.Settings(start_method="thread"),
        name=name,
        config=asdict(config),
        mode=mode,
    )
wandb.watch(model)

# Training
Reduction = Literal["mean", "sum"]

def criterion(logits, labels, reduction: Reduction="sum"):
    # only look at predictions of last numbers
    logits = logits[:,-1]
    # compute individual and summed losses for final number
    logprobs = F.log_softmax(logits.to(torch.float64), dim=-1)
    prediction_logprobs = torch.gather(logprobs, index=labels.unsqueeze(1), dim=-1)

    if reduction == "mean":
        loss = -torch.mean(prediction_logprobs)
    elif reduction == "sum":
        loss = -torch.sum(prediction_logprobs)
    else:
        raise ValueError("Invalid reduction argument.")

    return loss

steps_per_epoch = len(train_dataloader)
step = 0

def validate(model: nn.Module, trainloader: DataLoader, testloader: DataLoader, criterion: nn.Module):
    model.eval()

    def _validate(loader):
        loss = torch.zeros(1)
        acc = torch.zeros(1)

        num_samples = len(loader)

        with torch.no_grad():
            for x, y in loader:
                x, y = x.to(DEVICE), y.to(DEVICE)
                y_hat = model(x)
                loss += criterion(y_hat, y, reduction="sum")
                
                y_pred = y_hat.argmax(dim=-1)[:, -1].detach()
                acc += (y == y_pred).float().sum()

        loss /= num_samples
        acc = acc / num_samples

        return loss, acc

    def _weight_norm(model):
        norm_squared = torch.zeros(1)
        for p in model.parameters():
            norm_squared += p.norm().pow(2)

        return norm_squared.sqrt()


    train_loss, train_acc = _validate(trainloader)  
    test_loss, test_acc = _validate(testloader)
    weight_norm = _weight_norm(model)
    
    # Efficiency is logprob of the correct label divided by the norm of the weights

    return {
        "train/loss": train_loss.item(),
        "train/acc": train_acc.item(),
        "train/efficiency": (train_loss / weight_norm).item(),
        "test/loss": test_loss.item(),                
        "test/acc": test_acc.item(),
        "test/efficiency": (test_loss / weight_norm).item(),
        "weight_norm": weight_norm.item(),
    }

# train
for epoch in tqdm(range(1, int(config.num_training_steps / steps_per_epoch) + 1)):
    for i, (x, y) in enumerate(train_dataloader):
        if step % config.log_interval == 0:
            wandb.log(
                validate(model, train_dataloader, val_dataloader, criterion),
                step=step
            )

        model.train()
        x, y = x.to(DEVICE), y.to(DEVICE)
        y_hat = model(x)
        loss = criterion(y_hat, y, reduction="mean")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step += 1


Model has 288256 trainable parameters.


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670592183315118, max=1.0…

  0%|          | 0/2542 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Epoch-wise

## Model-wise

## Sample-wise

## Regularization-wise

## Interpolation

### Can we induce grokking in CIFAR-10?

### Can we interpolate just by varying initialization scale and label noise?

## Miscellaneous


### Can we induce epoch-/regularization-wise DD in shallow models?

### Can we induce epoch-wise DD in transformers?