feat: MLA absorption for DeepSeek V2/V3 — fuse low-rank Q/K/V into standard dense tensors#96
Conversation
|
Hey @mvkorobkov — same situation as #103: branch is conflicting against current Could you rebase against current main? If you'd rather, I can also cherry-pick onto a fresh branch with your attribution preserved on the commits. Let me know which you prefer. Once rebased, I'll do a proper review of the gqa_attention_asym kernel + the DeepSeek geometry plumbing. |
DS-V3 absorbed attention has qk_head_dim=192 (nope=128+rope=64) but v_head_dim=128. The existing gqa_attention uses a single head_dim for all projections, which would corrupt V slicing and output shape. gqa_attention_asym accepts separate qk_head_dim and v_head_dim: - Q/K sliced with qk_head_dim (dot-product stays in the larger space) - V sliced and output written with v_head_dim - Returns (seq, num_q * v_head_dim) When qk_head_dim == v_head_dim the function is numerically identical to gqa_attention (verified by asym_sym_equivalence_when_dims_equal test). 4 tests added: shape, finiteness, sym-equivalence, seq=1 causal. Note: gqa kernels live in larql-compute (post-ADR-0022 Step 2d); this commit places the asym variant alongside the existing gqa_attention there.
Three new optional fields on ModelConfig: qk_nope_head_dim — non-RoPE part of Q/K head dim (DS-V3: 128) qk_rope_head_dim — RoPE-rotated part of Q/K head dim (DS-V3: 64) v_head_dim — V projection head dim (DS-V3: 128) Parsed from config.json (qk_nope_head_dim / qk_rope_head_dim / v_head_dim). Trait accessors added to ModelArchitecture with None defaults. DeepSeekArch overrides to read from config. DS-V3 detection test extended to verify all three fields round-trip. Two GGUF test-only ModelConfig literals updated to include None stubs.
…eight matrices Implements `mla_absorb::absorb()` which converts the four MLA weight matrices (kv_a, kv_b, q_a, q_b) into standard dense Q/K/V tensors compatible with `gqa_attention_asym`. Key correctness points: - rope-K is MQA: single row in kv_a[kv_lora..] replicated num_kv times in absorbed K (not per-head in the input tensor) - DS-V3 native per-head layout [nope|rope] → LARQL convention [rope|nope] applied symmetrically to Q and K during absorption - V: straightforward kv_b[nope+v_hd slice] @ kv_compress Three tests (3 passed): - absorbed_forward_matches_reference: reference MLA forward vs absorbed path through gqa_attention_asym must match within 1e-4 - absorbed_shapes: output tensor dimensions - rope_k_is_broadcast_not_zero: single rope-K correctly replicated across heads
|
Rebased onto current main (810f163). Branch now contains 5 focused MLA commits (656 insertions, 10 deletions):
Dropped two commits that didn't belong here:
|
597d2ca to
2d10daa
Compare
write_model_weights_with_opts now accepts DS-V3 / MLA architectures when all three geometry fields (qk_nope_head_dim, qk_rope_head_dim, v_head_dim) are present in config.json. When detected: - skips the standard-attention guard - per layer: fetches kv_a/kv_b/q_a/q_b projections, calls mla_absorb::absorb, writes the resulting dense Q/K/V under the standard attn_q/k/v key names - O projection is passed through unchanged (no absorption needed) The loader remains MLA-unaware: it reads standard Q/K/V tensors just as for any Llama/Mistral model. The extra storage cost (absorbed K replicates the MQA rope-K row num_kv times) is acceptable for DS-V3 full scale (~3.5 GB extra per 61 layers on num_kv=128). All 971 larql-vindex unit + integration tests pass.
Summary
gqa_attention_asym— new attention kernel inlarql-inferencethat handles asymmetricqk_head_dim/v_head_dim(required for absorbed MLA tensors where Q/K use 192-dim heads but V uses 128-dim heads in DS-V3)ModelConfig—qk_nope_head_dim,qk_rope_head_dim,v_head_dimparsed fromconfig.json;DeepSeekArchexposes them via trait methodsmla_absorb— new module inlarql-vindexthat fuses the four DS-V2/V3 low-rank attention projections (kv_a,kv_b,q_a,q_b) into standard dense Q/K/V weight matriceswrite_model_weights— F32 weight writer now accepts MLA architectures: detects full geometry, runs absorption per layer, writes absorbed Q/K/V under standard key names so the loader needs no MLA awarenessWhy absorption
DS-V2/V3 stores attention as four low-rank matrices. Absorbing them into standard Q/K/V at extraction time means:
Correctness
Key details:
kv_arope-K is MQA (one shared row for all KV heads, not per-head) — replicatednum_kvtimes when building absorbed K[nope | rope]; LARQL convention is[rope | nope]— absorption reorders symmetrically for both Q and Kabsorbed_forward_matches_referencetest: reference MLA forward pass vs absorbed path throughgqa_attention_asymmust agree within 1e-4 (f32 precision)Test plan
cargo test -p larql-inference -- gqa_attention_asym— 4 tests (shape, finite, sym-equivalence, causal)cargo test -p larql-vindex -- mla_absorb— 3 tests (forward equivalence, shapes, rope broadcast)cargo test -p larql-models— existing DS-V3 detection tests extended with new geometry accessorscargo test -p larql-vindex— 971 tests, 0 failures