A CPU-first Rust training package implementing a Mamba SSM + Hyperspherical Prototype Network (HPN) architecture.
This is a concrete, working reference implementation — not a blank framework. The model is a stack of Mamba selective state-space layers with an HPN cosine-distance output head and learnable prototype matrix. Teams can use it as-is or fork and replace the layer/loss internals for their own architecture.
It works as both:
- a library dependency in your Rust application
- a ready-to-run CLI training binary
No Python runtime is required.
| Capability | Status |
|---|---|
| End-to-end train loop binary | ✅ |
| Library API for embedding custom model/training logic | ✅ |
| Serializable optimizer state (AdamW) | ✅ |
| Resume-safe checkpoints (model + optimizer + step) | ✅ |
| JSONL metrics logging | ✅ |
| Configurable layer expansion and freezing | ✅ |
| Deterministic parity probe for save/load correctness | ✅ |
| SIMD math kernels for high-throughput CPU training | ✅ |
| Validation cadence + best-checkpoint tracking | ✅ |
| Early stopping support | ✅ |
| Gradient clipping controls | ✅ |
| LR warmup + cosine decay controls | ✅ |
| Non-finite update guardrails | ✅ |
| Sharded streaming dataset support | ✅ |
| Packed sequence batching on shard streams | ✅ |
| Multi-worker sharded prefetch ingestion | ✅ |
| Run-state resume for stream cursors | ✅ |
| Atomic versioned checkpoints | ✅ |
| Cross-framework parity harness (Rust vs Python/JAX) | ✅ |
Current state: production-candidate for single-node CPU training with production-critical ingestion and parity validation implemented.
What is already robust:
- deterministic resume behavior (checkpoint + optimizer state + step)
- deterministic resume behavior for streaming shard cursors (
run_state.json) - configurable expansion/freeze controls for staged training
- validated SIMD and backward kernels with scalar parity probes
- CI, release, and crate packaging automation
Operational note:
- cross-framework parity runner requires
jaxto be installed in the active Python environment
Detailed roadmap and release milestones are tracked in roadmap.md.
Measured on this environment:
- CPU: Intel(R) Core(TM) Ultra 7 165H
- Logical CPUs: 22 (11 cores, 2 threads/core)
- Memory: 31 GiB
- OS: Linux 6.6.87.2-microsoft-standard-WSL2
- Binary:
./target/release/train_generic - Steps per run: 120
These are smoke benchmarks for relative comparison (not a full convergence benchmark).
| config | batch | seq | real_s | step_s | token_s |
|---|---|---|---|---|---|
| b16_s128_d512_l2 | 16 | 128 | 40.83 | 2.94 | 6019.10 |
| b16_s64_d512_l2 | 16 | 64 | 24.51 | 4.90 | 5013.46 |
| b8_s128_d512_l2 | 8 | 128 | 26.17 | 4.59 | 4695.45 |
| b16_s32_d512_l2 | 16 | 32 | 15.70 | 7.64 | 3913.38 |
| b4_s128_d512_l2 | 4 | 128 | 16.69 | 7.19 | 3681.25 |
| b8_s64_d512_l2 | 8 | 64 | 18.02 | 6.66 | 3409.54 |
| b8_s32_d512_l2 | 8 | 32 | 10.78 | 11.13 | 2849.72 |
| b4_s64_d512_l2 | 4 | 64 | 11.31 | 10.61 | 2716.18 |
| b4_s32_d512_l2 | 4 | 32 | 7.74 | 15.50 | 1984.50 |
Best in Matrix A: b16_s128_d512_l2 at 6019.10 tokens/s.
| config | d_model | layers | real_s | step_s | token_s |
|---|---|---|---|---|---|
| b8_s64_d256_l2 | 256 | 2 | 7.56 | 15.87 | 8126.98 |
| b8_s64_d256_l4 | 256 | 4 | 11.05 | 10.86 | 5560.18 |
| b8_s64_d256_l6 | 256 | 6 | 15.35 | 7.82 | 4002.61 |
| b8_s64_d512_l2 | 512 | 2 | 17.80 | 6.74 | 3451.69 |
| b8_s64_d256_l8 | 256 | 8 | 18.74 | 6.40 | 3278.55 |
| b8_s64_d256_l10 | 256 | 10 | 22.07 | 5.44 | 2783.87 |
| b8_s64_d512_l4 | 512 | 4 | 26.54 | 4.52 | 2315.00 |
| b8_s64_d512_l6 | 512 | 6 | 35.70 | 3.36 | 1721.01 |
| b8_s64_d512_l8 | 512 | 8 | 42.66 | 2.81 | 1440.23 |
| b8_s64_d512_l10 | 512 | 10 | 50.38 | 2.38 | 1219.53 |
Best in Matrix B: b8_s64_d256_l2 at 8126.98 tokens/s.
Raw benchmark CSVs are committed under:
runs/bench_matrix_a/results.csvruns/bench_matrix_b/results.csv
- Keep trainer internals explicit and hackable.
- Favor reproducible runs and resumability.
- Make the package easy to fork and specialize for custom architectures.
- Keep data ingestion simple at first (integer token files), then scale to streaming pipelines.
src/
lib.rs - crate root and public exports
generic_trainer.rs - full trainer state, train step, checkpoint/resume
trainer.rs - parameter and expansion/freezing config types
optim.rs - AdamW optimizer primitives
nn.rs - layer norm and output-loss helpers
simd_ops.rs - SIMD kernels used by the model path
layer.rs - cached layer forward/backward helpers
stack.rs - stack-level supervised step helpers
src/bin/
train_generic.rs - main CLI trainer
trainer_parity.rs - deterministic parity/resume checker
parity_lab.rs - expansion/freeze behavior harness
*_probe.rs - low-level probes used for validation
git clone https://github.com/npradeep357/rust_trainer
cd rust_trainer
cargo testRun a short smoke training job:
cargo run --release --bin train_generic -- \
--steps 200 \
--batch-size 4 \
--seq-len 32 \
--out-dir runs/smokeRun deterministic resume parity check:
cargo run --release --bin trainer_parityRun Rust vs Python/JAX parity check:
cargo run --release --bin cross_framework_parityThe default trainer accepts a whitespace-separated integer token file.
cargo run --release --bin train_generic -- \
--token-file /path/to/your_tokens.txt \
--out-dir runs/experiment_v1 \
--steps 50000 \
--batch-size 8 \
--seq-len 64 \
--d-model 512 \
--d-state 16 \
--base-layers 2 \
--target-layers 6 \
--placement specific:1,3,4,5 \
--freeze first:2 \
--lr 1e-4Resume training:
cargo run --release --bin train_generic -- \
--resume runs/experiment_v1/latest.bincode \
--out-dir runs/experiment_v1 \
--steps 20000| Flag | Default | Description |
|---|---|---|
--out-dir PATH |
runs/ |
Output directory for checkpoints and metrics |
--steps N |
5000 |
Number of train steps |
--save-every N |
200 |
Checkpoint interval |
--log-every N |
20 |
Metric logging interval |
--batch-size N |
8 |
Batch size |
--seq-len N |
64 |
Sequence length |
--seed N |
42 |
RNG seed |
--base-layers N |
2 |
Initial layer count before expansion |
--target-layers N |
6 |
Final layer count after expansion |
--d-model N |
512 |
Hidden width |
--d-state N |
16 |
State width |
--d-conv N |
4 |
Convolution kernel width |
--placement STR |
specific:1,3,4,5 |
Expansion placement |
--freeze STR |
first:2 |
Freeze policy |
--lr F |
1e-4 |
AdamW learning rate |
--ff-lr F |
1e-4 |
Forward-Forward local learning rate for d_skip updates |
--bp-cadence-steps N |
32 |
Apply global BP every N train steps (FF runs each step) |
--gradient-surgery-method STR |
pcgrad |
Conflict handling method: pcgrad, gradnorm, cagradstep |
--gradient-surgery-epsilon F |
1e-8 |
Numerical stability epsilon for surgery operations |
--gradnorm-alpha F |
0.2 |
GradNorm disagreement scaling factor |
--cagrad-lambda F |
1.0 |
CAGradStep conflict-aversion strength |
--freeze-embedding 1 |
false |
Freeze embedding table |
--token-file PATH |
none | Integer token dataset |
--token-dir PATH |
none | Directory of shard files for streaming training |
--val-token-file PATH |
none | Optional dedicated validation token dataset |
--val-token-dir PATH |
none | Optional validation shard directory |
--shard-ext EXT |
txt |
Extension filter used with --token-dir / --val-token-dir |
--shuffle-shards 1 |
true |
Shuffle shard order each epoch in streaming mode |
--packed-sequences 1 |
true |
Use packed contiguous token windows in streaming mode |
--prefetch-workers N |
0 |
Number of worker threads for sharded prefetch (>1 enables multi-worker mode) |
--prefetch-buffer N |
16 |
Bounded channel capacity for prefetched worker batches |
--resume PATH |
none | Resume checkpoint |
--vocab-size N |
auto | Override vocab size |
--val-ratio F |
0.05 |
Validation split ratio when --val-token-file is not provided |
--val-every N |
200 |
Validation cadence in train steps |
--eval-batches N |
8 |
Number of validation batches per eval pass |
--early-stopping-patience N |
0 |
Stop when validation does not improve for N eval windows (0 disables) |
--grad-clip-norm F |
0.0 |
Global gradient clipping threshold (0 disables clipping) |
--fail-on-non-finite 1 |
false |
Panic on NaN/Inf detection instead of skipping the update |
--lr-warmup-steps N |
0 |
Linear warmup length before decay |
--lr-min-scale F |
0.1 |
Minimum LR floor as fraction of base LR for cosine decay |
latest.bincode: latest atomic, versioned checkpoint (model + optimizer + step)best.bincode: best validation checkpointrun_state.json: resumable data-pipeline state (in-memory cursor or shard stream cursor)metrics.jsonl: train/validation metrics stream for dashboards and debugging
| Value | Meaning |
|---|---|
append |
Add new layers at the end |
prepend |
Add new layers at the beginning |
insert:N |
Insert all new layers starting at index N |
specific:1,3,4,5 |
Place each new layer at specific final indices |
| Value | Meaning |
|---|---|
first:N |
Freeze first N layers |
indices:0,2,5 |
Freeze explicit layer indices |
Add dependency:
[dependencies]
rust_trainer = "0.1"Use the package name from your own Cargo.toml.
Minimal integration example:
use rust_trainer::generic_trainer::{
GenericTrainer, default_trainer_config, make_batch_from_tokens,
};
use rust_trainer::{ExpansionPlacement, FreezeSelection, LayerSpec};
let spec = LayerSpec { d_model: 512, d_state: 16, d_conv: 4 };
let cfg = default_trainer_config(
8192,
spec,
6,
ExpansionPlacement::SpecificPositions(vec![1, 3, 4, 5]),
FreezeSelection::FirstN(2),
false,
1e-4,
);
let mut trainer = GenericTrainer::new_random(cfg, 2, 42);
let tokens: Vec<i64> = (0..8192).collect();
let (ids, targets) = make_batch_from_tokens(&tokens, 0, 8, 64);
let stats = trainer.train_step(&ids, &targets);
println!("loss: {}", stats.loss);
trainer.save_checkpoint("checkpoint.bincode").unwrap();The default model uses:
| Component | Implementation |
|---|---|
| Sequence layers | Mamba SSM (causal conv1d + SiLU + discretized state scan) |
| Output head | Hyperspherical Prototype Network (HPN) |
| Loss | Squared cosine distance to nearest prototype |
| Optimizer | AdamW with serializable moment buffers |
| Inference path | CPU-only; no GPU required |
The package is designed to be forked for other architectures. Replace or extend:
- Layer forward/backward path in
src/layer.rs— swap Mamba for Transformer, LSTM, etc. - Output loss/head logic in
src/nn.rs— swap HPN for cross-entropy, contrastive loss, etc. - Trainer state wiring in
src/generic_trainer.rs— add or remove parameter groups - Data loading logic in
src/bin/train_generic.rs
The checkpointing, optimizer state, logging, expansion, and freeze infrastructure are all architecture-independent and can be kept as-is.
Releases are tag-driven via GitHub Actions.
# bump version in Cargo.toml, commit, then:
git tag v0.2.0
git push origin v0.2.0The release workflow runs tests, builds binaries, creates a GitHub Release, and can publish to crates.io when credentials are configured.
Apache-2.0. See LICENSE.