Skip to content

feat: Qwen3-Coder-Next support — 8-bit gate dequant + fused projection deinterleave#13

Closed
userFRM wants to merge 4 commits intodanveloper:mainfrom
userFRM:feat/coder-next-port
Closed

feat: Qwen3-Coder-Next support — 8-bit gate dequant + fused projection deinterleave#13
userFRM wants to merge 4 commits intodanveloper:mainfrom
userFRM:feat/coder-next-port

Conversation

@userFRM
Copy link
Copy Markdown

@userFRM userFRM commented Mar 23, 2026

Summary

Adds support for running Qwen3-Coder-Next (78B MoE, 48 layers, 512 experts) on flash-moe. Also fixes a class of bugs affecting any MLX model with mixed-precision quantization.

Key Findings

1. Mixed-Precision Gate Quantization (affects all MLX models)

MLX 4-bit quantized models use 8-bit quantization for routing gates (mlp.gate, mlp.shared_expert_gate), specified per-tensor in config.json:

"quantization": {
    "bits": 4, "group_size": 64,
    "model.layers.0.mlp.gate": {"bits": 8},
    "model.layers.0.mlp.shared_expert_gate": {"bits": 8}
}

The 4-bit dequant kernel reads 8 nibbles per uint32 but these tensors pack 4 bytes per uint32. This corrupts routing scores, selecting wrong experts. Fix: new dequant_matvec_8bit Metal kernel + CPU fallback, with bits field in BatchMatvecSpec for per-tensor dispatch.

2. Fused Linear Attention Projections (Coder-Next specific)

Coder-Next uses fused projections (in_proj_qkvz, in_proj_ba) instead of Qwen3.5-397B's separate projections (in_proj_qkv, in_proj_z, in_proj_b, in_proj_a). The fused output is interleaved per k-head:

[h0_q(128), h0_k(128), h0_v(256), h0_z(256), h1_q(128), ...]

Not flat [all_Q, all_K, all_V, all_Z]. Fix: per-k-head deinterleave after projection matvec, with automatic detection of fused vs separate naming.

3. Verified Against MLX Reference

  • CPU-only path produces correct output verified against mlx-lm (cosine similarity 0.999999 on embedding, matching gate scores)
  • GPU pipeline has a remaining buffer dimension issue in CMD2 (tracked separately)

Changes

  • shaders.metal: Added dequant_matvec_8bit kernel
  • infer.m: 8-bit pipeline, fused projection detection + deinterleave, chat template wrapping, Coder-Next model constants
  • repack_coder_next.py: Expert repacker for switch_mlp stacked tensor format

Model Constants

Parameter Qwen3.5-397B Qwen3-Coder-Next
hidden_size 4096 2048
num_layers 60 48
num_experts 512 512
expert_intermediate 1024 512
expert_size (4-bit) 7.08 MB 1.77 MB
total on disk 209 GB 44 GB

Status

  • 8-bit gate dequant (Metal + CPU)
  • Fused projection detection and deinterleave
  • Chat template wrapping
  • Expert repacking from switch_mlp format
  • CPU-only path verified correct
  • GPU pipeline CMD2 buffer dimension fix (WIP)

Test

# CPU-only path (correct output, ~0.3 tok/s)
./infer --model <path> --prompt "What is 2+2?" --tokens 20 --k 4
# With g_metal=NULL forced in code

userFRM and others added 4 commits March 22, 2026 23:04
Delta-net: register-resident state + loop fusion (3→2 loops, 60% fewer device mem ops)
Norm kernels: SIMD parallel reduction replacing serial thread-0 loop (18x faster)
down_proj: v3_small kernel with 4KB threadgroup memory (4x GPU occupancy, -6.5%)
Routing: partial softmax — 4 exp() instead of 512 (mathematically identical)
IO pool: atomic counter + AArch64 WFE replacing pthread_cond_wait (~300µs/token)

All changes produce bit-identical output. No precision or quality tradeoffs.

Benchmarked on M1 Pro 8-core GPU:
- v3_small down_proj: 159.5µs vs 170.5µs baseline (-6.5%)
- Delta-net kernel: register-resident with fused decay+dot+update+output
- Partial softmax: 128x fewer exp() calls per layer

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ns for kernel safety

- Add SEV instruction after atomic_fetch_add in IO worker to reliably wake WFE spinner
- Add _Static_assert for MOE_INTERMEDIATE <= 1024 (v3_small kernel guard)
- Add _Static_assert for LINEAR_KEY_DIM == 128 (SIMD reduction assumption)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Modifies expert selection to prefer experts likely in OS page cache,
with bounded quality degradation controlled by --cache-tolerance.

Algorithm:
1. Standard topK on raw gate logits (unchanged)
2. Partition top-K into cached/uncached based on LRU access tracking
3. For uncached slots (weakest first), substitute with best cached
   expert whose score is within tolerance of the evicted expert

Safety:
- K clamped to MAX_K to prevent stack overflow
- K=1 uses absolute tolerance fallback (not relative to zero range)
- Server mode preserves cache state across requests (OS page cache persists)
- Zero substitutions when tolerance=0 or all top-K already cached

Estimated impact: +20-30% tok/s with real expert data (cache misses
currently dominate at 56% of per-layer time).

Addresses audit findings from Codex/Gemini/Kimi multi-model review:
- Fixed stack overflow when K > MAX_K
- Removed unused --cache-bonus dead code
- Fixed K=1 tolerance becoming zero
- Removed per-request cache reset in serve mode

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds runtime expert remap table for co-activation disk clustering.
When .map files exist (generated by cluster_experts.py), expert file
offsets are translated through a per-layer uint16[512] lookup table.
Zero overhead when no .map files present (identity mapping).

Also sorts pread tasks by file offset before dispatch, ensuring
sequential I/O order when experts are physically adjacent.

Measured on M1 Pro (cold SSD, F_NOCACHE):
- 4 scattered preads: 8.76ms (3.2 GB/s)
- 4 adjacent preads:  6.33ms (4.4 GB/s)  <- 38% faster

The clustering tool (cluster_experts.py) requires a routing log:
  ./infer --collect-routing routing.bin --tokens 200
  python3 cluster_experts.py --routing routing.bin --packed-dir packed_experts

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@userFRM
Copy link
Copy Markdown
Author

userFRM commented Mar 23, 2026

Closing — splitting into separate PRs. The 8-bit gate dequant fix will be submitted standalone.

@userFRM userFRM closed this Mar 23, 2026
rrr3try pushed a commit to Graf-RAGov/flash-moe-mlx that referenced this pull request Apr 17, 2026
Upstream + fork + issue context compiled for the port effort: PR diffs
(danveloper#3 runtime config, danveloper#11 perf wins, danveloper#13 Qwen3-Coder-Next, danveloper#14 8-bit dequant),
fork summaries (nerds-odd-e, gorroai), issue captures (danveloper#15 setup gotchas,
danveloper#17 expert_index scope bug, danveloper#20 other Qwen models), target architecture
spec (qwen3.6-35b-a3b-arch.md), hardcoded-constants map of upstream
flash-moe, condensed port plan. Plus benchmark results, parallelism
exploration, 10x optimization ideas.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant