Skip to content

catid/m2gdn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

m2gdn — Exact second-order LLM training on Monarch (M2) matrices

A small reference implementation that pairs Monarch-matrix MLPs (Hazy Research's M2) with Newton–Muon-style right-preconditioning (Du & Su 2026) and shows that the M2 block structure yields an exact block-diagonal preconditioner with O(d^{3/2}) storage and O(d^2) total refresh/apply work for d=b^2 — much cheaper than the baseline's dense (ZZᵀ + γI)⁻¹ inverse.

The repo also implements three other optimizer families (FOOF/GDS, Schedule-Free, plain AdamW) all using the same activation-capture and inverse-refresh machinery, for head-to-head comparison.

  • references.md — paper/repo deep dive (read for algorithm details).
  • PROGRESS.md — current project state, results, open questions, next-steps (read first if picking up work).
  • AGENTS.md — hardware notes + four-GPU conventions.

Layout

m2gdn/
├── m2gdn/
│   ├── monarch.py         # square Monarch2 + rectangular BlockdiagButterfly layers
│   ├── inverse.py         # batched damped Cholesky inverse + verification
│   ├── newton_schulz.py   # BF16 Newton–Schulz (Polar-Express coefficients)
│   ├── optimizers.py      # 7 optimizers + capture-window mixin
│   └── model.py           # TinyGPT with Monarch MLPs
├── tests/
│   ├── test_monarch.py
│   ├── test_inverse.py
│   ├── test_newton_schulz.py
│   ├── test_optimizers.py  (all 7 optimizers covered)
│   ├── test_covariance.py
│   ├── test_data_source.py
│   └── test_sweep_resume.py
├── bench_inverse.py       # dense vs block-diagonal inverse benchmark
├── bench_newton_schulz.py # standard vs Gram NS benchmark
├── bench_upstream_gram_ns_kernels.py  # optional upstream CuTeDSL kernel bench
├── compare_optimizers.py  # default optimizer bakeoff report from artifacts
├── train.py               # side-by-side benchmark (NewtonMuon dense vs monarch)
├── sweep.py               # NewtonMuon LR sweep
├── sweep_gds.py           # cross-optimizer sweep (knows all 7 variants)
├── dual_sweep_gds.py      # multi-GPU sweep_gds launcher + result merger
├── references.md          # deep-dive notes on 4 papers + 5 repos
├── PROGRESS.md            # project state for handoff
├── AGENTS.md              # AI/agent conventions (READ FIRST: hardware + four-GPU)
└── papers/                # 4 polished paper markdowns + extract pipeline

Optimizers

All seven live in m2gdn/optimizers.py and share the same _CovState, _CaptureWindowMixin, and BF16 inverse-cache infrastructure.

class preconditioner momentum NS use case
NewtonMuonDense dense (ZZᵀ+γI)⁻¹ Nesterov-Muon yes paper baseline
NewtonMuonMonarch exact M2 block-diag Nesterov-Muon yes the headline optimizer
GDSDense dense AdamW EMA no FOOF + AdamW first moment
GDSMonarch exact M2 AdamW EMA no GDS + structured-matrix trick
ScheduleFreeGDSDense dense SF interpolation+averaging no matches paper reference
ScheduleFreeGDSMonarch exact M2 SF interpolation+averaging no SF base + M2 trick
AdamW per-coordinate AdamW no vanilla baseline

For Newton-Muon variants, kind="plain" parameters use an auxiliary AdamW path (aux_lr=1e-3 by default) instead of being pushed through Muon/NS. Packed attention QKV weights are row-split into Q/K/V slices before Newton-Schulz by default (ns_split=3 in the parameter group); pass --disable-qkv-ns-split to sweep_gds.py to reproduce the older packed-NS behavior. For Monarch block-preconditioned variants, block_damping_policy can be per_block (default) or global; d=576 sweeps found no meaningful tuned-LR quality difference, so the block-local heuristic remains the default. For GDS/SF-GDS variants, precond_strength blends raw and activation- preconditioned gradients (1.0 is the historical full-preconditioned path); precond_norm_cap optionally caps the blended direction's norm relative to the raw gradient. sweep_gds.py exposes these as --precond-strengths and --precond-norm-caps. Schedule-Free variants expose --weight-lr-powers and --rs for the Schedule-Free averaging kernel; current d=1024 FineWeb SF-GDS probes prefer weight_lr_power=-0.5, with no-aux runs favoring r=2.5 and attention/dense aux runs favoring the stabler high-LR r=2 cell, over the paper-default weight_lr_power=2, r=0. Schedule-Free GDS variants also expose --precond-strength-warmup-steps, which ramps the raw-to-preconditioned blend from raw gradients to the target strength over optimizer steps, and --precond-strength-lows, which controls the low value used by *_low structural schedules. --sf-update-precond-strengths is an experimental decoupled-SF knob: omitting it preserves the exact coupled z/y update direction, while explicit scalar values let the Schedule-Free averaged iterate use a raw-to-GDS blend different from the z/base step. --sf-update-precond-schedules and --sf-update-precond-transition-steps add time-varying y/update schedules; the current d=1024 best uses cosine_to_z to ramp y from raw to the z/base GDS direction over 200 steps. --sf-update-lr-scales, --sf-update-lr-scale-schedules, and --sf-update-lr-scale-transition-steps independently scale the Schedule-Free y correction; the best d=1024 SF-GDS row found so far starts that y correction at 0.65x and ramps it back to 1.0x with cosine_to_one over 200 steps. Adaptive experimental variants (cosine_align*_to_z, cosine_norm*_to_z) gate that ramp using raw/preconditioned direction agreement, but current FineWeb confirmations did not beat the fixed cosine_to_z schedule. --sf-average-start-steps and --sf-average-ramp-steps are tail-averaging research knobs; current d=1024 FineWeb sweeps found hard delayed starts much worse and ckp1 ramps slightly worse than the default averaging kernel. For SF-GDS aux AdamW groups, --sf-aux-average tracks a separate SF-style average and --sf-aux-eval-lerp selects the eval-time blend from live aux weights (0) to averaged aux weights (1). Full aux averaging worsened the current d=1024 row, but a partial blend at 0.45 gives the best SF-GDS eval loss tested so far. Adaptive policies are also available via --precond-strength-policies fixed,norm_ratio,norm_cosine, --precond-target-ratios, and --precond-min-cosines; current FineWeb probes found the adaptive norm/cos policies slower and worse than fixed partial strength. Static low-overhead module schedules are exposed with --precond-strength-schedules; current d=1024 FineWeb probes found that lowering packed-QKV preconditioning while keeping non-QKV weights strongly preconditioned improves SF-GDS without the per-step norm/cos overhead. Q/K/V row-sliced schedules (q_low, k_low, v_low, qk_low, qv_low, kv_low, plus the matching *_raw forms) are available for packed QKV localization; current probes show the scalar qkv_low schedule is still better and faster than slice-low variants. Projection-specific schedules (proj_low, proj_raw, qkv_raw_proj_low, qkv_raw_proj_raw) isolate the attention output projection from packed QKV and MLP output weights. For true auxiliary-optimizer ablations, --attention-aux-modes qkv,proj,attn moves selected attention weights out of the GDS/SF-GDS update and into the auxiliary AdamW path controlled by --aux-lrs. On the current d=1024, 4-layer FineWeb grid, all-attention aux AdamW plus the refined SF averaging kernel closes most of the AdamW gap while remaining faster than dense AdamW. --attention-aux-modes attn_split additionally supports separate packed-QKV and attention-projection LRs via --attention-qkv-aux-lrs and --attention-proj-aux-lrs; the best split row is only nominally ahead of the shared attn row and remains inside CI. Non-attention dense ablations are exposed separately with --dense-aux-modes mlp_gate,mlp_out,mlp,nonattn_dense,all_dense and optional per-group --dense-aux-lrs. In the current d=1024, 4-layer FineWeb grid, attention_aux=attn or attn_split plus dense_aux=mlp_gate is the strongest SF-GDS setting tested so far and cuts warm step time by roughly 9-10%; moving the MLP output or all remaining dense preconditioned weights to aux AdamW regresses held-out loss, especially for SF-GDS. BlockdiagButterfly is the rectangular Monarch-style layer for d -> 3d, 3d -> d, and other divisible projections; it exposes per-factor capture shapes for the same block preconditioner machinery. sweep_gds.py --attention-kind dense,monarch_qkv,monarch_proj,monarch uses that layer to replace packed attention QKV and/or the attention output projection. Current d=2304 FineWeb checks show full Monarch attention reduces parameters and optimizer-step time for Monarch-MLP GDS/SF-GDS, while the Newton-Muon path needs a higher LR than the initial lr=0.01 profile. QKV-only replacement carries most of the GDS/SF-GDS benefit; projection-only is not a good GDS/SF-GDS path, though it is the least damaging Newton-Muon short-profile ablation.

Install

uv venv .venv --python 3.12
uv pip install --python .venv/bin/python --pre torch \
  --index-url https://download.pytorch.org/whl/nightly/cu130
uv pip install --python .venv/bin/python -e . --no-deps
uv pip install --python .venv/bin/python numpy einops tqdm pytest pip

Dependencies: torch>=2.4, numpy, einops, tqdm.

Run the tests

CUDA_VISIBLE_DEVICES=0,1,2,3 PYTHONPATH=. .venv/bin/python run_tests.py

The runner compiles the main scripts, runs every standalone test, and reports CUDA availability plus CUDA-only self-skips on CPU-only machines. The individual scripts remain directly runnable:

PYTHONPATH=. .venv/bin/python tests/test_monarch.py
PYTHONPATH=. .venv/bin/python tests/test_inverse.py
PYTHONPATH=. .venv/bin/python tests/test_newton_schulz.py
PYTHONPATH=. .venv/bin/python tests/test_optimizers.py
PYTHONPATH=. .venv/bin/python tests/test_covariance.py
PYTHONPATH=. .venv/bin/python tests/test_data_source.py
PYTHONPATH=. .venv/bin/python tests/test_gpu_launcher.py
PYTHONPATH=. .venv/bin/python tests/test_sweep_resume.py

Expected: run_tests.py reports PASS for py_compile and all eight scripts. The optimizer test suite includes a side-by-side check that ScheduleFreeGDSDense matches the upstream SGDScheduleFreeReference impl to FP32 noise.

Inverse benchmark

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. .venv/bin/python bench_inverse.py

On a Blackwell RTX PRO 6000 (97 GB), you should see something like:

     d    b |  inv_blk(ms)  inv_dense(ms)   speedup |  apply_blk(ms)  apply_dense(ms)   speedup |  mem_blk(MB)  mem_den(MB)
   256   16 |        0.091          0.234      2.6x |         0.0042           0.0082      2.0x |        0.016        0.250
  1024   32 |        0.152          1.018      6.7x |         0.0042           0.0474     11.4x |        0.125        4.000
  4096   64 |        0.257          6.946     27.1x |         0.0063           2.0648    330.2x |        1.000       64.000
  9216   96 |        0.403         47.034    116.7x |         0.0104          25.4295   2440.5x |        3.375      324.000
  • The inverse of the block-diagonal factor is 27× faster than a d × d Cholesky at d = 4096, growing to >100× at d = 9216.
  • Applying the inverse to a gradient (one torch.bmm vs one large matmul) is even cheaper in the block form.
  • Memory is √d times smaller; at d = 9216 we hold a 3.4 MB per-layer preconditioner instead of 324 MB.

Correctness residuals ‖(A + γI) · A⁻¹ − I‖ are in the 1e-6 range at all sizes, printed by the same benchmark.

Reproducible command cookbook

All commands below assume the repo root as the working directory.

# Full verification; local GPU run is normally under 10 seconds.
CUDA_VISIBLE_DEVICES=0,1,2,3 PYTHONPATH=. .venv/bin/python run_tests.py --quiet

# CPU-hidden skip-path check for CUDA-only tests.
CUDA_VISIBLE_DEVICES= PYTHONPATH=. .venv/bin/python run_tests.py --no-compile --quiet

# Core inverse primitive benchmark.
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. .venv/bin/python bench_inverse.py

# Current d=576 GDS seeded comparison.
PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
    --variants gds_dense,gds_monarch --lrs 0.2,0.25,0.3 \
    --out-dir sweeps_seed_d576_gds_tuned \
    --seeds 0,1,2 --momentums 0.85 \
    --d 576 --heads 4 --layers 2 --seq-len 64 --batch 16 --steps 400

# Current d=576 SF-GDS seeded comparison.
PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
    --variants sf_gds_dense,sf_gds_monarch --lrs 1.3,1.6 \
    --out-dir sweeps_seed_d576_sf_gds_tuned \
    --seeds 0,1,2 --momentums 0.9 --warmup-steps 50,100 \
    --d 576 --heads 4 --layers 2 --seq-len 64 --batch 16 --steps 400

# Cached FineWeb-Edu GPT-2 smoke, split across four GPUs.
PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
    --variants adamw_dense_mlp,adamw --lrs 0.0003,0.0006 \
    --out-dir sweeps_fineweb_smoke_d1024 \
    --data fineweb --data-dir /home/catid/ensemble2/data/fineweb_edu_gpt2_probe2 \
    --d 1024 --heads 8 --layers 1 --seq-len 32 --batch 2 --steps 5 \
    --eval-batches 2

# Fast synthetic SF-GDS/M2 profile with compiled model and compiled Markov data.
PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
    --variants sf_gds_monarch --lrs 9.5 \
    --out-dir sweeps_sf_m2_compiled_profile \
    --seeds 0,1,2,3 --split-by runs \
    --steps 60 --d 1024 --heads 8 --layers 1 --seq-len 64 --batch 8 \
    --vocab-size 4096 --eval-batches 1 --profile-timings \
    --lm-head-nblocks 32 --lm-head-hidden-block-size 64 --lm-head-terms 16 \
    --tie-m2-embedding --precond-strength-schedules mlp_head_raw \
    --precond-strength-low 0.25 --momentums 0.9 --warmup-steps 10 \
    --compile-model --compile-data --matmul-precision high
# Latest retained artifact: /tmp/m2gdn_sf_m2_profile_head_sum_einsum
# with tail-median end-to-end mean about 1.73 ms across four GPUs.

# Export seeded comparison tables for reports.
PYTHONPATH=. .venv/bin/python analyze_profiles.py \
    sweeps_seed_d576_gds_tuned/results.json \
    sweeps_seed_d576_sf_gds_tuned/results.json \
    --seed-summary-csv sweeps_seed_d576_seed_summaries.csv \
    --seed-summary-md sweeps_seed_d576_seed_summaries.md

# Generate the default cross-optimizer comparison report.
PYTHONPATH=. .venv/bin/python compare_optimizers.py \
    --out-md optimizer_comparison.md \
    --out-csv optimizer_comparison.csv

Rerunning the same dual_sweep_gds.py command resumes from <out-dir>/gpu*/runs/*.json and skips matching run keys. Add --force only when deliberately recomputing cached configurations.

Newton-Schulz backend benchmark

PYTHONPATH=. python3 bench_newton_schulz.py \
    --shapes 576x576,1728x576,2304x576,576x2304,3x576x576

NewtonMuonDense and NewtonMuonMonarch support ns_backend="standard", "gram", and "auto"; sweep_gds.py exposes this as --ns-backends. The pure-PyTorch Gram path is numerically compatible and falls back on square matrices. Small d=576 timings were slower, but large rectangular Blackwell timings are now favorable: d=4096 packed QKV (12288,4096) is 18.65 ms standard vs 10.88 ms Gram, and 4x-MLP-shaped (16384,4096) is 24.86 ms standard vs 12.71 ms Gram. Square and split-QKV batched shapes remain neutral.

CuTeDSL kernel status: the base env intentionally has no gram_newton_schulz, quack, or Cutlass DSL imports. An isolated .venv_gram_ns install can import them and run upstream kernels with the runtime SM120 compatibility shim in bench_upstream_gram_ns_kernels.py, but those kernels are slower than PyTorch on the local RTX PRO 6000 Blackwell box. Example d=4096 timings: packed QKV (12288,4096) is 13.95 ms standard torch, 10.15 ms Gram torch, 20.64 ms standard kernel, and 17.93 ms Gram kernel. Kernel use should stay import- and benchmark-gated, not enabled by GPU compute capability alone.

auto is conservative: it uses Gram only for rectangular matrices with smaller side at least 1024, and standard NS otherwise. For large packed QKV, auto also keeps the packed matrix intact instead of applying the default ns_split=3, because split-QKV is square and cannot benefit from Gram. In a d=4096 3-seed/300-step nm_dense_mlp check, split+standard averaged 135.0 ms optimizer time and min loss 4.6877 ± 0.0972; packed-QKV auto averaged 118.6 ms and min loss 4.5038 ± 0.0697. Treat this as an explicit large-width experiment knob. FineWeb d=4096/l2 validation cuts end-to-end time by about 13% and ties eval within CI, but train min/final loss is worse, so auto is not the default yet.

Optimizer comparison

Quick 2-way check (NewtonMuon dense vs Monarch only):

PYTHONPATH=. python3 train.py --steps 200 --d 256 --seq-len 128 --batch 32

Full LR sweep across multiple optimizers (split across all four GPUs):

PYTHONPATH=. .venv/bin/python dual_sweep_gds.py \
    --variants nm_dense_mlp,nm_monarch_aux,gds_dense,sf_gds_dense,adamw,nm_monarch,gds_monarch,sf_gds_monarch \
    --lrs 0.03,0.1,0.3,1.0 --d 576 --steps 400 --out-dir sweeps_4gpu

Each variant has different sensible LR ranges; see PROGRESS.md for the best-known hyperparameters and the full results inventory. dual_sweep_gds.py records child commands in commands.json, per-GPU logs in gpu*/stdout.log, and merged results in the top-level results.json. sweep_gds.py is resumable: each finished config writes <out-dir>/runs/<run_key>.json immediately, reruns skip matching configs by default, and --force recomputes them. Pass --seeds 0,1,2 for repeated trials; the output results.json includes seed_summaries with mean/std/95% CI for min loss, final loss, eval@x loss, warm step time, divergence rate, and global/per-variant *_within_*_ci95 flags for noise-scale differences. Use --data fineweb --data-dir <token-shard-dir> to train on cached GPT-style token streams with train.bin, val.bin, and optional meta.json; if --data-dir is omitted, the runner checks M2GDN_FINEWEB_DIR, repo-local data/fineweb_edu_gpt2*, and the known local /home/catid/ensemble2/data cache. FineWeb runs infer vocab_size and token_dtype from metadata and report held-out eval_loss from val.bin. When a sweep includes Monarch-model variants, --d must be a perfect square such as 1024, 2304, or 4096. Export those rows for reports with:

PYTHONPATH=. python3 analyze_profiles.py sweeps_seed/results.json \
    --seed-summary-csv seed_summaries.csv --seed-summary-md seed_summaries.md

The exported seed-summary tables keep non-empty tuning identity columns such as momentum, warmup_steps, r, damping policy, NS backend, and low-rank knobs, so rows from multi-knob sweeps remain distinguishable outside results.json.

Monarch block damping policy sweeps are explicit and recorded in JSON:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 sweep_gds.py \
    --variants nm_monarch --lrs 0.01,0.03,0.05 \
    --block-damping-policies per_block,global --d 576 --steps 400 \
    --out-dir sweeps_damping_nm

Newton-Schulz backend sweeps are also explicit:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 sweep_gds.py \
    --variants nm_dense_mlp --lrs 0.01 --aux-lrs 0.001 \
    --ns-backends standard,gram,auto --disable-qkv-ns-split \
    --d 576 --steps 400 --profile-timings --out-dir sweeps_ns_backend_dense

Covariance diagnostics are opt-in and summarize why block preconditioning helps or loses information:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 sweep_gds.py \
    --variants nm_monarch --lrs 0.01 --aux-lrs 0.001 \
    --cov-diagnostics --d 576 --steps 64 --out-dir sweeps_covdiag_d576_monarch

PYTHONPATH=. python3 analyze_profiles.py sweeps_covdiag_d576_monarch/results.json \
    --out-json cov_summary.json --out-csv cov_summary.csv --cov-plot cov.png

The current d=576 diagnostic run shows dense activation covariances have high cross-block energy (~0.87-0.89), so channel permutation or low-rank residual corrections are more promising next experiments than more single-seed LR grids.

Block-permutation probes and sweeps:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 probe_block_permutations.py \
    --mlp-kind monarch --d 576 --heads 8 --layers 2 --seq-len 64 \
    --batch 16 --batches 4 --max-dim 576 --random-perms 8 \
    --out sweeps_blockperm_d576_monarch.json

CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. python3 sweep_gds.py \
    --variants sf_gds_monarch --lrs 1.3 --warmup-steps 100 \
    --monarch-perm transpose_grid --d 576 --steps 400 \
    --out-dir sweeps_perm_transpose_sf

The current result is mixed: transpose_grid lowers covariance cross-block energy for many Monarch activations and slightly improves SF-GDS eval@x, but it does not improve Newton-Muon or GDS min loss. It is an experiment knob, not a new default.

Low-rank residual probe:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python3 probe_lowrank_block.py \
    --mlp-kind monarch --d 576 --heads 8 --layers 2 --seq-len 64 \
    --batch 16 --batches 4 --max-dim 576 --ranks 4,8,16,32,64 \
    --methods raw_residual,bmetric_residual,bmetric_inverse,sketched_residual \
    --out sweeps_lowrank_block_inverse_ranked_d576_monarch.json

Current result: residual-energy modes improve K @ K_approx^-1 ≈ I but barely move inverse/gradient error at low rank. Ranking block-metric modes by inverse impact, |lambda / (1 + lambda)|, is more promising: at d=576, rank-64 lowers sampled preconditioned-gradient error from ~0.72 to ~0.61/0.62, and rank-192 reaches ~0.40/0.42. Sketch modes need conditioning guards before integration.

The optimizer prototype exposes this for dense preconditioner groups:

PYTHONPATH=. python3 sweep_gds.py \
    --variants gds_monarch --lrs 0.3 --d 256 --steps 400 \
    --dense-inverse-modes dense,block_lowrank \
    --lowrank-ranks 16 --lowrank-methods bmetric_inverse \
    --profile-timings --out-dir sweeps_lowrank_opt_d256_gds_monarch

At d=256, rank-16 block_lowrank slightly improves tuned GDS loss but is much slower because the prototype still materializes a full inverse. Newton-Muon does not show a clear quality gain. block_lowrank_factored stores Woodbury factors instead of a materialized inverse and matches the same loss, but it is not faster in eager PyTorch because refresh still uses dense eigensolves and factored apply adds overhead. The next implementation step is faster mode selection/fused apply, not larger full-inverse sweeps.

Newton-Muon baseline names are intentionally explicit:

  • nm_dense_mlp: true dense-MLP TinyGPT architecture with dense Newton-Muon preconditioners.
  • nm_monarch_aux: Monarch-MLP TinyGPT, but Monarch factors use auxiliary AdamW instead of an M2 preconditioner.
  • nm_monarch_muon: Monarch-MLP TinyGPT, but Monarch factors use plain Muon/NS without the M2 block preconditioner. This is the historical control.
  • nm_monarch: Monarch-MLP TinyGPT with exact M2 block preconditioners.

Both NewtonMuon variants report per-step loss and wall-clock. The Monarch path replaces a dense O(d^3) Cholesky/apply path with sqrt(d) independent small-block solves, O(d^2) total work, and O(d^{3/2}) storage, so the relative speed advantage of the Monarch inverse grows with d. The end-to-end optimizer speed advantage is masked at very large d by Newton-Schulz on the (still-dense) attention matrices — see the "Known limitations" section.

What's exact here, and why

Given a Monarch weight W = P · L · P · R · P with L, R block-diagonal (b = √d blocks of size b × b each):

  • The gradient wrt L is also block-structured: each block L_j only sees per-batch inputs x_j. Thus the per-block activation covariance A_j = 1/N · xᵀ_j x_j ∈ R^{b × b} is the exact second-moment the optimizer needs.
  • Cholesky-inverting b blocks of size b × b costs O(d^2) total work and O(d^{3/2}) memory. The blocks are independent, so the GPU batched-kernel wall time is much closer to a small-block solve than to a dense d × d solve.
  • Applying the inverse is a batched bmm with b small blocks, one of the fastest operations on a GPU for this size range.

The baseline's (ZZᵀ + γI)⁻¹ of shape d × d is both approximate (it ignores the Monarch factorisation of W) and expensive (O(d³) in time, O(d²) in memory). Newton-Muon mitigates the cost by refreshing only every 16 steps; our path removes the cost entirely.

All activations/gradients stay in BF16; covariance accumulation and Cholesky run in FP32 (same policy as the upstream Newton-Muon repo).

Known limitations / next steps

See PROGRESS.md for the full picture and prioritized next-steps. Headline limits:

  • TinyGPT + Markov-1 is a demo problem. Fits on any GPU in seconds; insufficient to differentiate optimizer quality cleanly. The current d=1024, 4-layer, 1000-step FineWeb bakeoff has a new quality leader: nm_monarch with ns_backend=auto, lr 0.03, reaches 6.617 ± 0.058 eval loss versus adamw_dense_mlp at 6.7077 ± 0.0623. That gain costs wall-clock: the profiled optimizer/e2e medians are about 11.50/18.24 ms for nm_monarch auto versus 5.64/9.35 ms for adamw_dense_mlp and 3.28/7.34 ms for gds_dense. GDS/SF-GDS remain the faster lower-quality frontier on this setting. Partial preconditioning improves dense GDS (precond_strength=0.25, lr 0.15), but still trails AdamW; simple norm caps and per-step adaptive norm/cos policies worsened validation loss. Static schedules are more promising: gds_dense with raw attention gradients and full-strength non-attention preconditioning reaches 7.1044 ± 0.0506. The best no-aux sf_gds_dense reaches 6.8641 ± 0.0214 with lr 0.4, weight_lr_power=-0.5, r=2.5, and qkv_low low 0.2, but the current best SF-GDS row moves attention to aux AdamW and MLP gate weights to aux AdamW, then ramps the Schedule-Free y/update direction from raw to coupled GDS while reducing the early y-correction magnitude: lr 0.5875, warmup 75, weight_lr_power=0.5, r=2.5, attention_aux=attn_split, QKV/proj aux LRs 7.5e-5/1.25e-5, dense_aux=mlp_gate at 7.5e-5, sf_update_precond_schedule=cosine_to_z with start 100 and transition 200, plus sf_update_lr_scale=0.65, sf_update_lr_scale_schedule=cosine_to_one, start 100, transition 200, and sf_aux_average=True, sf_aux_eval_lerp=0.45 reaches 6.6629 ± 0.0280 eval at 3.79 ms warm step. This improves the previous SF-GDS row (6.6842 ± 0.0265) by about 0.021 eval loss, but the intervals still overlap, so treat it as the new best candidate rather than a conclusive optimizer ranking shift. The same row without aux eval blending was 6.6683 ± 0.0268 at 3.65 ms, so the blend is a small quality-for-time tradeoff. The previous raw-to-GDS y-direction-only row was 6.7218 ± 0.0236; the previous exact-coupled split row was 6.7410 ± 0.0301; fixed raw-y alone was only a stability knob and did not beat the coupled default. This is nominally ahead of AdamW dense-MLP (6.7077 ± 0.0623) but inside CI, and still behind Newton-Muon Monarch; SF-GDS is materially faster than both. Moving MLP output/all remaining dense weights to aux AdamW regresses validation. The refined aux recipe transfers to the higher-signal d=2304, 4-layer, 1000-step setting, but at a lower LR: sf_gds_dense with lr 0.18, weight_lr_power=-0.5, r=2, attention_aux=attn, aux_lr=5e-5, and dense_aux=mlp_gate at 7.5e-5 reaches 7.0897 ± 0.0638 versus AdamW 7.2405 ± 0.0656, with faster warm steps (10.84 ms vs 14.79 ms). The no-gate refined row is close (7.1010 ± 0.0643) but slower (12.69 ms); higher LRs 0.2-0.25 are seed-sensitive. Transferring the d=1024 cosine_to_z y/update schedule to this d=2304 row is neutral: lr 0.18, transition 100 reaches a nominally lower 7.0861 ± 0.0562, but the 0.0036 gain is far inside CI, so the simpler no-schedule row remains the practical default. The d=1024 y-correction LR-scale trick did not transfer on seed 0 at d=2304; the best tested cell was 7.1237 eval, behind the confirmed 7.0861 row. Sweep JSON now logs the exact Markov entropy floor because duplicate successor samples can make it slightly lower than ln(sparsity).
  • Newton-Schulz on the big attention matrices is the wall-clock bottleneck at d ≥ 4096, not the preconditioner refresh. The Monarch inverse savings (bench_inverse.py shows 27× at d=4096, 117× at d=9216) are real but invisible against NS on (12288, 4096) QKV tensors. The rectangular BlockdiagButterfly layer is now wired into TinyGPT attention. At d=2304, full Monarch attention cuts Newton-Muon mean NS time from about 13.46 ms to 5.01 ms, mean end-to-end time from 37.68 ms to 25.92 ms, and params from 296.2M to 255.0M; the same untuned Newton-Muon profile has worse loss. After tuning, full Monarch attention at lr 0.07/0.15 matches dense attention within 3-seed CI on d=2304/l2 while reducing mean end-to-end time by about 31% and mean NS time by about 63%. QKV/projection ablation shows QKV-only gets most of the speedup but hurts Newton-Muon quality most, while projection-only preserves short eval better but keeps most dense-QKV NS cost. At d=4096/l2, full Monarch attention cuts warm step time by about 53%, mean end-to-end time by about 50%, and mean NS time by about 69%, but the 300-step FineWeb quality gap remains (9.88 dense eval versus 10.04 best full-Monarch eval). Pure-PyTorch Gram-NS is now measured as shape-dependent: neutral for square/split-QKV matrices, but 1.6-2.0x faster for large packed rectangular matrices. The current safe default remains QKV split plus standard NS until packed-QKV Gram passes longer quality checks; CuTeDSL kernels are blocked by missing quack and Cutlass DSL packages in the base env.
  • Schedule-Free needs high LR plus warmup. With warmup_steps=100 and LR around 1.3-1.6, SF-GDS reaches near the Markov-1 floor on the 400-step toy benchmark, and the 10k-step single-seed check improves to about 1.383 y-loss / 1.388 x-loss. On the short FineWeb grid, the useful SF-GDS LR scale is much lower (0.3-0.4 with warmup 50), and it trails AdamW on validation loss.
  • Global vs per-block block damping is a wash on the current benchmark. Both policies are implemented and verified; tuned d=576 losses differ by ≤0.001 in single-seed sweeps.
  • No distributed training. Out of scope for this reference impl.
  • Fully Monarch sequence mixing is still open. Attention projections can now be structural Monarch layers, and raw/aux attention schedules remain useful dense-attention controls. M2-BERT's sequence mixer in HazyResearch/m2 uses FFT Hyena filters, not explicit Monarch multiplication, so a fully-Monarch architecture still needs implementing Theorem 3 of the paper (causal parameterisation). Not in the upstream repo.

About

M2 Monarch + Schedule-Free Optimizer + LocoProp-S + Parallel Ensemble experiments

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages