From e4838046f311076c2746f0e0266757f753ae1e18 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 19 Nov 2025 09:44:04 +0200 Subject: [PATCH] llama : update worst-case graph for unified cache --- examples/embedding/embedding.cpp | 7 ++++--- src/llama-context.cpp | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9e3ab5905bb37..fe91b308cdc0a 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -104,12 +104,16 @@ int main(int argc, char ** argv) { params.embedding = true; + // get max number of sequences per batch + const int n_seq_max = llama_max_parallel_sequences(); + // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache // in order to support any number of prompts if (params.n_parallel == 1) { LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__); params.kv_unified = true; + params.n_parallel = n_seq_max; } // utilize the full context @@ -123,9 +127,6 @@ int main(int argc, char ** argv) { params.n_ubatch = params.n_batch; } - // get max number of sequences per batch - const int n_seq_max = llama_max_parallel_sequences(); - llama_backend_init(); llama_numa_init(params.numa); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 70a3ec62dfc63..2f6cd7e2a0ee8 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -299,7 +299,7 @@ llama_context::llama_context( cross.v_embd.clear(); - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); // avoid reserving graphs with zero outputs - assume one output per sequence @@ -542,7 +542,7 @@ bool llama_context::memory_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());