You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Found while verifying issue #163 (PR #202) on an M1 Ultra with the Gemma 4 31B + MTP assistant pairing (REACHABLE_PAIRINGS[1], block_size 4, tests/speculative_parity.rs). Issue #163 hardened the ragged/batched MTP verify masks (NaN-safe padding rows, per-row stale-tail exclusion). Those fixes are correct and necessary, but real-model gates revealed a deeper, pre-existing structural gap that masking cannot fix.
In a B>1 batched MTP burst, every verify window is written at the shared physical cache offset and rotated (RoPE) at those physical positions for all rows. After a divergent (mixed-accept) round, rollback_speculative_cache_explicit trims all rows to the global max accept, so a row whose accept count was below the round max keeps a hole: its logical token sequence occupies positions [0, valid_end_r) plus the new window at [physical_offset, ...), with a gap in between that B=1 never has (B=1 trims its cache exactly, so its positions are contiguous).
Because both keys and queries rotate at physical positions, every subsequent attention score between a new token and the pre-hole history has its relative RoPE distance inflated by the cumulative hole size relative to the row's standalone B=1 run. The #163 stale-tail mask correctly removes the stale keys inside the hole from the key set, but no mask can repair the inflated relative distances. Greedy parity for a row is therefore only guaranteed while the row has accepted the round max in every round so far (the lockstep prefix); after the first sub-max round the row's logits diverge from B=1 structurally and near-tie argmaxes flip.
Evidence
M1 Ultra, gemma-4-31b-it-4bit + gemma-4-31B-it-assistant-bf16, greedy, max_tokens 24, the four equal-length prompts from greedy_parity_mtp_gemma4_batched_b4_matches_b1 and the four variable-length prompts from the new ragged gate:
The strongest single confirmation: in the ragged run, row 3 was the round-max row for many consecutive rounds (accept_lens [3, 3, 1, 1, 3, ...]) and stayed byte-identical to its B=1 reference for 17 tokens, then flipped within a round or two of its first sub-max round. Rows whose accepts dropped below the max early (accept_lens all zeros for row 0) flipped within one or two rounds of the first divergence.
The equal-length gate greedy_parity_mtp_gemma4_batched_b4_matches_b1 is red on current main on M1 Ultra (row 0 token 4: batched 1852 vs B=1 3161), before any perf: harden ragged B>1 MTP batching (NaN-safe masked rows, per-row valid-length verify mask) #163 change. It was validated byte-identical on an M5 Max when it landed, so the strict gate is hardware-sensitive: which near-tie flips first differs by GPU class, but the hole mechanism is hardware-independent.
Decouple RoPE position from physical cache slot for the batched MTP verify forward, assigning each row its logical positions (valid_end_r..) for new window tokens, in the spirit of the per-row batched RoPE machinery that already exists for dense batched decode (fast_rope_batched per-row offsets). Alternatively (or additionally) compact each row's hole away after rollback so positions stay contiguous. Either approach must keep the drafter-side per-row metadata (kv_offset = left_padding + kv_valid_len anchors) consistent, and the #163 stale-gap mask then becomes redundant-but-harmless for the compacted variant or stays load-bearing for the position-assignment variant.
Acceptance Criteria
After divergent accepts, each batched row's tokens remain byte-identical to its standalone B=1 run on the same hardware.
The strict greedy_parity_mtp_gemma4_batched_b4_matches_b1 gate passes on M1 Ultra.
Both batched MTP paths are opt-in (MLXCEL_ENABLE_MTP_BATCH / MLXCEL_ENABLE_MTP_BATCH_RAGGED default off), so production defaults are unaffected. Priority reflects that this invalidates the lossless-speculation claim of the opt-in feature under mixed accepts.
Problem / Background
Found while verifying issue #163 (PR #202) on an M1 Ultra with the Gemma 4 31B + MTP assistant pairing (REACHABLE_PAIRINGS[1], block_size 4, tests/speculative_parity.rs). Issue #163 hardened the ragged/batched MTP verify masks (NaN-safe padding rows, per-row stale-tail exclusion). Those fixes are correct and necessary, but real-model gates revealed a deeper, pre-existing structural gap that masking cannot fix.
In a B>1 batched MTP burst, every verify window is written at the shared physical cache offset and rotated (RoPE) at those physical positions for all rows. After a divergent (mixed-accept) round,
rollback_speculative_cache_explicittrims all rows to the global max accept, so a row whose accept count was below the round max keeps a hole: its logical token sequence occupies positions [0, valid_end_r) plus the new window at [physical_offset, ...), with a gap in between that B=1 never has (B=1 trims its cache exactly, so its positions are contiguous).Because both keys and queries rotate at physical positions, every subsequent attention score between a new token and the pre-hole history has its relative RoPE distance inflated by the cumulative hole size relative to the row's standalone B=1 run. The #163 stale-tail mask correctly removes the stale keys inside the hole from the key set, but no mask can repair the inflated relative distances. Greedy parity for a row is therefore only guaranteed while the row has accepted the round max in every round so far (the lockstep prefix); after the first sub-max round the row's logits diverge from B=1 structurally and near-tie argmaxes flip.
Evidence
M1 Ultra, gemma-4-31b-it-4bit + gemma-4-31B-it-assistant-bf16, greedy, max_tokens 24, the four equal-length prompts from
greedy_parity_mtp_gemma4_batched_b4_matches_b1and the four variable-length prompts from the new ragged gate:greedy_parity_mtp_gemma4_batched_b4_matches_b1is red on current main on M1 Ultra (row 0 token 4: batched 1852 vs B=1 3161), before any perf: harden ragged B>1 MTP batching (NaN-safe masked rows, per-row valid-length verify mask) #163 change. It was validated byte-identical on an M5 Max when it landed, so the strict gate is hardware-sensitive: which near-tie flips first differs by GPU class, but the hole mechanism is hardware-independent.Proposed Solution
Decouple RoPE position from physical cache slot for the batched MTP verify forward, assigning each row its logical positions (valid_end_r..) for new window tokens, in the spirit of the per-row batched RoPE machinery that already exists for dense batched decode (
fast_rope_batchedper-row offsets). Alternatively (or additionally) compact each row's hole away after rollback so positions stay contiguous. Either approach must keep the drafter-side per-row metadata (kv_offset = left_padding + kv_valid_len anchors) consistent, and the #163 stale-gap mask then becomes redundant-but-harmless for the compacted variant or stays load-bearing for the position-assignment variant.Acceptance Criteria
greedy_parity_mtp_gemma4_batched_b4_matches_b1gate passes on M1 Ultra.Technical Considerations
Both batched MTP paths are opt-in (MLXCEL_ENABLE_MTP_BATCH / MLXCEL_ENABLE_MTP_BATCH_RAGGED default off), so production defaults are unaffected. Priority reflects that this invalidates the lossless-speculation claim of the opt-in feature under mixed accepts.