In [2]:
!pip install numpy torch sympy mod blobfile

import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from contextlib import suppress
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Callable, Literal, Optional, Union, Tuple
from copy import deepcopy

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import optim
import wandb
from tqdm.notebook import tqdm
import ipywidgets as widgets

from grokking.dataset import ModularArithmetic, Operator
from grokking.transformer import Transformer
from grokking.utils import generate_run_name
from grokking.validation import criterion, validate
from grokking.learner import Config, GrokkingLearner

DEFAULT_MODULUS = 113
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


# Unifying Grokking & DD

In [3]:
config = Config(
    lr=1e-3,
    d_model=128,
    weight_decay=0.1,
    test_acc_criterion=1.,
    device=DEVICE,  
)

In [4]:
# Dataset

train_dataset, val_dataset = ModularArithmetic.generate_split(
    operator=config.operator,
    modulus=config.modulus,
    frac_label_noise=config.frac_label_noise,
    seed=config.seed,
    shuffle=config.shuffle,
    frac_train=config.frac_train,
)

# Dataloaders

train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size)

In [5]:
# Logging
date_time = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
mode = "disabled" if config.no_logging else None

In [6]:
# Training

def train_test_run():
    learner = GrokkingLearner.create(config, train_dataloader, val_dataloader)

    if config.resume_run_id is None:
        wandb.init(
            project=config.wandb_project,
            id=date_time,
            settings=wandb.Settings(start_method="thread"),
            name=learner.name,
            config=asdict(config),
            mode=mode,
        )
    else:
        wandb.init(
            project=config.wandb_project,
            id=config.resume_run_id,
            resume="must",
            settings=wandb.Settings(start_method="thread"),
            name=learner.name,
            config=asdict(config),
            mode=mode,
        )
    wandb.watch(learner.model)
    
    try: 
        learner.train()
    except KeyboardInterrupt:
        wandb.finish()


train_test_run()

Model has 226816 trainable parameters


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjqhoogland[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/300000 [00:00<?, ?it/s]

In [15]:
# Sweep

sweep_config = {
    'method': 'grid',
    'name': 'sweep',
    'parameters': {
        k : {"value": v} for k, v in asdict(config).items()
    } | {
        'd_model': {
            'values': [32, 64, 128, 256, 512]
        },
        'seed': {
            'values': [0, 1, 2, 3, 4]
        },
        # TODO: Look at initial gain, frac_train, frac_label_noise
    }
}

def get_sweep_info(sweep_config):
    # Hash the config (make sure to alphabetize the keys)
    config_hash = hash(str(sorted(sweep_config.items())))        
    sweep_config['name'] = f'{sweep_config["name"]}-{config_hash}'
    sweep_id = wandb.sweep(sweep_config, project="dominoes")

    return sweep_config['name'], sweep_id
 

sweep_name, sweep_id = get_sweep_info(sweep_config)

print("WANDB sweep name:", sweep_name)
print("WANDB sweep ID:", sweep_id)  

Create sweep with ID: fr8286r5
Sweep URL: https://wandb.ai/jqhoogland/dominoes/sweeps/fr8286r5
WANDB sweep name: sweep-7382929675824336860
WANDB sweep ID: fr8286r5


In [None]:
# Sweep function

def train():
    wandb.init(project="dominoes")
    config_dict = wandb.config 
    config = Config(**config_dict)

    learner = GrokkingLearner.create(
        config=config,
        trainloader=train_dataloader,
        testloader=val_dataloader,
    )

    # wandb.watch(learner.model)
    metrics = learner.train()

    return metrics

In [None]:
# Run the sweep
wandb.agent(sweep_id, function=train, count=10 * 10)

## Epoch-wise

## Model-wise

## Sample-wise

## Regularization-wise

## Interpolation

### Can we induce grokking in CIFAR-10?

### Can we interpolate just by varying initialization scale and label noise?

## Miscellaneous


### Can we induce epoch-/regularization-wise DD in shallow models?

### Can we induce epoch-wise DD in transformers?