From 3eabdb77fc4c2034e22ac5bdd36d072a987ee4c4 Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Thu, 21 May 2026 18:22:03 +0900 Subject: [PATCH] perf(gemma3n): improve M5 decode bandwidth with pretransposed weights MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit On M5-class (Neural Accelerator) hardware the non-quantized Gemma3n decode GEMVs — the MLP gate/up projections and the tied LM head — stream their weights faster when MLX is handed an already-materialized transposed weight rather than transposing on the fly inside every decode step. This materializes those transposes once at load time on M5-class hardware only, leaving every other code path untouched. `MLP::gate_proj`/`up_proj` become a new `MlpInputProjection` enum. The `Standard` variant wraps `UnifiedLinear` exactly as before and is used for quantized layers and for all non-M5 hardware. The `Pretransposed` variant is selected only when the hardware is an M5-class part and the weight is not quantized: it transposes the projection weight, makes it contiguous, evaluates it at load time, and then `forward` is a plain `matmul` (plus optional bias) against the prepared weight. The tied LM head gets the same treatment through `pretranspose_large_m5_embedding`, which materializes the wide embedding transpose under the same M5-and-non-quantized guard; `Gemma3nLanguageModel` caches it in `embed_tokens_weight_t` and a new `lm_head` helper uses it when present, otherwise falling back to `embed_tokens.as_linear`. Quantized models keep their specialized 4-bit paths unchanged, and the split-path per-layer update now consumes the corrected planes by iterator instead of copying the first plane. This also adds an opt-in Metal GPU capture hook for profiling: when the process is launched with `MTL_CAPTURE_ENABLED=1` and `MLXCEL_CAPTURE_DECODE=` is set, the generator captures exactly one warm decode token (at the second step, after the decode kernels are JIT-cached) to a `.gputrace` bundle and exits, so per-kernel timings are directly comparable with mlx-lm's `mx.metal.start_capture` script. The hook is inert unless both environment variables are set. This complements the M5 split-path dispatch from #61: M5-class hardware now both avoids the fused path and gets the pretransposed decode weights. Verified with the gemma3n helper unit tests, a clean clippy build, and a coherent generation on gemma3n-e4b-bf16 (this M1 Ultra host takes the standard, non-pretransposed branch, so the dispatch and `lm_head` fallback are exercised end to end). --- src/lib/mlxcel-core/src/generate.rs | 19 +++++ src/models/gemma3n.rs | 114 +++++++++++++++++++++++++--- 2 files changed, 123 insertions(+), 10 deletions(-) diff --git a/src/lib/mlxcel-core/src/generate.rs b/src/lib/mlxcel-core/src/generate.rs index 3241388..d8bb797 100644 --- a/src/lib/mlxcel-core/src/generate.rs +++ b/src/lib/mlxcel-core/src/generate.rs @@ -1028,6 +1028,25 @@ impl CxxGenerator { ffi::export_to_dot_pair(&path, &next_tok, &next_log); } } + // Optional Metal GPU capture of one warm decode token for + // per-kernel profiling vs mlx-lm. Fires at n==2 so + // all decode kernels are JIT-cached. Requires the process to be + // launched with `MTL_CAPTURE_ENABLED=1`; writes a `.gputrace` + // bundle to the given path, comparable with mlx-lm's + // `mx.metal.start_capture`. + if n == 2 { + if let Ok(path) = std::env::var("MLXCEL_CAPTURE_DECODE") { + ffi::metal_start_capture(&path); + ffi::eval(&next_tok); + ffi::metal_stop_capture(); + // Exit immediately so the GPU trace document finalizes + // with exactly one captured decode token and no further + // GPU work polluting it (mirrors mlx-lm's capture-script + // lifecycle). Capture mode is a profiling-only path. + eprintln!("[capture] wrote one decode token to {path}"); + std::process::exit(0); + } + } if force_sync { ffi::eval(&next_tok); } else { diff --git a/src/models/gemma3n.rs b/src/models/gemma3n.rs index 69becf5..b1dbed2 100644 --- a/src/models/gemma3n.rs +++ b/src/models/gemma3n.rs @@ -665,13 +665,74 @@ impl Gemma3nAttention { // MLP with gelu_topk activation. pub struct MLP { - pub gate_proj: UnifiedLinear, - pub up_proj: UnifiedLinear, + pub gate_proj: MlpInputProjection, + pub up_proj: MlpInputProjection, pub down_proj: UnifiedLinear, pub activation_sparsity: f32, pub std_multiplier: f32, } +// M5 non-quantized Gemma3n decode GEMVs stream gate/up weights faster when MLX +// sees materialized transposed weights. Quantized layers keep UnifiedLinear so +// their specialized 4bit path is unchanged. +pub enum MlpInputProjection { + Standard(UnifiedLinear), + Pretransposed { + weight_t: UniquePtr, + bias: Option>, + }, +} + +impl MlpInputProjection { + fn from_weights_maybe_pretransposed( + weights: &WeightMap, + prefix: &str, + group_size: i32, + bits: i32, + ) -> Result { + let hw = mlxcel_core::hardware::get_hardware(); + let is_m5_na = hw.has_neural_accelerator && hw.macos_supports_na; + let scales_name = format!("{}.scales", prefix); + if !is_m5_na || weights.contains_key(&scales_name) { + return Ok(Self::Standard(UnifiedLinear::from_weights( + weights, prefix, group_size, bits, + )?)); + } + + let weight_name = format!("{}.weight", prefix); + let weight = weights + .get(&weight_name) + .ok_or_else(|| format!("Weight not found: {}", weight_name))?; + let weight_t = mlxcel_core::transpose(weight); + let weight_t = mlxcel_core::contiguous(&weight_t, false); + mlxcel_core::eval(&weight_t); + + let bias_name = format!("{}.bias", prefix); + let bias = weights.get(&bias_name).map(|b| mlxcel_core::copy(b)); + Ok(Self::Pretransposed { weight_t, bias }) + } + + fn forward(&self, x: &MlxArray) -> UniquePtr { + match self { + Self::Standard(linear) => linear.forward(x), + Self::Pretransposed { weight_t, bias } => { + let out = mlxcel_core::matmul(x, weight_t); + match bias { + Some(bias) => mlxcel_core::add(&out, bias), + None => out, + } + } + } + } + + fn regular_weight(&self) -> Option<&Linear> { + match self { + Self::Standard(linear) => linear.regular_weight(), + Self::Pretransposed { .. } => None, + } + } +} + /// #60 introduced a fused Gemma3n decode path (stacked AltUp predict/ /// correct plus the `gemma3n_mlp_forward` bridge call) that cuts Rust↔C++ /// graph-construction overhead. It improves decode on Apple Silicon without a @@ -775,14 +836,18 @@ impl MLP { .map(|q| q.bits as i32) .unwrap_or(4); - let gate_proj = UnifiedLinear::from_weights( + let gate_proj = MlpInputProjection::from_weights_maybe_pretransposed( weights, &format!("{}.gate_proj", prefix), group_size, bits, )?; - let up_proj = - UnifiedLinear::from_weights(weights, &format!("{}.up_proj", prefix), group_size, bits)?; + let up_proj = MlpInputProjection::from_weights_maybe_pretransposed( + weights, + &format!("{}.up_proj", prefix), + group_size, + bits, + )?; let down_proj = UnifiedLinear::from_weights( weights, &format!("{}.down_proj", prefix), @@ -962,9 +1027,12 @@ impl DecoderLayer { // Add first_prediction to corrected[1:]. let mut result = Vec::with_capacity(corrected.len()); - result.push(mlxcel_core::copy(&corrected[0])); - for item in corrected.iter().skip(1) { - result.push(mlxcel_core::add(item, &first_prediction)); + let mut corrected = corrected.into_iter(); + if let Some(first) = corrected.next() { + result.push(first); + } + for item in corrected { + result.push(mlxcel_core::add(&item, &first_prediction)); } result @@ -1068,6 +1136,7 @@ impl DecoderLayer { // Language Model. pub struct Gemma3nLanguageModel { pub embed_tokens: UnifiedEmbedding, + pub embed_tokens_weight_t: Option>, pub embed_tokens_per_layer: UnifiedEmbedding, pub per_layer_model_projection: UnifiedLinear, pub per_layer_projection_norm: RMSNorm, @@ -1082,6 +1151,29 @@ pub struct Gemma3nLanguageModel { } impl Gemma3nLanguageModel { + fn pretranspose_large_m5_embedding( + embedding: &UnifiedEmbedding, + ) -> Option> { + let hw = mlxcel_core::hardware::get_hardware(); + if embedding.is_quantized() || !(hw.has_neural_accelerator && hw.macos_supports_na) { + return None; + } + + // The tied LM head is a very wide decode GEMV; materializing the + // transpose improves M5 bandwidth on non-quantized Gemma3n. + let weight_t = mlxcel_core::transpose(embedding.weight()); + let weight_t = mlxcel_core::contiguous(&weight_t, false); + mlxcel_core::eval(&weight_t); + Some(weight_t) + } + + fn lm_head(&self, out: &MlxArray) -> UniquePtr { + match &self.embed_tokens_weight_t { + Some(weight_t) => mlxcel_core::matmul(out, weight_t), + None => self.embed_tokens.as_linear(out), + } + } + pub fn forward(&self, inputs: &MlxArray, caches: &mut [KVCache]) -> UniquePtr { // Embed tokens let h = self.embed_tokens.forward(inputs); @@ -1177,7 +1269,7 @@ impl Gemma3nLanguageModel { }; // LM head (tied embeddings) - let mut logits = self.embed_tokens.as_linear(&out); + let mut logits = self.lm_head(&out); // Apply logit softcapping if configured if let Some(cap) = self.config.final_logit_softcapping { @@ -1338,7 +1430,7 @@ impl Gemma3nLanguageModel { } else { mlxcel_core::astype(&out, mlxcel_core::array_dtype(self.embed_tokens.weight())) }; - let mut logits = self.embed_tokens.as_linear(&out); + let mut logits = self.lm_head(&out); if let Some(cap) = self.config.final_logit_softcapping { logits = apply_softcap(&logits, cap); @@ -1377,6 +1469,7 @@ impl Gemma3nLanguageModel { group_size, bits, )?; + let embed_tokens_weight_t = Self::pretranspose_large_m5_embedding(&embed_tokens); let embed_tokens_per_layer = UnifiedEmbedding::from_weights( weights, &format!("{}.embed_tokens_per_layer", prefix), @@ -1473,6 +1566,7 @@ impl Gemma3nLanguageModel { Ok(Self { embed_tokens, + embed_tokens_weight_t, embed_tokens_per_layer, per_layer_model_projection, per_layer_projection_norm,