Skip to content

fix: per-row position holes break B>1 batched MTP greedy parity after divergent accepts #203

@inureyes

Description

@inureyes

Problem / Background

Found while verifying issue #163 (PR #202) on an M1 Ultra with the Gemma 4 31B + MTP assistant pairing (REACHABLE_PAIRINGS[1], block_size 4, tests/speculative_parity.rs). Issue #163 hardened the ragged/batched MTP verify masks (NaN-safe padding rows, per-row stale-tail exclusion). Those fixes are correct and necessary, but real-model gates revealed a deeper, pre-existing structural gap that masking cannot fix.

In a B>1 batched MTP burst, every verify window is written at the shared physical cache offset and rotated (RoPE) at those physical positions for all rows. After a divergent (mixed-accept) round, rollback_speculative_cache_explicit trims all rows to the global max accept, so a row whose accept count was below the round max keeps a hole: its logical token sequence occupies positions [0, valid_end_r) plus the new window at [physical_offset, ...), with a gap in between that B=1 never has (B=1 trims its cache exactly, so its positions are contiguous).

Because both keys and queries rotate at physical positions, every subsequent attention score between a new token and the pre-hole history has its relative RoPE distance inflated by the cumulative hole size relative to the row's standalone B=1 run. The #163 stale-tail mask correctly removes the stale keys inside the hole from the key set, but no mask can repair the inflated relative distances. Greedy parity for a row is therefore only guaranteed while the row has accepted the round max in every round so far (the lockstep prefix); after the first sub-max round the row's logits diverge from B=1 structurally and near-tie argmaxes flip.

Evidence

M1 Ultra, gemma-4-31b-it-4bit + gemma-4-31B-it-assistant-bf16, greedy, max_tokens 24, the four equal-length prompts from greedy_parity_mtp_gemma4_batched_b4_matches_b1 and the four variable-length prompts from the new ragged gate:

  1. The strongest single confirmation: in the ragged run, row 3 was the round-max row for many consecutive rounds (accept_lens [3, 3, 1, 1, 3, ...]) and stayed byte-identical to its B=1 reference for 17 tokens, then flipped within a round or two of its first sub-max round. Rows whose accepts dropped below the max early (accept_lens all zeros for row 0) flipped within one or two rounds of the first divergence.
  2. The equal-length gate greedy_parity_mtp_gemma4_batched_b4_matches_b1 is red on current main on M1 Ultra (row 0 token 4: batched 1852 vs B=1 3161), before any perf: harden ragged B>1 MTP batching (NaN-safe masked rows, per-row valid-length verify mask) #163 change. It was validated byte-identical on an M5 Max when it landed, so the strict gate is hardware-sensitive: which near-tie flips first differs by GPU class, but the hole mechanism is hardware-independent.
  3. With PR perf: harden ragged B>1 MTP batching masks and verify tail #202's stale-gap mask temporarily disabled via an env switch, the branch reproduced main's failure signature bit-exactly on the equal-length case (same stream prefix [100, 236749, 715, 496], same first flip to 1852 at token 4), proving perf: harden ragged B>1 MTP batching masks and verify tail #202's only behavioral delta on the equal-length path is the mask itself and the parity breakage pre-dates it.
  4. Separate but related M1 finding, fixed by perf: harden ragged B>1 MTP batching (NaN-safe masked rows, per-row valid-length verify mask) #163: on main, the ragged path emitted token id 0 garbage from the very first token on M1 Ultra (B=1 expected 100), because the fully-masked left-padding query rows produced NaN softmax rows and the M1 SDPA kernel variant did not confine them the way the M5 kernel did. PR perf: harden ragged B>1 MTP batching masks and verify tail #202's diagonal rescue fixes this class entirely (with perf: harden ragged B>1 MTP batching masks and verify tail #202 the ragged prefill bonus and the lockstep prefix match B=1 on M1).

Proposed Solution

Decouple RoPE position from physical cache slot for the batched MTP verify forward, assigning each row its logical positions (valid_end_r..) for new window tokens, in the spirit of the per-row batched RoPE machinery that already exists for dense batched decode (fast_rope_batched per-row offsets). Alternatively (or additionally) compact each row's hole away after rollback so positions stay contiguous. Either approach must keep the drafter-side per-row metadata (kv_offset = left_padding + kv_valid_len anchors) consistent, and the #163 stale-gap mask then becomes redundant-but-harmless for the compacted variant or stays load-bearing for the position-assignment variant.

Acceptance Criteria

  • After divergent accepts, each batched row's tokens remain byte-identical to its standalone B=1 run on the same hardware.
  • The strict greedy_parity_mtp_gemma4_batched_b4_matches_b1 gate passes on M1 Ultra.
  • The perf: harden ragged B>1 MTP batching masks and verify tail #202 ragged gate's full-stream form passes on M1 Ultra, with the lockstep-prefix-only relaxation removed from the ragged gate.

Technical Considerations

Both batched MTP paths are opt-in (MLXCEL_ENABLE_MTP_BATCH / MLXCEL_ENABLE_MTP_BATCH_RAGGED default off), so production defaults are unaffected. Priority reflects that this invalidates the lossless-speculation claim of the opt-in feature under mixed accepts.

Metadata

Metadata

Assignees

No one assigned

    Labels

    area:modelsModel architectures, weights, loading, metadatapriority:mediumMedium prioritytype:bugBug fixes, error corrections, or issue resolutions

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions