Skip to content

Feature: Quantized KV cache support in scaled_dot_product_attention (TurboQuant) #3404

@Landon-Molt

Description

@Landon-Molt

Feature Request: Quantized KV Cache Support in mx.fast.scaled_dot_product_attention

Summary

Add native support for reading quantized (TurboQuant codebook) KV cache data in mx.fast.scaled_dot_product_attention, eliminating the need for Python-level custom Metal kernels or full-cache dequantization during attention.

Motivation

TurboQuant (Google Research, ICLR 2026) compresses KV caches to 2-4 bits via rotation + codebook quantization. The MLX ecosystem has adopted it widely:

  • mlx-vlm (Blaizzy/mlx-vlm) — ships TurboQuantKVCache with Metal decode kernels
  • mlx-lm (ml-explore/mlx-lm#1060) — community discussion, TheTom's TurboQuant+ results
  • omlx (jundot/omlx) — continuous-batching server using TQ for KV compression

The problem: mx.fast.scaled_dot_product_attention only accepts fp16/bf16/fp32 tensors. When the KV cache is quantized, there are only two options today:

  1. Dequantize the entire cache to fp16 before attention — works, fast (uses native SDPA), but creates a full fp16 copy of the KV cache in memory (31 GB activation spike at 128K context on gemma-4-31B). Defeats the purpose of quantization at long context.

  2. Custom Metal kernel via mx.fast.metal_kernel — we built a working proof-of-concept that fuses score + online softmax + value accumulation in one dispatch. Correctness is perfect (cosine 1.0 vs dequantize path), memory is bounded. But it's 3-4x slower than native SDPA due to Python dispatch overhead and inability to leverage MLX's internal kernel fusion and memory planning.

Neither option is satisfactory. llama.cpp and vLLM solved this by implementing quantized-cache attention in their C++/CUDA/Metal core — the quantized format is a first-class citizen in the attention kernel. MLX could do the same.

Prior Art in the MLX Ecosystem

  • TheTom's MLX fork (TheTom/mlx:feature/turboquant-plus) — adds mx.fast.scaled_dot_product_attention_qv as a C++ Metal kernel reading 4-bit quantized values. Achieves near-native decode speed. Decode-only (L=1), uses mx.quantize affine format.

  • mlx-vlm _fused_integer_decode_single_tile_kernel — existing Python-generated Metal kernel for TQ decode attention. Fuses score + online softmax + value accumulation. Works well for decode but is invoked via mx.fast.metal_kernel (Python dispatch).

  • Our proof-of-concept (Landon-Molt/mlx-vlm:feat/fused-tq-prefill) — extends the decode kernel to L>1 (prefill) with grid-parallel queries. Validates the algorithm and correctness. Benchmarks below.

Benchmarks (Our Proof-of-Concept)

Single layer, B=1, H=8 KV heads, D=128, GQA 4:1, L=128 queries, MacBook Pro M5 Max 128GB:

KV tokens Fused TQ (Python Metal) Dequantize + native SDPA Ratio
2K 0.058s 0.003s 17.9x
8K 0.025s 0.009s 2.9x
32K 0.097s 0.029s 3.4x
64K 0.191s 0.056s 3.4x

The 3-4x gap at long context is entirely from Python→Metal dispatch overhead vs C++ native. The Metal kernel itself is efficient — at short context where dispatch dominates less, the gap narrows to 2.9x.

For a 60-layer model at 32K context, this translates to ~5.8s attention-only (vs ~1.7s native) — the difference between responsive and sluggish prefill.

Proposed Approach

Extend the existing SDPA Metal kernels to read TurboQuant codebook-quantized KV data:

Decode (L=1) — extend sdpa_vector.h:

  • Score: unpack key codebook indices inline, dot product with query in rotated space
  • Value: unpack value codebook indices inline during weighted accumulation
  • Same online softmax, same threading model, same simdgroup reduction

Prefill (L>1) — extend steel flash attention kernels:

  • Grid-parallel over query positions (each threadgroup handles one query)
  • Tile over KV tokens with online softmax (same as current flash attention)
  • Inline codebook dequantization per element during score and value phases

Quantized state format (from mlx-vlm's TurboQuantKVCache):

  • Keys: TurboQuantProdStatenorms (fp16, per-token), mse_indices (packed uint32), residual_norms (fp16), qjl_signs (packed uint32), + codebook (small, constant)
  • Values: TurboQuantMSEStatenorms (fp16, per-token), indices (packed uint32), + codebook (small, constant)
  • Dequant per element: codebook[unpack_bits(packed, d)] * norm[t]

API Options

Option A: New function

mx.fast.scaled_dot_product_attention_tq(
    queries, keys_state, values_state,
    key_codebook, value_codebook,
    scale=..., mask=...
)

Option B: Extend existing SDPA to detect quantized inputs

# If keys/values are QuantizedStateProxy objects, dispatch to TQ kernel
mx.fast.scaled_dot_product_attention(queries, keys, values, scale=..., mask=...)

Related Issues

  • mlx-vlm#1016 — root cause: prefill_attention() dead after ProdCodec removal
  • mlx-vlm#939 — tiled prefill (2 dispatches, 17x slower)
  • mlx-lm#1060 — TurboQuant community discussion
  • #3302 — GPU watchdog on long-context SDPA (related: long prefill)
  • #3361 — SDPA fix for >32K KV (actively maintained area)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions