perf(gemma3n): reduce bf16 decode AltUp/MLP graph overhead#60
Merged
Conversation
The bf16 decode hot path built the per-layer AltUp update as a sequence of `Vec<UniquePtr<MlxArray>>` operations: `predict` cast each of the four AltUp planes to f32 individually, stacked them, ran the projection, then sliced the result back into four arrays; `correct` re-stacked those same four arrays only to add the correction and slice them apart again; and the dense MLP issued four separate Rust to C++ bridge calls (gate, up, activation, down). Every slice/stack/cast boundary is a fresh graph node, so a single decoder layer churned through far more graph construction than the underlying math requires. Keep the AltUp prediction as one `[altup, B, L, hidden]` graph island for the whole layer. New `predict_stacked` and `correct_stacked` methods operate on the stacked tensor directly, casting once after the stack (matching mlx-lm's `x.astype(mx.float32)` scheduling) instead of once per plane, and `correct_stacked` adds the correction onto the existing stacked predictions rather than rebuilding them. The public `predict`/`correct` Vec-returning wrappers are preserved as thin shims that delegate to the stacked path and split, so the parity tests and any external callers are unaffected. `DecoderLayer::forward` now consumes the stacked tensor end to end and only slices the active plane when it actually needs it. Three helpers back this: `slice_altup_plane`, `split_altup_planes`, and `split_altup_after_per_layer_update` (which folds mlx-lm's `corrected_predictions[1:] += first_prediction` update into the split). The non-quantized bf16 language MLP now runs through a single `gemma3n_mlp_forward` C++ bridge call (cast input to bf16, gate/up, gelu_approx or gelu_topk, down, cast back to bf16) instead of four Rust-side ops. Matmuls stay outside `mx::compile` for the same reason as the existing `compiled_swiglu_mlp_forward_fp16`/`compiled_gelu_mlp_forward_fp16` helpers — compiled matmul+transpose graphs can reuse the wrong per-layer constants — while the element-wise activation reuses the cached compiled kernels. Quantized weights fall back to the existing op-at-a-time path unchanged. On the Mac Studio M1 Ultra, this lifts gemma-3n-E4B-it bf16 text decode from 34.41 to 35.65 tok/s (about +3.6%), crossing 90% mlx-lm parity (88% to 91%); the M1 Ultra benchmark page is updated to the new decode value and the parity count is bumped accordingly. Verified with the new `slice_altup_plane`/`split_altup_after_per_layer_update` unit tests, a clean clippy build of the C++ bridge, and a coherent real-model generation on gemma3n-e4b-bf16.
inureyes
added a commit
that referenced
this pull request
May 21, 2026
) The fused Gemma3n bf16 decode path added in #60 (stacked AltUp predict/correct plus the `gemma3n_mlp_forward` bridge call) cuts Rust to C++ graph-construction overhead and improves decode on Apple Silicon without a Neural Accelerator (Mac Studio M1 Ultra: about +3.6%). A same-machine A/B on the MacBook Pro M5 Max shows the opposite there: the fused path regressed gemma-3n-E4B-it bf16 same-process decode by roughly 6.3% (about 39.0 down to 36.6 tok/s). The stacked-AltUp scheduling and the single fused MLP bridge call interact poorly with M5-class (Neural Accelerator) hardware, where the pre-fused per-op path schedules better. Gate both fused paths behind a new `use_fused_decode_path()` helper, which is true only when the hardware is not a Neural Accelerator part (`!(has_neural_accelerator && macos_supports_na)`), mirroring the same hardware predicate already used elsewhere in the core. `MLP::forward` now takes the fused `gemma3n_mlp_forward` bridge call only off NA hardware and otherwise runs the per-op bf16 path. `DecoderLayer::forward` dispatches to `forward_stacked` (the fused, stacked-AltUp layer) on non-NA hardware and to `forward_split` (the pre-fused path where `AltUp::predict`/`correct` return per-plane Vecs) on M5-class hardware. Both code paths already existed, so this is a runtime dispatch by hardware class, not a revert: non-NA Apple Silicon keeps the faster fused path, M5-class hardware avoids the regression. On the M5 Max, gemma-3n-E4B-it bf16 decode returns to about 39 tok/s (representative 39.05) with coherent real-model output, holding near 80% of the mlx-lm reference. The M5 Max benchmark page is updated to the restored decode value, its vs-M1-Ultra ratio (39.05 / 35.65 = 1.10x) is recomputed against the M1 Ultra decode, and the note now records that M5-class hardware uses the split decode path while other Apple Silicon uses the fused path. Verified with the gemma3n helper unit tests, a clean clippy build, and a coherent generation on gemma3n-e4b-bf16 (this M1 Ultra host exercises the fused branch; the split branch is the unchanged pre-fused code).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Reduces graph-construction overhead on the Gemma3n bf16 decode hot path by keeping the per-layer AltUp update as a single stacked tensor and collapsing the dense bf16 MLP into one C++ bridge call.
What changed
AltUp::predict/correctpreviously cast, stacked, projected, and then sliced the four AltUp planes back into aVec<UniquePtr<MlxArray>>every layer — andcorrectre-stacked the same arrays just to add the correction. Newpredict_stacked/correct_stackedkeep the prediction as one[altup, B, L, hidden]graph island, cast once after the stack (matching mlx-lm'sx.astype(mx.float32)scheduling) instead of once per plane, and add the correction onto the existing stacked tensor. The Vec-returningpredict/correctwrappers are kept as thin shims that delegate and split, so parity tests and external callers are unchanged.DecoderLayer::forwardconsumes the stacked tensor end to end and only slices the active plane when needed. Three helpers back this:slice_altup_plane,split_altup_planes, andsplit_altup_after_per_layer_update(the last folds mlx-lm'scorrected_predictions[1:] += first_predictionupdate into the split).gemma3n_mlp_forwardC++ bridge call (cast to bf16, gate/up, gelu_approx or gelu_topk, down, cast back) instead of four Rust ops. Matmuls deliberately stay outsidemx::compile, like the existingcompiled_swiglu_mlp_forward_fp16/compiled_gelu_mlp_forward_fp16helpers, while the element-wise activation reuses the cached compiled kernels. Quantized weights keep the existing op-at-a-time path.Performance
On the Mac Studio M1 Ultra, gemma-3n-E4B-it bf16 text decode improves from 34.41 to 35.65 tok/s (about +3.6%), crossing 90% mlx-lm parity (88% to 91%). The M1 Ultra benchmark page is updated to the new decode value with the parity count bumped to match.
Testing
slice_altup_plane_selects_stacked_prediction_planeandsplit_altup_after_per_layer_update_preserves_plane_zero_and_updates_tailpass alongside the existing gemma3n helper tests.cargo clippy --features metal,accelerate --lib --bin mlxcel --tests -- -D warningsis clean (the C++ bridge compiles and links).