Skip to content

mtelvers/attn-ox

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

attn-ox

A small OxCaml program that learns single-hex-digit addition with a transformer. The architecture (embedding, multi-head attention, layer norm, residual connections, feed-forward) is the same as in GPT-style models, scaled down to 13,760 parameters. Training takes seconds on a laptop.

A full walkthrough of how the model works, what it learns, and how grokking emerges in a 376-parameter variant is on the blog.

The forward pass, backward pass, and AdamW optimizer are written by hand. The hot matmul kernels use AVX 256-bit vectors via OxCaml's Float64x4. Every step of the manual backward is verified against finite-difference gradient checks in test/grad_check.ml.

Build and run

Requires an OxCaml opam switch (this project was developed against 5.2.0+ox).

opam exec --switch 5.2.0+ox -- dune build
opam exec --switch 5.2.0+ox -- dune runtest             # 12 tests, ~50ms

opam exec --switch 5.2.0+ox -- dune exec bin/main.exe   # train + inspect
opam exec --switch 5.2.0+ox -- dune exec bin/holdout.exe   # held-out generalisation
opam exec --switch 5.2.0+ox -- dune exec bin/long_train.exe   # long-training / grokking
opam exec --switch 5.2.0+ox -- dune exec bin/sweep.exe   # d_model sweep
opam exec --switch 5.2.0+ox -- dune exec bench/bench.exe   # scalar vs SIMD matmul

Source layout

lib/    architecture + ops (forward, backward, SIMD matmul)
bin/    binaries: main, holdout, long_train, sweep
bench/  scalar vs SIMD matmul benchmark
test/   gradient checks and sanity tests

About

A small OxCaml program that learns single-hex-digit addition with a transformer

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages