diff --git a/include/llama.h b/include/llama.h index 98bed9d6150a0..aa9932afb844b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -463,6 +463,7 @@ extern "C" { // NOTE: After creating a llama_context, it is recommended to query the actual values using these functions // In some cases the requested values via llama_context_params may differ from the actual values used by the context + // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 866514038e493..e115fcd933f53 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -114,10 +114,14 @@ llama_context::llama_context( } } + // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); + if (cparams.kv_unified) { cparams.n_ctx_seq = cparams.n_ctx; } else { cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; + cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256); if (cparams.n_ctx_seq == 0) { throw std::runtime_error("n_ctx_seq == 0"); diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index facba1d004012..3a34102a23d08 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -45,7 +45,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad)); + // note: the SWA cache is always padded to 256 for performance + // https://github.com/ggml-org/llama.cpp/issues/17037 + uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256); // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py index 65952de8b8d4c..d2f3fba5fe7a9 100644 --- a/tools/server/tests/unit/test_speculative.py +++ b/tools/server/tests/unit/test_speculative.py @@ -77,10 +77,10 @@ def test_different_draft_min_draft_max(): def test_slot_ctx_not_exceeded(): global server - server.n_ctx = 64 + server.n_ctx = 256 server.start() res = server.make_request("POST", "/completion", data={ - "prompt": "Hello " * 56, + "prompt": "Hello " * 248, "temperature": 0.0, "top_k": 1, "speculative.p_min": 0.0, @@ -91,19 +91,19 @@ def test_slot_ctx_not_exceeded(): def test_with_ctx_shift(): global server - server.n_ctx = 64 + server.n_ctx = 256 server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ - "prompt": "Hello " * 56, + "prompt": "Hello " * 248, "temperature": 0.0, "top_k": 1, - "n_predict": 64, + "n_predict": 256, "speculative.p_min": 0.0, }) assert res.status_code == 200 assert len(res.body["content"]) > 0 - assert res.body["tokens_predicted"] == 64 + assert res.body["tokens_predicted"] == 256 assert res.body["truncated"] == True