feat(generate): wire MTP speculative decode into offline generate#385
Conversation
Construct and drive the MtpGenerator round loop in the offline `mlxcel generate` path for MTP-capable targets, replacing the deferred error that `--draft-kind mtp` returned for every target. The offline path now reuses the SAME per-target MtpTarget adapters the server burst path uses (src/models/gemma4_mtp_target.rs) and the same MtpGenerator round-loop driver, so a one-shot CLI user gets the MTP speedup without standing up a server. What changed: src/commands/generate.rs: `run_generation_mode` now takes a concrete `&LoadedModel` (instead of a generic `M: LanguageModel`) so the `--draft-kind mtp` branch can match the target family and pick the matching adapter; the only caller already passes `&LoadedModel`, and every inner call stays generic over LanguageModel (LoadedModel implements it), so the non-MTP monomorphized code is unchanged. The new `run_offline_mtp` resolves the concrete Gemma 4 target (text / VLM / Unified) the same way `run_mtp_burst` does, loads the assistant through `load_drafter` (an MTP assistant is a Drafter, not a full LoadedModel), runs the compat check, binds the drafter to the same concrete target the adapter wraps (load-bearing: MtpGenerator::generate does not bind internally), then selects `Gemma4MtpTargetAdapter` / `Gemma4VLMtpTargetAdapter` / `Gemma4UnifiedMtpTargetAdapter` and drives the loop through `drive_offline_mtp`. The resolved token bias (CLI `--logit-bias` plus the issue #350 multimodal-placeholder suppression) is injected into `SamplingConfig.token_bias` so the adapter applies the SAME bias the non-speculative CxxGenerator applies via `with_token_bias`, which is what keeps the temp-0 output byte-identical to the non-speculative path. A non-MTP-capable target returns a clear error instead of silently falling back, and a multimodal request under `--draft-kind mtp` is rejected with an actionable message. The pure routing gate `should_route_offline_mtp` keeps the loop-construction decision unit-testable without loading a model. Preserved paths: the explicit `--draft-kind dflash` / internal-mtp branches keep their deferred error, and the auto-detect classic `SpeculativeGenerator` path (no explicit `--draft-kind`) is unchanged, so existing offline speculative workflows are unaffected. src/commands/generate_tests.rs: three model-free unit tests pin the routing gate (explicit mtp routes; auto-detected mtp and other explicit kinds do not). Closes #166
…parity The MtpGenerator pushes a token and then checks EOS, so its returned vector includes the terminal stop token, whereas the non-speculative CxxGenerator and the server burst finalizer both exclude it. run_offline_mtp returned the raw generator output, so every EOS-terminated generation leaked one trailing stop token (rendered visibly because decode_generated_text uses skip_special_tokens = false) and inflated the printed generated-token count, breaking the byte-identical temp-0 parity criterion of #166. Truncate the offline MTP output at the first EOS using the merged target eos plus sampling stop_token_ids set and realign generated_tokens / decode_tok_per_sec. Adds strip_trailing_eos with model-free unit tests.
Security & performance review (offline MTP wiring)Reviewed the changed files ( Verdict: no CRITICAL or HIGH findings. No code changes required. Confirmed
Non-blocking observations (MEDIUM / LOW, no action taken)
|
Correct a wrong flag name in a generate.rs comment (line 1157 called the resolved token-bias source `--logit-bias`; the actual flag is `--lang-bias` resolved via `resolve_cli_token_bias`). Functional behavior was already correct; this is comment accuracy only. Add a paragraph to docs/benchmarks.md under the speculative decoding section covering the offline `mlxcel generate --draft-kind mtp` path (issue #166): which targets are supported (Gemma 4 text / VLM / Unified), the temp-0 parity condition (within the f16 / #203 batched-kernel jitter class on some hardware), and the greedy-parity-only limitation when any sampling penalty is active. The same caveat is added in brief to the `--draft-kind` CLI help text in speculative_args.rs for discoverability via --help.
PR Finalization (959afba)ChangesComment fix ( Documentation ( CLI help text ( Checks
TestsExisting routing + |
Summary
Wire the MTP speculative-decoding round loop into the offline
mlxcel generatepath. Previously--draft-kind mtpreturned a deferred error for every target in the offline CLI; speculative decode only ran in the server burst path. The offline path now constructs and drives the sameMtpGeneratorround loop the server uses, reusing the same per-targetMtpTargetadapters (src/models/gemma4_mtp_target.rs), so a one-shot CLI user gets the MTP speedup without standing up a server.What changed
src/commands/generate.rs:run_generation_modenow takes a concrete&LoadedModelinstead of a genericM: LanguageModel, so the--draft-kind mtpbranch can match the target family and select the matching adapter. The sole caller already passes&LoadedModel, and every inner call (generate_standard,generate_with_embeddings,SpeculativeGenerator::generate) stays generic overLanguageModel(whichLoadedModelimplements), so the non-MTP monomorphized code is unchanged.run_offline_mtp: resolves the concrete Gemma 4 target (text / VLM / Unified) the same way the server'srun_mtp_burstdoes, loads the assistant viaload_drafter(an MTP assistant is aDrafter, not a fullLoadedModel), runs the compat check, binds the drafter to the same concrete target the adapter wraps (load-bearing:MtpGenerator::generatedoes not bind internally), then selectsGemma4MtpTargetAdapter/Gemma4VLMtpTargetAdapter/Gemma4UnifiedMtpTargetAdapterand drives the loop throughdrive_offline_mtp.--logit-biasplus the issue fix(models): gemma4_unified leaks multimodal placeholder tokens (audio/image/video) into text generation #350 multimodal-placeholder suppression thatrun_generation_modealready merges) is injected intoSamplingConfig.token_biasso the adapter applies the SAME bias the non-speculativeCxxGeneratorapplies viawith_token_bias. The first-bonus sample also seeds the same history-dependent-penaltytoken_historyas the classic decode path.--image/--audio/--video) request under--draft-kind mtpis rejected with guidance, mirroring the server's multimodal decline.should_route_offline_mtpkeeps the loop-construction decision unit-testable without loading a model.src/commands/generate_tests.rs: three model-free unit tests pin the routing gate.Preserved paths
The explicit
--draft-kind dflash/ internal-mtp branches keep their existing deferred error, and the auto-detect classicSpeculativeGeneratorpath (no explicit--draft-kind) is byte-for-byte unchanged. Only the explicit--draft-kind mtpbranch now builds the real loop.Test plan
cargo check -p mlxcel --bin mlxcel --features metal,acceleratecargo test -p mlxcel --bin mlxcel --features metal,accelerate should_route_offline_mtp(3 passed)cargo clippy -p mlxcel --bin mlxcel --features metal,accelerate --tests -- -D warnings(clean)cargo fmt -p mlxcel -- --check(clean)Orchestrator validation (real models)
MTP speculative run (12B Unified pair):
./target/release/mlxcel generate -m models/gemma-4-12b-it-4bit --draft-model models/gemma-4-12b-it-assistant-4bit --draft-kind mtp -p "Explain Apple Silicon's unified memory architecture in one short paragraph." -n 128Non-speculative baseline (must be byte-identical at temp 0, which is the default):
./target/release/mlxcel generate -m models/gemma-4-12b-it-4bit -p "Explain Apple Silicon's unified memory architecture in one short paragraph." -n 128Same pattern for the 31B pair (
models/gemma-4-31b-it-4bit+models/gemma-4-31b-it-assistant-bf16). The speedup isbaseline decode tok/svs the MTP run. Note: per the #203 jitter class, temp-0 classic-vs-MTP can show occasional word-level near-tie flips on some hardware; that is the same jitter the server path exhibits, not a wiring defect.Closes #166