Skip to content

feat(generate): wire MTP speculative decode into offline generate#385

Merged
inureyes merged 3 commits into
mainfrom
feature/issue-166-offline-mtp-speculative
Jun 21, 2026
Merged

feat(generate): wire MTP speculative decode into offline generate#385
inureyes merged 3 commits into
mainfrom
feature/issue-166-offline-mtp-speculative

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Wire the MTP speculative-decoding round loop into the offline mlxcel generate path. Previously --draft-kind mtp returned a deferred error for every target in the offline CLI; speculative decode only ran in the server burst path. The offline path now constructs and drives the same MtpGenerator round loop the server uses, reusing the same per-target MtpTarget adapters (src/models/gemma4_mtp_target.rs), so a one-shot CLI user gets the MTP speedup without standing up a server.

What changed

  • src/commands/generate.rs: run_generation_mode now takes a concrete &LoadedModel instead of a generic M: LanguageModel, so the --draft-kind mtp branch can match the target family and select the matching adapter. The sole caller already passes &LoadedModel, and every inner call (generate_standard, generate_with_embeddings, SpeculativeGenerator::generate) stays generic over LanguageModel (which LoadedModel implements), so the non-MTP monomorphized code is unchanged.
  • New run_offline_mtp: resolves the concrete Gemma 4 target (text / VLM / Unified) the same way the server's run_mtp_burst does, loads the assistant via load_drafter (an MTP assistant is a Drafter, not a full LoadedModel), runs the compat check, binds the drafter to the same concrete target the adapter wraps (load-bearing: MtpGenerator::generate does not bind internally), then selects Gemma4MtpTargetAdapter / Gemma4VLMtpTargetAdapter / Gemma4UnifiedMtpTargetAdapter and drives the loop through drive_offline_mtp.
  • Byte-identical parity: the resolved token bias (CLI --logit-bias plus the issue fix(models): gemma4_unified leaks multimodal placeholder tokens (audio/image/video) into text generation #350 multimodal-placeholder suppression that run_generation_mode already merges) is injected into SamplingConfig.token_bias so the adapter applies the SAME bias the non-speculative CxxGenerator applies via with_token_bias. The first-bonus sample also seeds the same history-dependent-penalty token_history as the classic decode path.
  • Clear errors, no silent fallback: a non-MTP-capable target returns an actionable error, and a multimodal (--image / --audio / --video) request under --draft-kind mtp is rejected with guidance, mirroring the server's multimodal decline.
  • Pure routing gate should_route_offline_mtp keeps the loop-construction decision unit-testable without loading a model.
  • src/commands/generate_tests.rs: three model-free unit tests pin the routing gate.

Preserved paths

The explicit --draft-kind dflash / internal-mtp branches keep their existing deferred error, and the auto-detect classic SpeculativeGenerator path (no explicit --draft-kind) is byte-for-byte unchanged. Only the explicit --draft-kind mtp branch now builds the real loop.

Test plan

  • cargo check -p mlxcel --bin mlxcel --features metal,accelerate
  • cargo test -p mlxcel --bin mlxcel --features metal,accelerate should_route_offline_mtp (3 passed)
  • cargo clippy -p mlxcel --bin mlxcel --features metal,accelerate --tests -- -D warnings (clean)
  • cargo fmt -p mlxcel -- --check (clean)
  • Real-model validation (deferred to the orchestrator, needs the release binary + multi-GB checkpoints): byte-identical temp-0 output and decode speedup, see below.

Orchestrator validation (real models)

MTP speculative run (12B Unified pair):

./target/release/mlxcel generate -m models/gemma-4-12b-it-4bit --draft-model models/gemma-4-12b-it-assistant-4bit --draft-kind mtp -p "Explain Apple Silicon's unified memory architecture in one short paragraph." -n 128

Non-speculative baseline (must be byte-identical at temp 0, which is the default):

./target/release/mlxcel generate -m models/gemma-4-12b-it-4bit -p "Explain Apple Silicon's unified memory architecture in one short paragraph." -n 128

Same pattern for the 31B pair (models/gemma-4-31b-it-4bit + models/gemma-4-31b-it-assistant-bf16). The speedup is baseline decode tok/s vs the MTP run. Note: per the #203 jitter class, temp-0 classic-vs-MTP can show occasional word-level near-tie flips on some hardware; that is the same jitter the server path exhibits, not a wiring defect.

Closes #166

Construct and drive the MtpGenerator round loop in the offline `mlxcel generate` path for MTP-capable targets, replacing the deferred error that `--draft-kind mtp` returned for every target. The offline path now reuses the SAME per-target MtpTarget adapters the server burst path uses (src/models/gemma4_mtp_target.rs) and the same MtpGenerator round-loop driver, so a one-shot CLI user gets the MTP speedup without standing up a server.

What changed:

src/commands/generate.rs: `run_generation_mode` now takes a concrete `&LoadedModel` (instead of a generic `M: LanguageModel`) so the `--draft-kind mtp` branch can match the target family and pick the matching adapter; the only caller already passes `&LoadedModel`, and every inner call stays generic over LanguageModel (LoadedModel implements it), so the non-MTP monomorphized code is unchanged. The new `run_offline_mtp` resolves the concrete Gemma 4 target (text / VLM / Unified) the same way `run_mtp_burst` does, loads the assistant through `load_drafter` (an MTP assistant is a Drafter, not a full LoadedModel), runs the compat check, binds the drafter to the same concrete target the adapter wraps (load-bearing: MtpGenerator::generate does not bind internally), then selects `Gemma4MtpTargetAdapter` / `Gemma4VLMtpTargetAdapter` / `Gemma4UnifiedMtpTargetAdapter` and drives the loop through `drive_offline_mtp`. The resolved token bias (CLI `--logit-bias` plus the issue #350 multimodal-placeholder suppression) is injected into `SamplingConfig.token_bias` so the adapter applies the SAME bias the non-speculative CxxGenerator applies via `with_token_bias`, which is what keeps the temp-0 output byte-identical to the non-speculative path. A non-MTP-capable target returns a clear error instead of silently falling back, and a multimodal request under `--draft-kind mtp` is rejected with an actionable message. The pure routing gate `should_route_offline_mtp` keeps the loop-construction decision unit-testable without loading a model.

Preserved paths: the explicit `--draft-kind dflash` / internal-mtp branches keep their deferred error, and the auto-detect classic `SpeculativeGenerator` path (no explicit `--draft-kind`) is unchanged, so existing offline speculative workflows are unaffected.

src/commands/generate_tests.rs: three model-free unit tests pin the routing gate (explicit mtp routes; auto-detected mtp and other explicit kinds do not).

Closes #166
@inureyes inureyes added status:review Under review type:enhancement New features, capabilities, or significant additions area:models Model architectures, weights, loading, metadata priority:low Low priority labels Jun 21, 2026
…parity

The MtpGenerator pushes a token and then checks EOS, so its returned vector includes the terminal stop token, whereas the non-speculative CxxGenerator and the server burst finalizer both exclude it. run_offline_mtp returned the raw generator output, so every EOS-terminated generation leaked one trailing stop token (rendered visibly because decode_generated_text uses skip_special_tokens = false) and inflated the printed generated-token count, breaking the byte-identical temp-0 parity criterion of #166. Truncate the offline MTP output at the first EOS using the merged target eos plus sampling stop_token_ids set and realign generated_tokens / decode_tok_per_sec. Adds strip_trailing_eos with model-free unit tests.
@inureyes

Copy link
Copy Markdown
Member Author

Security & performance review (offline MTP wiring)

Reviewed the changed files (src/commands/generate.rs, src/commands/generate_tests.rs) at HEAD 6d6999b44, weighted toward performance, resource handling, and panic-safety since this is a local one-shot CLI path with minimal network surface.

Verdict: no CRITICAL or HIGH findings. No code changes required.

Confirmed

  • Hot path / no per-token overhead. run_offline_mtp does only one-time setup (load + validate + bind + a single sampling_config.clone()), then hands off to the existing MtpGenerator loop. strip_trailing_eos runs once at end-of-generation (single position() scan, O(n)), and the stats realignment runs once. Nothing is added inside the per-token round loop.
  • Monomorphization change is codegen-neutral. run_generation_mode was generic only ever instantiated with LoadedModel (the sole caller passes &LoadedModel); switching the signature to concrete &LoadedModel produces identical code and adds no clones or allocations on the non-MTP path. The MTP-only sampling.clone() is off that path.
  • Resource safety. The drafter is loaded exactly once via load_drafter; the MTP branch returns before the classic load_model(draft_model_path), so there is no double-load. A load / validate / bind failure returns Err and the partially-set-up drafter drops via RAII with no partially-initialized state and no panic. bind takes the target by shared reference, so the target model is never left in a partial state.
  • Guards precede heavy allocation. block_size < 2 and the non-MTP-capable-target gate both return Err before load_drafter runs.
  • Panic-safety of new code. strip_trailing_eos is safe for empty token vec, empty eos set, and eos at index 0 (each yields a no-op or empty result, pinned by unit tests). The stats realignment guards decode_time_ms > 0.0, so no divide-by-zero.
  • DoS / unbounded work. The round loop is strictly bounded by max_tokens (loop-head check, the per-round remaining/bs clamp, and the walk budget). A misbehaving drafter that errors or under-produces causes a clean break, never unbounded generation.
  • Conventions. No new unsafe, no AI attribution, no em dashes in the added lines, no secret/PII in the new log lines (only the drafter path, consistent with the existing classic-path log).

Non-blocking observations (MEDIUM / LOW, no action taken)

  • MEDIUM (pre-existing, shared code): MtpGenerator::generate asserts !prompt_tokens.is_empty(). An empty tokenized prompt under --draft-kind mtp would abort instead of returning an error. In practice Gemma 4 prepends BOS, so this is not reachable on normal input; the classic path degrades rather than asserting. Worth a guard if offline empty-prompt robustness is desired, but it is shared-generator behavior, not introduced here.
  • LOW: run_offline_mtp matches over LoadedModel twice (target resolution, then adapter selection) with an unreachable!() in the second match. The branch is provably dead given the first gate, but folding both into a single match would remove the parallel-match fragility.

Correct a wrong flag name in a generate.rs comment (line 1157 called the
resolved token-bias source `--logit-bias`; the actual flag is `--lang-bias`
resolved via `resolve_cli_token_bias`). Functional behavior was already
correct; this is comment accuracy only.

Add a paragraph to docs/benchmarks.md under the speculative decoding section
covering the offline `mlxcel generate --draft-kind mtp` path (issue #166):
which targets are supported (Gemma 4 text / VLM / Unified), the temp-0 parity
condition (within the f16 / #203 batched-kernel jitter class on some hardware),
and the greedy-parity-only limitation when any sampling penalty is active. The
same caveat is added in brief to the `--draft-kind` CLI help text in
speculative_args.rs for discoverability via --help.
@inureyes

Copy link
Copy Markdown
Member Author

PR Finalization (959afba)

Changes

Comment fix (src/commands/generate.rs line 1157): The comment named --logit-bias as the source of the resolved token bias. The actual flag is --lang-bias (resolved via resolve_cli_token_bias in src/lang_bias.rs). Corrected to --lang-bias. Functional behavior was already correct; this was comment accuracy only.

Documentation (docs/benchmarks.md): Added a paragraph in the existing "Speculative decoding (MTP)" section covering the offline mlxcel generate --draft-kind mtp path (issue #166): supported targets (Gemma 4 text / VLM / Unified), the temp-0 parity condition (within the f16 / #203 batched-kernel jitter class), and the greedy-parity-only limitation when any sampling penalty is active (first bonus token penalized, subsequent tokens in each verify window are greedy, so penalized requests are not byte-identical to the non-speculative path). This mirrors the server burst path's known limitation.

CLI help text (src/cli/speculative_args.rs): Added the same greedy-parity caveat in brief to the --draft-kind doc comment for discoverability via --help.

Checks

  • cargo fmt --check: clean
  • cargo clippy -p mlxcel --bin mlxcel --features metal,accelerate --tests -- -D warnings: clean (15s, no warnings)

Tests

Existing routing + strip_trailing_eos unit tests in src/commands/generate_tests.rs cover the PR's logic. No obvious cheap gap identified that doesn't require model loading.

@inureyes inureyes added status:done Completed and removed status:review Under review labels Jun 21, 2026
@inureyes inureyes merged commit ca5200d into main Jun 21, 2026
5 checks passed
@inureyes inureyes deleted the feature/issue-166-offline-mtp-speculative branch June 21, 2026 16:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:models Model architectures, weights, loading, metadata priority:low Low priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: wire MTP speculative decoding into the offline mlxcel generate path

1 participant