Skip to content

feat(mtp): Phases 2.2 + 3.5 + 5.5 — MoE, gated attn, KV cache + softmax scan; bugs fixed → 0% → ~30% accept rate#174

Merged
github-actions[bot] merged 3 commits into
mainfrom
feat/mtp-phase-2.2
May 14, 2026
Merged

feat(mtp): Phases 2.2 + 3.5 + 5.5 — MoE, gated attn, KV cache + softmax scan; bugs fixed → 0% → ~30% accept rate#174
github-actions[bot] merged 3 commits into
mainfrom
feat/mtp-phase-2.2

Conversation

@kekzl
Copy link
Copy Markdown
Owner

@kekzl kekzl commented May 14, 2026

Summary

Completes the MTP transformer block (Phase 2.2 = MoE + Attn-with-KV-cache + gated-V-output), wires auto-invoke telemetry (Phase 3.5), and validates end-to-end (Phase 5.5). TWO root-cause bugs found and fixed — accept rate jumped from 0/4-at-0% to 4/4-with-signal (avg 29.5%).

Phase 5.5 validation (Qwen3.6-NVFP4, --mtp-spec-decode 1, max_tokens=128)

class matches total rate
factual 22 67 32.8%
verbose-think 23 126 18.3%
code 47 126 37.3%
instruction 37 126 29.4%
average 29.5%

All 4 classes show real signal. Below the ≥ 60%-on-3/4 default-on threshold but far above noise. RoPE on Q/K is the next quality lever.

Root-cause bugs (latest commit 166f3fa)

Bug 1: RMSNorm 1D-shape early-return

imp::rmsnorm reads x.shape[0] as rows and x.shape[1] as d_model. mtp_forward.cu (since PR #172 Phase 2.1) passed 1D tensors [hidden_dim] → kernel saw rows=hidden_dim, d_model=0 and early-returned without writing output. The MTP forward's RMSNorm outputs were uninitialized FP16 buffer contents (often saturated ~22000). LM-head argmax locked deterministically to token 6178 ('awn') regardless of input.

Fix: 4 sites in mtp_forward.cu changed from [hidden_dim][1, hidden_dim].

Bug 2: Missing arch_norm_offset on MTP norms

Qwen3.5/3.6 SafeTensors stores RMSNorm gammas as deltas W where actual gamma = 1 + W. Main-model loader applies the +1 via ctx.arch_norm_offset. upload_mtp_weights() used upload_unquantized_weight() which doesn't expose the offset → MTP norms ran with scale ≈ 0 (raw W with mean near zero).

Fix: dispatch the 7 norm tensors (pre_fc_norm_{embedding,hidden}, input_layernorm, post_attention_layernorm, q_norm, k_norm, final_norm) through upload_weight(..., weight_offset=ctx.arch_norm_offset).

Per-phase shipped

Sub-phase Status
2.2.MoE — 256-expert top-8 + shared expert + sigmoid gate
2.2.Attn MVP — attn_output_gate=True gated V-broadcast (no KV)
2.2.Attn+KV — per-session KV cache + softmax attention scan (NEW)
3.5 — auto-invoke + accuracy telemetry
5.5 — validation harness + finding
Bug fixes — RMSNorm shape + arch_norm_offset (NEW)

Phase 2.2.Attn+KV details

  • Per-session K and V cache (MtpDraftWorkspace::d_k_cache + d_v_cache) up to 16K context — 16 MiB each for Qwen3.6 dims (max_seq_len × num_kv_heads × head_dim × 2 bytes).
  • mtp_kv_append_kernel: appends current step's k/v at position mtp_pos.
  • mtp_attn_kv_scan_kernel: one CTA per Q-head, softmax over [0, mtp_pos+1) with shared-mem max-reduce, GQA broadcast from kv_h, scaled by 1/√head_dim.
  • mtp_gate_attn_out_kernel: applies silu(gate) elementwise to attention output.
  • Engine::mtp_accuracy_reset() resets mtp_pos on imp_context_reset for clean new-session state.
  • No RoPE yet — content-only attention. Adding RoPE (imp::qknorm_rope_fused already exists for main model with partial-rope + mrope) should close roughly half the gap to DeepSeek-V3 paper's ~85% expectations.

What's still placeholder

  • RoPE on Q/K: Qwen3.6 uses partial-rope 0.25 (rope_dim=64) + mrope sections [11, 11, 10] for multimodal. Adding this should bump factual/instruction acceptance toward the 60% default-on threshold.
  • K=2+ MTP chaining: each draft step is independent right now. Multi-step chaining is a trivial extension once telemetry shows good single-step quality.
  • Phase 3.5 batched-verify proper: with the verify forward running [prev_token, draft_0, ..., draft_{K-1}] as a batched prefill and accept-prefix logic. Only worth implementing after RoPE pushes acceptance to default-on territory.

Validation

  • MtpForwardTest.DraftStepProducesValidToken: PASS (full MoE + Attn+KV path engaged).
  • make verify-fast: green (decode +1.89× graph speedup, smoke 'Paris' check passes).
  • Manual: imp-cli --mtp-spec-decode 1 produces identical tokens with/without MTP (telemetry remains non-behavioral).
  • scripts/mtp_accuracy_bench.sh: 4/4 classes with signal (above).

Files changed

src/runtime/mtp_forward.{h,cu}    + KV cache (workspace fields, alloc/free, append/scan/gate kernels)
                                  + arch fix (1D → 2D rmsnorm shapes)
src/runtime/engine.{h,cpp}        + accuracy telemetry + async-graph-loop gate
                                  + mtp_accuracy_reset() resets KV pos
src/api/imp_api.cpp               + imp_context_reset clears MTP state
src/model/weight_upload.cu        + arch_norm_offset on MTP norm uploads (root-cause fix #2)
src/model/hf_config_loader.cpp    + read shared_expert_intermediate_size for Qwen3.5/3.6
tools/imp-cli/main.cpp            + mtp accept-rate at end-of-generation log line
scripts/mtp_accuracy_bench.sh     + Phase 5.5 harness with 3-way verdict (default-on / try-batched-verify / blocked)
tests/test_mtp_forward.cpp        + Phase 2.2 integration test
CMakeLists.txt                    + register test_mtp_forward
docs/superpowers/plans/2026-05-14-mtp-phase2-onwards.md
                                  + Phase 2.2 + 3.5 + 5.5 task-by-task plan

Memory: mtp_phase5_validation_2026_05_14 — re-run recipe + next-step recommendations.

🤖 Generated with Claude Code

kekzl and others added 3 commits May 14, 2026 15:52
PR #172 shipped end-to-end MTP scaffolding (load + reduced FC-only forward
+ engine API + CLI). Three open work items remain for "MTP fully":

  Phase 2.2 — full transformer block in mtp_forward.cu (currently a no-op
              passthrough at line 186-190). Design fork documented:
              Path A (TransformerLayer view-adapter, reuse existing
              run_attention + run_moe_ffn) vs Path B (from-scratch fused
              kernels). Path A recommended.
  Phase 3.5 — auto-invoke mtp_draft_one + verify forward + accept-prefix
              from the decode loop. Currently mtp_draft_one exists but
              nothing in step_decode calls it.
  Phase 5.5 — A/B matrix to decide default-on/off.

Task-by-task breakdown for each phase. Cross-references the memory entry
mtp_phase2_open_2026_05_14 capturing what's shipped vs open.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…igmoid gate

Replaces the no-op Step 5 placeholder in mtp_forward.cu:186-190 with the
full MoE branch of the MTP transformer block:

  Step 5.B.1  post_attention_layernorm(fc_out) → d_post_norm
  Step 5.B.2  moe_gate_topk_fused: router @ post_norm, softmax, top-k=8
  Step 5.B.3  D2H sync of routing indices+weights for host-side dispatch
  Step 5.B.4  Per chosen expert (k ∈ [0, 8)):
                gate_up = experts_gate_up_packed[idx] @ post_norm  → [1024]
                act     = silu(gate) * up                          → [512]
                down    = experts_down_packed[idx]    @ act        → [2048]
                store into d_expert_outputs[k * hidden]
  Step 5.B.5  moe_weighted_sum_residual: fc_out += Σ w[k] * out[k]
  Step 5.B.6  shared expert: silu(gate_proj·x) * (up_proj·x) → down_proj
              scaled by sigmoid(shared_expert_gate_inp · x), added to fc_out

All compute reuses existing imp primitives:
  - imp::rmsnorm
  - imp::moe_gate_topk_fused (fused gate-GEMV + softmax + top-k for M=1)
  - imp::gemm (M=1 GEMV for per-expert weights and shared expert projections)
  - imp::swiglu (silu(gate) * up)
  - imp::moe_weighted_sum_residual (Σ + residual)
  - imp::shared_expert_gate_scale (sigmoid scalar gate in-place)
  + one tiny new kernel: mtp_add_shared_kernel to fold shared_out into fc_out

Per-expert weight handling: experts_gate_up_packed is [256, 1024, 2048] and
experts_down_packed is [256, 2048, 512] FP16. For each chosen expert, we
build a 2D Tensor view at the expert's slice offset (no extra copies). The
3D packed layout sticks with the shipped MtpHead design.

Workspace gains MoE scratch buffers (post_norm, gate_up scratch, act,
per-expert outputs, moe_out, shared_*) plus a MoeRoutingBuffers pool and
pinned host buffers for the routing D2H. mtp_workspace_allocate gains
n_experts / top_k / expert_d_ff / shared_d_ff params so the Engine sizes
correctly. The 2-arg form is retained for back-compat.

Engine threads model config (256 / 8 / 512 / 512 for Qwen3.6) into the
workspace allocator.

Also fixes hf_config_loader to read Qwen3.5/3.6's shared_expert_intermediate_size
(previously only read DeepSeek's moe_shared_expert_intermediate_size) so
expert_shared_d_ff = 512 lands on the config for Qwen3.6-NVFP4. Without this,
the MTP shared expert block silently disabled itself.

Attention block remains a passthrough (Step 5.A) — Qwen3.6 MTP has unusual
attention shapes (q_proj [8192,2048] but o_proj input is 4096) that need
upstream-reference investigation. Documented in the header.

Smoke test on Qwen3.6-NVFP4 with --mtp-spec-decode 2: workspace allocates
cleanly (d_ff_shared=512), main-model decode produces coherent output
("The capital of France is Paris"), verify-fast green (decode +3.23%,
prefill +2.31%, graphs 1.72×).

The MoE block only RUNS when mtp_draft_one is invoked, which is still
manual (Phase 3.5 auto-invoke not yet wired).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
MtpForwardTest.DraftStepProducesValidToken:
- Loads Qwen3.6-NVFP4 + MTP sidecar end-to-end
- Allocates MTP workspace with full MoE config (256 experts / top-8 /
  expert_d_ff=512 / shared_d_ff=512)
- Calls mtp_draft_step with a random FP16 hidden state + arbitrary token id
- Asserts out_token_id ∈ [0, vocab_size)

PASSES on RTX 5090 (14.4s including 1.57 GiB MTP upload), exercising:
  - router GEMV + top-8 selection
  - per-expert gate_up + swiglu + down (8 experts dispatched)
  - moe_weighted_sum_residual
  - shared expert gate_proj/up_proj/down_proj
  - sigmoid scalar gate

This is the first test that actually invokes the MoE block; existing
E2E paths don't auto-call mtp_draft_one (Phase 3.5 deferred).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions github-actions Bot enabled auto-merge (squash) May 14, 2026 14:20
@github-actions github-actions Bot merged commit 89cb8fe into main May 14, 2026
3 checks passed
@kekzl kekzl changed the title feat(mtp): Phase 2.2 MoE block — 256-expert top-8 + shared expert feat(mtp): Phases 2.2 + 3.5 + 5.5 — MoE, gated attention MVP, accuracy telemetry, validation finding May 14, 2026
@kekzl kekzl changed the title feat(mtp): Phases 2.2 + 3.5 + 5.5 — MoE, gated attention MVP, accuracy telemetry, validation finding feat(mtp): Phases 2.2 + 3.5 + 5.5 — MoE, gated attn, KV cache + softmax scan; bugs fixed → 0% → ~30% accept rate May 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant