feat(hardware): add KV cache memory estimator with 256-token rounding#65
Merged
Conversation
Add kv_cache_bytes() pure function computing per-layer KV reservation using the formula: num_layers × 2 (K+V) × num_kv_heads × head_dim × elem_bytes × round_up(ctx_len, 256) × batch. Context length is rounded up to the next multiple of 256 to match KVCache's step-aligned pre-allocation. Add KvCacheParams struct and kv_cache_bytes_from_params() config-driven wrapper. KvCacheParams.int8_kv controls elem_bytes (1 for INT8, 2 for FP16/BF16), honoring --cache-type-k / --cache-type-v / --kv-cache-mode flags. Replace the flat KV_CACHE_HEADROOM_GB = 2 constant in recommend_quantization() with an optional kv_cache_headroom_bytes parameter. When None, the 2 GiB fallback preserves backward compatibility; when Some, the computed bytes are converted to GiB (ceiling) and used for accurate fit decisions. Existing callers pass None. Wire kv_cache_bytes_from_params into quant_advisor.advise_quantization(): reads num_layers / num_kv_heads / head_dim from config.json (8192-token default context) and passes the result to recommend_quantization, replacing the flat constant path. Also expose kv_cache_bytes as QuantAdvice.kv_cache_bytes for future unified estimator. Acceptance criteria satisfied: - kv_cache_dense_mha: 32L/32H/D128/FP16/8K = 4 GiB - kv_cache_gqa_fewer_kv_heads: 32L/8H/D128/FP16/8K = 1 GiB (GQA) - kv_cache_long_context_128k: 128K ctx = 16 GiB - kv_cache_int8_half_memory: INT8 (elem=1) is exactly half of FP16 (elem=2) - kv_cache_256_token_rounding: ctx=255→256, ctx=257→512, ctx=256→256 - recommend_quant_long_context_tightens_headroom: 8B/24GB: 8K→FP16, 128K→INT4
6501224 to
5ab0fef
Compare
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Replaces the flat
KV_CACHE_HEADROOM_GB = 2constant inrecommend_quantization()with a proper KV-cache memory estimator. The new estimator computes exact bytes from model architecture parameters using the formulanum_layers × 2 × num_kv_heads × head_dim × elem_bytes × round_up(ctx_len, 256) × batch, matching the actual buffer reservation performed byKVCache.What changed
src/lib/mlxcel-core/src/hardware.rs: AddedKV_CACHE_ALLOC_STEP = 256constant,kv_cache_bytes()pure function with 256-token rounding,KvCacheParamsstruct,kv_cache_bytes_from_params()config-driven wrapper, and updatedrecommend_quantization()to acceptOption<u64>KV headroom instead of the hardcoded constant. All existing callers passNonefor backward compatibility.src/execution/quant_advisor.rs: Updated import to include new types, addedestimate_kv_cache_bytes_from_path()/estimate_kv_cache_bytes_from_config()helpers that extractnum_hidden_layers,num_key_value_heads,hidden_size,num_attention_headsfromconfig.json, and wired them intoadvise_quantization()to supply computed KV headroom. Addedkv_cache_bytes: Option<u64>toQuantAdvicefor the future unified estimator.Test plan
cargo test --lib -p mlxcel-core hardware::tests— all 24 tests pass, covering: dense MHA, GQA (kv_heads < heads), long context (128K), INT8 KV (half bytes), 256-token rounding edge cases (ctx=1, 255, 256, 257), andrecommend_quantizationcorrectness with both flat and computed headroomcargo test --lib -p mlxcel quant_advisor::tests— all 6 tests passcargo clippy --lib --tests -p mlxcel-core -- -D warnings— cleancargo check --lib --tests -p mlxcel— cleanCloses #54