A small reference implementation that pairs Monarch-matrix MLPs
(Hazy Research's M2) with Newton–Muon-style right-preconditioning
(Du & Su 2026) and shows that the M2 block structure yields an exact
block-diagonal preconditioner with O(d^{3/2}) storage and O(d^2)
total refresh/apply work for d=b^2 — much cheaper than the baseline's
dense (ZZᵀ + γI)⁻¹ inverse.
The repo also implements three other optimizer families (FOOF/GDS, Schedule-Free, plain AdamW) all using the same activation-capture and inverse-refresh machinery, for head-to-head comparison.
references.md— paper/repo deep dive (read for algorithm details).PROGRESS.md— current project state, results, open questions, next-steps (read first if picking up work).AGENTS.md— hardware notes + four-GPU conventions.
m2gdn/
├── m2gdn/
│ ├── monarch.py # square Monarch2 + rectangular BlockdiagButterfly layers
│ ├── inverse.py # batched damped Cholesky inverse + verification
│ ├── newton_schulz.py # BF16 Newton–Schulz (Polar-Express coefficients)
│ ├── optimizers.py # 7 optimizers + capture-window mixin
│ └── model.py # TinyGPT with Monarch MLPs
├── tests/
│ ├── test_monarch.py
│ ├── test_inverse.py
│ ├── test_newton_schulz.py
│ ├── test_optimizers.py (all 7 optimizers covered)
│ ├── test_covariance.py
│ ├── test_data_source.py
│ └── test_sweep_resume.py
├── bench_inverse.py # dense vs block-diagonal inverse benchmark
├── bench_newton_schulz.py # standard vs Gram NS benchmark
├── bench_upstream_gram_ns_kernels.py # optional upstream CuTeDSL kernel bench
├── compare_optimizers.py # default optimizer bakeoff report from artifacts
├── train.py # side-by-side benchmark (NewtonMuon dense vs monarch)
├── sweep.py # NewtonMuon LR sweep
├── sweep_gds.py # cross-optimizer sweep (knows all 7 variants)
├── dual_sweep_gds.py # multi-GPU sweep_gds launcher + result merger
├── references.md # deep-dive notes on 4 papers + 5 repos
├── PROGRESS.md # project state for handoff
├── AGENTS.md # AI/agent conventions (READ FIRST: hardware + four-GPU)
└── papers/ # 4 polished paper markdowns + extract pipeline
All seven live in m2gdn/optimizers.py and share the same _CovState,
_CaptureWindowMixin, and BF16 inverse-cache infrastructure.
| class | preconditioner | momentum | NS | use case |
|---|---|---|---|---|
NewtonMuonDense |
dense (ZZᵀ+γI)⁻¹ |
Nesterov-Muon | yes | paper baseline |
NewtonMuonMonarch |
exact M2 block-diag | Nesterov-Muon | yes | the headline optimizer |
GDSDense |
dense | AdamW EMA | no | FOOF + AdamW first moment |
GDSMonarch |
exact M2 | AdamW EMA | no | GDS + structured-matrix trick |
ScheduleFreeGDSDense |
dense | SF interpolation+averaging | no | matches paper reference |
ScheduleFreeGDSMonarch |
exact M2 | SF interpolation+averaging | no | SF base + M2 trick |
AdamW |
per-coordinate | AdamW | no | vanilla baseline |
For Newton-Muon variants, kind="plain" parameters use an auxiliary AdamW
path (aux_lr=1e-3 by default) instead of being pushed through Muon/NS.
Packed attention QKV weights are row-split into Q/K/V slices before
Newton-Schulz by default (ns_split=3 in the parameter group); pass
--disable-qkv-ns-split to sweep_gds.py to reproduce the older packed-NS
behavior.
For Monarch block-preconditioned variants, block_damping_policy can be
per_block (default) or global; d=576 sweeps found no meaningful tuned-LR
quality difference, so the block-local heuristic remains the default.
For GDS/SF-GDS variants, precond_strength blends raw and activation-
preconditioned gradients (1.0 is the historical full-preconditioned path);
precond_norm_cap optionally caps the blended direction's norm relative to the
raw gradient. sweep_gds.py exposes these as --precond-strengths and
--precond-norm-caps. Schedule-Free variants expose --weight-lr-powers and
--rs for the Schedule-Free averaging kernel; current d=1024 FineWeb SF-GDS
probes prefer weight_lr_power=-0.5, with no-aux runs favoring r=2.5 and
attention/dense aux runs favoring the stabler high-LR r=2 cell, over the
paper-default weight_lr_power=2, r=0. Schedule-Free GDS variants also expose
--precond-strength-warmup-steps, which ramps the raw-to-preconditioned blend
from raw gradients to the target strength over optimizer steps, and
--precond-strength-lows, which controls the low value used by *_low
structural schedules. --sf-update-precond-strengths is an experimental
decoupled-SF knob: omitting it preserves the exact coupled z/y update direction,
while explicit scalar values let the Schedule-Free averaged iterate use a
raw-to-GDS blend different from the z/base step.
--sf-update-precond-schedules and --sf-update-precond-transition-steps add
time-varying y/update schedules; the current d=1024 best uses
cosine_to_z to ramp y from raw to the z/base GDS direction over 200 steps.
--sf-update-lr-scales, --sf-update-lr-scale-schedules, and
--sf-update-lr-scale-transition-steps independently scale the
Schedule-Free y correction; the best d=1024 SF-GDS row found so far starts
that y correction at 0.65x and ramps it back to 1.0x with
cosine_to_one over 200 steps.
Adaptive experimental variants (cosine_align*_to_z, cosine_norm*_to_z) gate
that ramp using raw/preconditioned direction agreement, but current FineWeb
confirmations did not beat the fixed cosine_to_z schedule.
--sf-average-start-steps and --sf-average-ramp-steps are tail-averaging
research knobs; current d=1024 FineWeb sweeps found hard delayed starts much
worse and ckp1 ramps slightly worse than the default averaging kernel.
For SF-GDS aux AdamW groups, --sf-aux-average tracks a separate SF-style
average and --sf-aux-eval-lerp selects the eval-time blend from live aux
weights (0) to averaged aux weights (1). Full aux averaging worsened the
current d=1024 row, but a partial blend at 0.45 gives the best SF-GDS eval
loss tested so far.
Adaptive policies are also available via
--precond-strength-policies fixed,norm_ratio,norm_cosine,
--precond-target-ratios, and --precond-min-cosines; current FineWeb probes
found the adaptive norm/cos policies slower and worse than fixed partial
strength. Static low-overhead module schedules are exposed with
--precond-strength-schedules; current d=1024 FineWeb probes found that
lowering packed-QKV preconditioning while keeping non-QKV weights strongly
preconditioned improves SF-GDS without the per-step norm/cos overhead. Q/K/V
row-sliced schedules
(q_low, k_low, v_low, qk_low, qv_low, kv_low, plus the matching
*_raw forms) are available for packed QKV localization; current probes show
the scalar qkv_low schedule is still better and faster than slice-low
variants. Projection-specific schedules (proj_low, proj_raw,
qkv_raw_proj_low, qkv_raw_proj_raw) isolate the attention output projection
from packed QKV and MLP output weights.
For true auxiliary-optimizer ablations, --attention-aux-modes qkv,proj,attn
moves selected attention weights out of the GDS/SF-GDS update and into the
auxiliary AdamW path controlled by --aux-lrs. On the current d=1024,
4-layer FineWeb grid, all-attention aux AdamW plus the refined SF averaging
kernel closes most of the AdamW gap while remaining faster than dense AdamW.
--attention-aux-modes attn_split additionally supports separate packed-QKV
and attention-projection LRs via --attention-qkv-aux-lrs and
--attention-proj-aux-lrs; the best split row is only nominally ahead of the
shared attn row and remains inside CI.
Non-attention dense ablations are exposed separately with
--dense-aux-modes mlp_gate,mlp_out,mlp,nonattn_dense,all_dense and optional
per-group --dense-aux-lrs. In the current d=1024, 4-layer FineWeb grid,
attention_aux=attn or attn_split plus dense_aux=mlp_gate is the strongest
SF-GDS setting tested so far and cuts warm step time by roughly 9-10%; moving
the MLP output or all remaining dense preconditioned weights to aux AdamW
regresses held-out loss, especially for SF-GDS.
BlockdiagButterfly is the rectangular Monarch-style layer for d -> 3d,
3d -> d, and other divisible projections; it exposes per-factor capture
shapes for the same block preconditioner machinery.
sweep_gds.py --attention-kind dense,monarch_qkv,monarch_proj,monarch uses
that layer to replace packed attention QKV and/or the attention output
projection. Current d=2304 FineWeb checks show full Monarch attention reduces
parameters and optimizer-step time for Monarch-MLP GDS/SF-GDS, while the
Newton-Muon path needs a higher LR than the initial lr=0.01 profile.
QKV-only replacement carries most of the GDS/SF-GDS benefit; projection-only is
not a good GDS/SF-GDS path, though it is the least damaging Newton-Muon
short-profile ablation.
uv venv .venv --python 3.12
uv pip install --python .venv/bin/python --pre torch \
--index-url https://download.pytorch.org/whl/nightly/cu130
uv pip install --python .venv/bin/python -e . --no-deps
uv pip install --python .venv/bin/python numpy einops tqdm pytest pip
Dependencies: torch>=2.4, numpy, einops, tqdm.
CUDA_VISIBLE_DEVICES=0,1,2,3 PYTHONPATH=. .venv/bin/python run_tests.py
The runner compiles the main scripts, runs every standalone test, and reports CUDA availability plus CUDA-only self-skips on CPU-only machines. The individual scripts remain directly runnable:
PYTHONPATH=. .venv/bin/python tests/test_monarch.py
PYTHONPATH=. .venv/bin/python tests/test_inverse.py
PYTHONPATH=. .venv/bin/python tests/test_newton_schulz.py
PYTHONPATH=. .venv/bin/python tests/test_optimizers.py
PYTHONPATH=. .venv/bin/python tests/test_covariance.py
PYTHONPATH=. .venv/bin/python tests/test_data_source.py
PYTHONPATH=. .venv/bin/python tests/test_gpu_launcher.py
PYTHONPATH=. .venv/bin/python tests/test_sweep_resume.py
Expected: run_tests.py reports PASS for py_compile and all eight scripts.
The optimizer test suite includes a side-by-side check that
ScheduleFreeGDSDense matches the upstream
SGDScheduleFreeReference impl to FP32 noise.
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. .venv/bin/python bench_inverse.py
On a Blackwell RTX PRO 6000 (97 GB), you should see something like:
d b | inv_blk(ms) inv_dense(ms) speedup | apply_blk(ms) apply_dense(ms) speedup | mem_blk(MB) mem_den(MB)
256 16 | 0.091 0.234 2.6x | 0.0042 0.0082 2.0x | 0.016 0.250
1024 32 | 0.152 1.018 6.7x | 0.0042 0.0474 11.4x | 0.125 4.000
4096 64 | 0.257 6.946 27.1x | 0.0063 2.0648 330.2x | 1.000 64.000
9216 96 | 0.403 47.034 116.7x | 0.0104 25.4295 2440.5x | 3.375 324.000
- The inverse of the block-diagonal factor is 27× faster than a
d × dCholesky atd = 4096, growing to >100× atd = 9216. - Applying the inverse to a gradient (one
torch.bmmvs one largematmul) is even cheaper in the block form. - Memory is
√dtimes smaller; atd = 9216we hold a 3.4 MB per-layer preconditioner instead of 324 MB.
Correctness residuals ‖(A + γI) · A⁻¹ − I‖ are in the 1e-6 range at
all sizes, printed by the same benchmark.
All commands below assume the repo root as the working directory.
# Full verification; local GPU run is normally under 10 seconds.
CUDA_VISIBLE_DEVICES=0,1,2,3 PYTHONPATH=. .venv/bin/python run_tests.py --quiet
# CPU-hidden skip-path check for CUDA-only tests.
CUDA_VISIBLE_DEVICES= PYTHONPATH=. .venv/bin/python run_tests.py --no-compile --quiet
# Core inverse primitive benchmark.
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. .venv/bin/python bench_inverse.py
# Current d=576 GDS seeded comparison.
PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
--variants gds_dense,gds_monarch --lrs 0.2,0.25,0.3 \
--out-dir sweeps_seed_d576_gds_tuned \
--seeds 0,1,2 --momentums 0.85 \
--d 576 --heads 4 --layers 2 --seq-len 64 --batch 16 --steps 400
# Current d=576 SF-GDS seeded comparison.
PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
--variants sf_gds_dense,sf_gds_monarch --lrs 1.3,1.6 \
--out-dir sweeps_seed_d576_sf_gds_tuned \
--seeds 0,1,2 --momentums 0.9 --warmup-steps 50,100 \
--d 576 --heads 4 --layers 2 --seq-len 64 --batch 16 --steps 400
# Cached FineWeb-Edu GPT-2 smoke, split across four GPUs.
PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
--variants adamw_dense_mlp,adamw --lrs 0.0003,0.0006 \
--out-dir sweeps_fineweb_smoke_d1024 \
--data fineweb --data-dir /home/catid/ensemble2/data/fineweb_edu_gpt2_probe2 \
--d 1024 --heads 8 --layers 1 --seq-len 32 --batch 2 --steps 5 \
--eval-batches 2
# Fast synthetic SF-GDS/M2 profile with compiled model and compiled Markov data.
PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
--variants sf_gds_monarch --lrs 9.5 \
--out-dir sweeps_sf_m2_compiled_profile \
--seeds 0,1,2,3 --split-by runs \
--steps 60 --d 1024 --heads 8 --layers 1 --seq-len 64 --batch 8 \
--vocab-size 4096 --eval-batches 1 --profile-timings \
--lm-head-nblocks 32 --lm-head-hidden-block-size 64 --lm-head-terms 16 \
--tie-m2-embedding --precond-strength-schedules mlp_head_raw \
--precond-strength-low 0.25 --momentums 0.9 --warmup-steps 10 \
--compile-model --compile-data --matmul-precision high
# Latest retained artifact: /tmp/m2gdn_sf_m2_profile_head_sum_einsum
# with tail-median end-to-end mean about 1.73 ms across four GPUs.
# Export seeded comparison tables for reports.
PYTHONPATH=. .venv/bin/python analyze_profiles.py \
sweeps_seed_d576_gds_tuned/results.json \
sweeps_seed_d576_sf_gds_tuned/results.json \
--seed-summary-csv sweeps_seed_d576_seed_summaries.csv \
--seed-summary-md sweeps_seed_d576_seed_summaries.md
# Generate the default cross-optimizer comparison report.
PYTHONPATH=. .venv/bin/python compare_optimizers.py \
--out-md optimizer_comparison.md \
--out-csv optimizer_comparison.csvRerunning the same dual_sweep_gds.py command resumes from
<out-dir>/gpu*/runs/*.json and skips matching run keys. Add --force only
when deliberately recomputing cached configurations.
PYTHONPATH=. python3 bench_newton_schulz.py \
--shapes 576x576,1728x576,2304x576,576x2304,3x576x576
NewtonMuonDense and NewtonMuonMonarch support ns_backend="standard",
"gram", and "auto"; sweep_gds.py exposes this as --ns-backends.
The pure-PyTorch Gram path is numerically compatible and falls back on square
matrices. Small d=576 timings were slower, but large rectangular Blackwell
timings are now favorable: d=4096 packed QKV (12288,4096) is 18.65 ms
standard vs 10.88 ms Gram, and 4x-MLP-shaped (16384,4096) is 24.86 ms
standard vs 12.71 ms Gram. Square and split-QKV batched shapes remain neutral.
CuTeDSL kernel status: the base env intentionally has no gram_newton_schulz,
quack, or Cutlass DSL imports. An isolated .venv_gram_ns install can import
them and run upstream kernels with the runtime SM120 compatibility shim in
bench_upstream_gram_ns_kernels.py, but those kernels are slower than PyTorch
on the local RTX PRO 6000 Blackwell box. Example d=4096 timings: packed QKV
(12288,4096) is 13.95 ms standard torch, 10.15 ms Gram torch,
20.64 ms standard kernel, and 17.93 ms Gram kernel. Kernel use should stay
import- and benchmark-gated, not enabled by GPU compute capability alone.
auto is conservative: it uses Gram only for rectangular matrices with smaller
side at least 1024, and standard NS otherwise. For large packed QKV, auto
also keeps the packed matrix intact instead of applying the default ns_split=3,
because split-QKV is square and cannot benefit from Gram. In a d=4096
3-seed/300-step nm_dense_mlp check, split+standard averaged 135.0 ms
optimizer time and min loss 4.6877 ± 0.0972; packed-QKV auto averaged
118.6 ms and min loss 4.5038 ± 0.0697. Treat this as an explicit
large-width experiment knob. FineWeb d=4096/l2 validation cuts end-to-end time
by about 13% and ties eval within CI, but train min/final loss is worse, so
auto is not the default yet.
Quick 2-way check (NewtonMuon dense vs Monarch only):
PYTHONPATH=. python3 train.py --steps 200 --d 256 --seq-len 128 --batch 32
Full LR sweep across multiple optimizers (split across all four GPUs):
PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
--variants nm_dense_mlp,nm_monarch_aux,gds_dense,sf_gds_dense,adamw,nm_monarch,gds_monarch,sf_gds_monarch \
--lrs 0.03,0.1,0.3,1.0 --d 576 --steps 400 --out-dir sweeps_4gpu
Each variant has different sensible LR ranges; see PROGRESS.md for the
best-known hyperparameters and the full results inventory.
dual_sweep_gds.py records child commands in commands.json, per-GPU logs in
gpu*/stdout.log, and merged results in the top-level results.json.
sweep_gds.py is resumable: each finished config writes
<out-dir>/runs/<run_key>.json immediately, reruns skip matching configs by
default, and --force recomputes them. Pass --seeds 0,1,2 for repeated
trials; the output results.json includes seed_summaries with mean/std/95%
CI for min loss, final loss, eval@x loss, warm step time, divergence rate, and
global/per-variant *_within_*_ci95 flags for noise-scale differences.
Use --data fineweb --data-dir <token-shard-dir> to train on cached GPT-style
token streams with train.bin, val.bin, and optional meta.json; if
--data-dir is omitted, the runner checks M2GDN_FINEWEB_DIR, repo-local
data/fineweb_edu_gpt2*, and the known local /home/catid/ensemble2/data
cache. FineWeb runs infer vocab_size and token_dtype from metadata and
report held-out eval_loss from val.bin.
When a sweep includes Monarch-model variants, --d must be a perfect square
such as 1024, 2304, or 4096.
Export those rows for reports with:
PYTHONPATH=. python3 analyze_profiles.py sweeps_seed/results.json \
--seed-summary-csv seed_summaries.csv --seed-summary-md seed_summaries.md
The exported seed-summary tables keep non-empty tuning identity columns such as
momentum, warmup_steps, r, damping policy, NS backend, and low-rank knobs,
so rows from multi-knob sweeps remain distinguishable outside results.json.
Monarch block damping policy sweeps are explicit and recorded in JSON:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 sweep_gds.py \
--variants nm_monarch --lrs 0.01,0.03,0.05 \
--block-damping-policies per_block,global --d 576 --steps 400 \
--out-dir sweeps_damping_nm
Newton-Schulz backend sweeps are also explicit:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 sweep_gds.py \
--variants nm_dense_mlp --lrs 0.01 --aux-lrs 0.001 \
--ns-backends standard,gram,auto --disable-qkv-ns-split \
--d 576 --steps 400 --profile-timings --out-dir sweeps_ns_backend_dense
Covariance diagnostics are opt-in and summarize why block preconditioning helps or loses information:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 sweep_gds.py \
--variants nm_monarch --lrs 0.01 --aux-lrs 0.001 \
--cov-diagnostics --d 576 --steps 64 --out-dir sweeps_covdiag_d576_monarch
PYTHONPATH=. python3 analyze_profiles.py sweeps_covdiag_d576_monarch/results.json \
--out-json cov_summary.json --out-csv cov_summary.csv --cov-plot cov.png
The current d=576 diagnostic run shows dense activation covariances have high cross-block energy (~0.87-0.89), so channel permutation or low-rank residual corrections are more promising next experiments than more single-seed LR grids.
Block-permutation probes and sweeps:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 probe_block_permutations.py \
--mlp-kind monarch --d 576 --heads 8 --layers 2 --seq-len 64 \
--batch 16 --batches 4 --max-dim 576 --random-perms 8 \
--out sweeps_blockperm_d576_monarch.json
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. python3 sweep_gds.py \
--variants sf_gds_monarch --lrs 1.3 --warmup-steps 100 \
--monarch-perm transpose_grid --d 576 --steps 400 \
--out-dir sweeps_perm_transpose_sf
The current result is mixed: transpose_grid lowers covariance cross-block
energy for many Monarch activations and slightly improves SF-GDS eval@x, but it
does not improve Newton-Muon or GDS min loss. It is an experiment knob, not a
new default.
Low-rank residual probe:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 probe_lowrank_block.py \
--mlp-kind monarch --d 576 --heads 8 --layers 2 --seq-len 64 \
--batch 16 --batches 4 --max-dim 576 --ranks 4,8,16,32,64 \
--methods raw_residual,bmetric_residual,bmetric_inverse,sketched_residual \
--out sweeps_lowrank_block_inverse_ranked_d576_monarch.json
Current result: residual-energy modes improve K @ K_approx^-1 ≈ I but barely
move inverse/gradient error at low rank. Ranking block-metric modes by inverse
impact, |lambda / (1 + lambda)|, is more promising: at d=576, rank-64 lowers
sampled preconditioned-gradient error from ~0.72 to ~0.61/0.62, and rank-192
reaches ~0.40/0.42. Sketch modes need conditioning guards before integration.
The optimizer prototype exposes this for dense preconditioner groups:
PYTHONPATH=. python3 sweep_gds.py \
--variants gds_monarch --lrs 0.3 --d 256 --steps 400 \
--dense-inverse-modes dense,block_lowrank \
--lowrank-ranks 16 --lowrank-methods bmetric_inverse \
--profile-timings --out-dir sweeps_lowrank_opt_d256_gds_monarch
At d=256, rank-16 block_lowrank slightly improves tuned GDS loss but is much
slower because the prototype still materializes a full inverse. Newton-Muon
does not show a clear quality gain. block_lowrank_factored stores Woodbury
factors instead of a materialized inverse and matches the same loss, but it is
not faster in eager PyTorch because refresh still uses dense eigensolves and
factored apply adds overhead. The next implementation step is faster mode
selection/fused apply, not larger full-inverse sweeps.
Newton-Muon baseline names are intentionally explicit:
nm_dense_mlp: true dense-MLP TinyGPT architecture with dense Newton-Muon preconditioners.nm_monarch_aux: Monarch-MLP TinyGPT, but Monarch factors use auxiliary AdamW instead of an M2 preconditioner.nm_monarch_muon: Monarch-MLP TinyGPT, but Monarch factors use plain Muon/NS without the M2 block preconditioner. This is the historical control.nm_monarch: Monarch-MLP TinyGPT with exact M2 block preconditioners.
Both NewtonMuon variants report per-step loss and wall-clock. The Monarch
path replaces a dense O(d^3) Cholesky/apply path with sqrt(d) independent
small-block solves, O(d^2) total work, and O(d^{3/2}) storage, so the
relative speed advantage of the Monarch inverse grows with d. The
end-to-end optimizer speed advantage is masked at very large d by
Newton-Schulz on the (still-dense) attention matrices — see the
"Known limitations" section.
Given a Monarch weight W = P · L · P · R · P with L, R
block-diagonal (b = √d blocks of size b × b each):
- The gradient wrt
Lis also block-structured: each blockL_jonly sees per-batch inputsx_j. Thus the per-block activation covarianceA_j = 1/N · xᵀ_j x_j ∈ R^{b × b}is the exact second-moment the optimizer needs. - Cholesky-inverting
bblocks of sizeb × bcostsO(d^2)total work andO(d^{3/2})memory. The blocks are independent, so the GPU batched-kernel wall time is much closer to a small-block solve than to a densed × dsolve. - Applying the inverse is a batched
bmmwithbsmall blocks, one of the fastest operations on a GPU for this size range.
The baseline's (ZZᵀ + γI)⁻¹ of shape d × d is both approximate
(it ignores the Monarch factorisation of W) and expensive
(O(d³) in time, O(d²) in memory). Newton-Muon mitigates the cost by
refreshing only every 16 steps; our path removes the cost entirely.
All activations/gradients stay in BF16; covariance accumulation and Cholesky run in FP32 (same policy as the upstream Newton-Muon repo).
See PROGRESS.md for the full picture and prioritized next-steps. Headline
limits:
- TinyGPT + Markov-1 is a demo problem. Fits on any GPU in seconds;
insufficient to differentiate optimizer quality cleanly. The current
d=1024, 4-layer, 1000-step FineWeb bakeoff has a new quality leader:
nm_monarchwithns_backend=auto, lr0.03, reaches6.617 ± 0.058eval loss versusadamw_dense_mlpat6.7077 ± 0.0623. That gain costs wall-clock: the profiled optimizer/e2e medians are about11.50/18.24 msfornm_monarchauto versus5.64/9.35 msforadamw_dense_mlpand3.28/7.34 msforgds_dense. GDS/SF-GDS remain the faster lower-quality frontier on this setting. Partial preconditioning improves dense GDS (precond_strength=0.25, lr 0.15), but still trails AdamW; simple norm caps and per-step adaptive norm/cos policies worsened validation loss. Static schedules are more promising:gds_densewith raw attention gradients and full-strength non-attention preconditioning reaches7.1044 ± 0.0506. The best no-auxsf_gds_densereaches6.8641 ± 0.0214with lr0.4,weight_lr_power=-0.5,r=2.5, andqkv_lowlow0.2, but the current best SF-GDS row moves attention to aux AdamW and MLP gate weights to aux AdamW, then ramps the Schedule-Free y/update direction from raw to coupled GDS while reducing the early y-correction magnitude: lr0.5875, warmup75,weight_lr_power=0.5,r=2.5,attention_aux=attn_split, QKV/proj aux LRs7.5e-5/1.25e-5,dense_aux=mlp_gateat7.5e-5,sf_update_precond_schedule=cosine_to_zwith start100and transition200, plussf_update_lr_scale=0.65,sf_update_lr_scale_schedule=cosine_to_one, start100, transition200, andsf_aux_average=True,sf_aux_eval_lerp=0.45reaches6.6629 ± 0.0280eval at3.79 mswarm step. This improves the previous SF-GDS row (6.6842 ± 0.0265) by about0.021eval loss, but the intervals still overlap, so treat it as the new best candidate rather than a conclusive optimizer ranking shift. The same row without aux eval blending was6.6683 ± 0.0268at3.65 ms, so the blend is a small quality-for-time tradeoff. The previous raw-to-GDS y-direction-only row was6.7218 ± 0.0236; the previous exact-coupled split row was6.7410 ± 0.0301; fixed raw-y alone was only a stability knob and did not beat the coupled default. This is nominally ahead of AdamW dense-MLP (6.7077 ± 0.0623) but inside CI, and still behind Newton-Muon Monarch; SF-GDS is materially faster than both. Moving MLP output/all remaining dense weights to aux AdamW regresses validation. The refined aux recipe transfers to the higher-signal d=2304, 4-layer, 1000-step setting, but at a lower LR:sf_gds_densewith lr0.18,weight_lr_power=-0.5,r=2,attention_aux=attn,aux_lr=5e-5, anddense_aux=mlp_gateat7.5e-5reaches7.0897 ± 0.0638versus AdamW7.2405 ± 0.0656, with faster warm steps (10.84 msvs14.79 ms). The no-gate refined row is close (7.1010 ± 0.0643) but slower (12.69 ms); higher LRs0.2-0.25are seed-sensitive. Transferring the d=1024cosine_to_zy/update schedule to this d=2304 row is neutral: lr0.18, transition100reaches a nominally lower7.0861 ± 0.0562, but the0.0036gain is far inside CI, so the simpler no-schedule row remains the practical default. The d=1024 y-correction LR-scale trick did not transfer on seed 0 at d=2304; the best tested cell was7.1237eval, behind the confirmed7.0861row. Sweep JSON now logs the exact Markov entropy floor because duplicate successor samples can make it slightly lower thanln(sparsity). - Newton-Schulz on the big attention matrices is the wall-clock
bottleneck at d ≥ 4096, not the preconditioner refresh. The Monarch
inverse savings (
bench_inverse.pyshows 27× at d=4096, 117× at d=9216) are real but invisible against NS on(12288, 4096)QKV tensors. The rectangularBlockdiagButterflylayer is now wired intoTinyGPTattention. At d=2304, full Monarch attention cuts Newton-Muon mean NS time from about13.46 msto5.01 ms, mean end-to-end time from37.68 msto25.92 ms, and params from296.2Mto255.0M; the same untuned Newton-Muon profile has worse loss. After tuning, full Monarch attention at lr0.07/0.15matches dense attention within 3-seed CI on d=2304/l2 while reducing mean end-to-end time by about 31% and mean NS time by about 63%. QKV/projection ablation shows QKV-only gets most of the speedup but hurts Newton-Muon quality most, while projection-only preserves short eval better but keeps most dense-QKV NS cost. At d=4096/l2, full Monarch attention cuts warm step time by about 53%, mean end-to-end time by about 50%, and mean NS time by about 69%, but the 300-step FineWeb quality gap remains (9.88dense eval versus10.04best full-Monarch eval). Pure-PyTorch Gram-NS is now measured as shape-dependent: neutral for square/split-QKV matrices, but 1.6-2.0x faster for large packed rectangular matrices. The current safe default remains QKV split plus standard NS until packed-QKV Gram passes longer quality checks; CuTeDSL kernels are blocked by missingquackand Cutlass DSL packages in the base env. - Schedule-Free needs high LR plus warmup. With
warmup_steps=100and LR around1.3-1.6, SF-GDS reaches near the Markov-1 floor on the 400-step toy benchmark, and the 10k-step single-seed check improves to about1.383y-loss /1.388x-loss. On the short FineWeb grid, the useful SF-GDS LR scale is much lower (0.3-0.4with warmup 50), and it trails AdamW on validation loss. - Global vs per-block block damping is a wash on the current benchmark. Both policies are implemented and verified; tuned d=576 losses differ by ≤0.001 in single-seed sweeps.
- No distributed training. Out of scope for this reference impl.
- Fully Monarch sequence mixing is still open. Attention projections can now be structural Monarch layers, and raw/aux attention schedules remain useful dense-attention controls. M2-BERT's sequence mixer in HazyResearch/m2 uses FFT Hyena filters, not explicit Monarch multiplication, so a fully-Monarch architecture still needs implementing Theorem 3 of the paper (causal parameterisation). Not in the upstream repo.