A small, fully-reproducible empirical study of whether grokking — delayed generalization in neural networks — can be forecast from signals visible in the first few hundred steps of training, long before generalization actually occurs.
pip install torch numpy matplotlib scipy scikit-learn pandasCPU-only is fine. The full sweep takes ~50 minutes on one CPU.
# 1. Run unit tests
python tests/test_basic.py
python tests/test_analysis.py
# 2. Run the main 4 × 4 × 3 sweep (train_frac × weight_decay × seed) on
# modular addition mod 23 with a 1-layer transformer.
python experiments/run_main_grid.py
# 3. Compute per-feature univariate AUROCs and a multivariate logistic
# forecast (leave-one-out) at multiple early-training windows.
python experiments/analyze.py
# 4. Generate paper figures.
python experiments/make_figures.py
# 5. Render and compile the LaTeX paper.
python experiments/render_paper.py --compile_pdfsrc/
model.py # tiny decoder-only transformer + modular-addition data
train.py # training loop with rich early-dynamics logging
features.py # 28 features summarising the early window per run
experiments/
run_main_grid.py # orchestrator for the 4×4×3 sweep
run_jobs.py # run a small explicit list of (tf,wd,seed) jobs
analyze.py # univariate + multivariate prediction; bootstrap CIs
make_figures.py # all paper figures
render_paper.py # fill LaTeX template from analysis JSON, compile PDF
tests/
test_basic.py # data, model, training, basic scaffolding
test_analysis.py # feature extraction, AUROC, bootstrap helpers
paper/
paper.tex # LaTeX paper template
results/ # generated JSON logs
figures/ # generated PNG figures
For each training run, at every step of the early window, we record:
- train and validation loss/accuracy
- parameter L2 norm and gradient L2 norm
- effective rank (entropy of normalized singular values) of the token embedding matrix W_E and the unembedding matrix W_U
- train-set logit margin (correct logit minus runner-up)
- per-step parameter update norm and the cosine alignment of consecutive parameter updates
From these series we extract a fixed 28-dimensional feature vector per run: end-of-window values, slopes, train/val gaps, optimisation-noise statistics, and a few log/ratio transforms.
A run is labelled grokked = 1 if max_t val_acc[t] >= 0.99 during
training, else 0. We compute:
- Per-feature univariate AUROC, ranked across windows in {100, 200, 300, 500} steps.
- Multivariate L2-regularised logistic regression on the 28 features, evaluated by leave-one-out over runs, with bootstrap 95% CIs.
- Baselines of using val_loss alone or val_acc alone at the same window. We also report a paired bootstrap test of AUROC(multivariate) − AUROC(baseline).
All scaffolding is unit-tested:
- modular-arithmetic data correctness; train/val split disjointness
- model forward shapes; sanity check that the model overfits a tiny task
- effective rank bounds (rank-1 → 1; full rank → close to min(m,n))
- gradient and parameter norm bookkeeping
- logit-margin sign
- training smoke test: 200 steps yields consistent log shapes
- AUROC is exactly 1.0 / 0.0 on perfect/inverted scores; ~0.5 on random with ties handled correctly
- bootstrap CI helper produces narrow CIs near 1 for perfect separation and a CI containing 0.5 for random scores
- paired bootstrap returns 0 when the two predictors are identical and positive when one clearly beats the other
Run with:
python tests/test_basic.py && python tests/test_analysis.pyIf you find this useful, please cite as
@misc{grokking-prediction-2026,
title = {Predicting Grokking from Early Training Dynamics:
A Forecasting Approach to Delayed Generalization},
year = {2026},
note = {Reproducible code accompanying the paper.},
}