A small OxCaml program that learns single-hex-digit addition with a transformer. The architecture (embedding, multi-head attention, layer norm, residual connections, feed-forward) is the same as in GPT-style models, scaled down to 13,760 parameters. Training takes seconds on a laptop.
A full walkthrough of how the model works, what it learns, and how grokking emerges in a 376-parameter variant is on the blog.
The forward pass, backward pass, and AdamW optimizer are written by hand.
The hot matmul kernels use AVX 256-bit vectors via OxCaml's Float64x4.
Every step of the manual backward is verified against finite-difference
gradient checks in test/grad_check.ml.
Requires an OxCaml opam switch (this project was developed against
5.2.0+ox).
opam exec --switch 5.2.0+ox -- dune build
opam exec --switch 5.2.0+ox -- dune runtest # 12 tests, ~50ms
opam exec --switch 5.2.0+ox -- dune exec bin/main.exe # train + inspect
opam exec --switch 5.2.0+ox -- dune exec bin/holdout.exe # held-out generalisation
opam exec --switch 5.2.0+ox -- dune exec bin/long_train.exe # long-training / grokking
opam exec --switch 5.2.0+ox -- dune exec bin/sweep.exe # d_model sweep
opam exec --switch 5.2.0+ox -- dune exec bench/bench.exe # scalar vs SIMD matmullib/ architecture + ops (forward, backward, SIMD matmul)
bin/ binaries: main, holdout, long_train, sweep
bench/ scalar vs SIMD matmul benchmark
test/ gradient checks and sanity tests