From 8d518d7c0e05e82eb9dff9abb398ed83013318cc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 2 Oct 2025 18:03:52 +0300 Subject: [PATCH 01/17] minor : code style --- common/chat.h | 6 ++--- src/llama-kv-cache.cpp | 7 ++---- tools/server/server.cpp | 12 ++++++---- tools/server/utils.hpp | 52 ++++++++++++++++++++++++----------------- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/common/chat.h b/common/chat.h index a1afe574bd0ca..f7b36ec711df4 100644 --- a/common/chat.h +++ b/common/chat.h @@ -33,8 +33,8 @@ struct common_chat_msg_content_part { struct common_chat_msg { std::string role; std::string content; - std::vector content_parts = {}; - std::vector tool_calls = {}; + std::vector content_parts; + std::vector tool_calls; std::string reasoning_content; std::string tool_name; std::string tool_call_id; @@ -44,7 +44,7 @@ struct common_chat_msg { bool empty() const { return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); } - void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { + void set_tool_call_ids(std::vector & ids_cache, const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { if (ids_cache.size() <= i) { auto id = tool_calls[i].id; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 816f2d5de592b..736693e174527 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index de6e1a322b2c2..f2d1a7971a7a4 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1619,7 +1619,7 @@ struct server_slot { /* is_partial= */ stop != STOP_TYPE_EOS, params.oaicompat_chat_syntax); if (!new_msg.empty()) { - new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); + new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); chat_msg = new_msg; diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); } @@ -2749,7 +2749,7 @@ struct server_context { } // if multimodal is enabled, send an error and return false - bool ensure_no_mtmd(const int id_task) { + bool check_no_mtmd(const int id_task) { if (mctx) { send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); return false; @@ -3121,7 +3121,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_SAVE: { - if (!ensure_no_mtmd(task.id)) { + if (!check_no_mtmd(task.id)) { break; } @@ -3162,7 +3162,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { - if (!ensure_no_mtmd(task.id)) break; + if (!check_no_mtmd(task.id)) break; int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -3209,7 +3209,9 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { - if (!ensure_no_mtmd(task.id)) break; + if (!check_no_mtmd(task.id)) { + break; + } int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 4ca1423aaf2d4..8ee3042e2c7de 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1144,9 +1144,8 @@ struct server_tokens { auto it = map_pos_to_media.find(pos); if (it != map_pos_to_media.end()) { return it->second; - } else { - throw std::runtime_error("Chunk not found"); } + throw std::runtime_error("Chunk not found"); } void push_back(llama_token tok) { @@ -1170,7 +1169,7 @@ struct server_tokens { map_pos_to_media[start_pos] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { size_t n_tokens; - auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); for (size_t i = 0; i < n_tokens; ++i) { push_back(text_tokens[i]); } @@ -1190,7 +1189,7 @@ struct server_tokens { // We could also just check, but this will prevent silently dropping MTMD data. GGML_ASSERT(has_mtmd); for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) { - auto chunk = tokens.map_pos_to_media[it->first].get(); + auto * chunk = tokens.map_pos_to_media[it->first].get(); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); map_pos_to_media[start_pos+it->first] = std::move(new_chunk); } @@ -1271,33 +1270,42 @@ struct server_tokens { } size_t get_common_prefix(const server_tokens & b) const { - size_t max_idx = std::min(tokens.size(), b.tokens.size()); + const size_t max_idx = std::min(tokens.size(), b.tokens.size()); + for (size_t i = 0; i < max_idx; ++i) { - auto & ai = tokens[i]; - auto & bi = b.tokens[i]; + const llama_token ai = tokens[i]; + const llama_token bi = b.tokens[i]; if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { GGML_ASSERT(has_mtmd); + const auto & a_chunk = find_chunk(i); const auto & b_chunk = b.find_chunk(i); + GGML_ASSERT(a_chunk && b_chunk); - std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); - std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); - size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); - size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); - if (ai_id == bi_id && a_pos == b_pos) { - GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen - i += a_pos - 1; // will be +1 by the for loop + + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); + + const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get()); + const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get()); + + if (id_ai == id_bi && pos_a == pos_b) { + GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen + i += pos_a - 1; // will be +1 by the for loop continue; - } else { - return i; } - } else if (ai == bi) { - continue; - } else { + return i; } + + if (ai == bi) { + continue; + } + + return i; } + return max_idx; // all tokens are equal } @@ -1308,7 +1316,7 @@ struct server_tokens { const int32_t n_vocab = llama_vocab_n_tokens(vocab); for (size_t i = 0; i < tokens.size(); ++i) { - auto & t = tokens[i]; + const auto & t = tokens[i]; if (t == LLAMA_TOKEN_NULL) { try { const auto & chunk = find_chunk(i); @@ -1330,8 +1338,8 @@ struct server_tokens { mtmd_context * mctx, llama_pos n_past, int32_t seq_id, - llama_pos & n_pos_out) { - auto & chunk = find_chunk(n_past); + llama_pos & n_pos_out) const { + const auto & chunk = find_chunk(n_past); const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; SRV_INF("processing %s...\n", name); From ca01e7f1c037769b3e40f643d6899da7a6ecd8f7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 2 Oct 2025 18:04:11 +0300 Subject: [PATCH 02/17] server : fix prompt similarity calculation --- tools/server/server.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index f2d1a7971a7a4..19d552fcc5942 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2366,8 +2366,8 @@ struct server_context { // length of the Longest Common Subsequence between the current slot's prompt and the input prompt int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); - // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + // fraction of the common subsequence length + float cur_similarity = float(cur_lcs_len) / task.prompt_tokens.size(); // select the current slot if the criteria match if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { From 668a436ee356a69e21893cd89a67b305503860c7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 2 Oct 2025 20:08:58 +0300 Subject: [PATCH 03/17] server : initial host-memory prompt caching --- tools/server/server.cpp | 314 ++++++++++++++++++++++++++++++++-------- tools/server/utils.hpp | 3 +- 2 files changed, 255 insertions(+), 62 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 19d552fcc5942..850006d2b4ab8 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -273,7 +273,8 @@ struct server_task { // used by SERVER_TASK_TYPE_INFERENCE slot_params params; server_tokens prompt_tokens; - int id_selected_slot = -1; + + int id_slot = -1; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE struct slot_action { @@ -764,13 +765,6 @@ struct completion_token_output { } }; -struct ctx_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - std::vector data; -}; - struct server_task_result_cmpl_final : server_task_result { int index = 0; @@ -1404,6 +1398,70 @@ struct server_task_result_apply_lora : server_task_result { } }; +struct ctx_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + +struct server_slot_prompt_state { + server_tokens tokens; + + std::vector data; + + std::list checkpoints; + + size_t size() const { + size_t res = data.size(); + + for (const auto & checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; + } + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_slot; + +struct server_prompt_cache { + std::list states; + + size_t limit_size = 0; // 0 = no limit + + size_t size() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.size(); + } + + return res; + } + + int n_tokens() const { + int res = 0; + + for (const auto & state : states) { + res += state.n_tokens(); + } + + return res; + } + + void save(server_slot & slot); + void load(server_slot & slot, const server_tokens & prompt); +}; + struct server_slot { int id; int id_task = -1; @@ -1454,21 +1512,22 @@ struct server_slot { std::string generated_text; llama_tokens generated_tokens; - common_chat_msg chat_msg; - server_tokens cache_tokens; + common_chat_msg chat_msg; std::vector generated_token_probs; - std::vector ctx_checkpoints; - bool has_next_token = true; bool has_new_line = false; bool truncated = false; + stop_type stop; std::string stopping_word; + // state + server_slot_prompt_state prompt_state; + // sampling json json_schema; @@ -1722,6 +1781,116 @@ struct server_slot { } }; +void server_prompt_cache::save(server_slot & slot) { + auto & state = slot.prompt_state; + + assert(state.data.size() == 0); + + // first check if the current state is contained fully in the cache + for (auto it = states.begin(); it != states.end(); ++it) { + const auto & cached_prompt = it->tokens; + + const int cur_lcs_len = cached_prompt.get_common_prefix(state.tokens); + + if (cur_lcs_len == (int) state.tokens.size()) { + SRV_INF("%s", " - prompt is already cached, skipping\n"); + return; + } + } + + // next, remove any cached prompts that are fully contained in the current prompt + for (auto it = states.begin(); it != states.end();) { + const auto & cached_prompt = it->tokens; + + const int len = cached_prompt.get_common_prefix(state.tokens); + + if (len == (int) cached_prompt.size()) { + SRV_INF(" - removing cached prompt with length %d\n", len); + + it = states.erase(it); + } else { + ++it; + } + } + + const size_t cur_size = llama_state_seq_get_size_ext(slot.ctx, slot.id, 0); + + SRV_INF(" - saving prompt with length %d, total cache size = %.3f MiB\n", + (int) state.tokens.size(), cur_size / (1024.0 * 1024.0)); + + // if there is a limit, remove the oldest entries to make room + if (limit_size > 0) { + while (size() + cur_size > limit_size) { + if (states.empty()) { + break; + } + + states.pop_front(); + } + } else { + // else, make sure the number of cached tokens doesn't exceed the context size of the slot + while (n_tokens() + (int) state.tokens.size() > slot.n_ctx) { + if (states.empty()) { + break; + } + + states.pop_front(); + } + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto & cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(state.tokens.get_text_tokens(), false), + /*.data =*/ std::vector(cur_size), + /*.checkpoints =*/ state.checkpoints, + }; + + llama_state_seq_get_data_ext(slot.ctx, cur.data.data(), cur_size, slot.id, 0); + + SRV_INF(" - cache state: %zu prompts, %.3f MiB\n", states.size(), size() / (1024.0 * 1024.0)); +} + +void server_prompt_cache::load(server_slot & slot, const server_tokens & prompt) { + auto & state = slot.prompt_state; + + int lcs_len = state.tokens.get_common_prefix(prompt); + + SRV_INF(" - looking for better prompt, base lcs_len = %d\n", lcs_len); + + auto it_best = states.end(); + + // find the most similar cached prompt + for (auto it = states.begin(); it != states.end(); ++it) { + const auto & cached_prompt = it->tokens; + + const int cur_lcs_len = cached_prompt.get_common_prefix(prompt); + + if (lcs_len < cur_lcs_len) { + lcs_len = cur_lcs_len; + it_best = it; + } + } + + if (it_best != states.end()) { + SRV_INF(" - found better prompt with lcs_len = %d\n", lcs_len); + + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data_ext(slot.ctx, it_best->data.data(), size, slot.id, 0); + if (n != size) { + SLT_WRN(slot, "failed to restore slot state with size %zu\n", size); + return; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + state = std::move(*it_best); + + states.erase(it_best); + } +} + struct server_metrics { int64_t t_start = 0; @@ -2114,6 +2283,8 @@ struct server_context { server_queue queue_tasks; server_response queue_results; + server_prompt_cache prompt_cache; + server_metrics metrics; // Necessary similarity of prompt for slot selection @@ -2270,7 +2441,7 @@ struct server_context { slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; slot.mctx = mctx; - slot.cache_tokens.has_mtmd = mctx != nullptr; + slot.prompt_state.tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2347,6 +2518,8 @@ struct server_context { server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; + bool update_cache = false; + // find the slot that has at least n% prompt similarity if (ret == nullptr && slot_prompt_similarity != 0.0f) { int lcs_len = 0; @@ -2359,26 +2532,33 @@ struct server_context { } // skip the slot if it does not contains cached tokens - if (slot.cache_tokens.empty()) { + if (slot.prompt_state.tokens.empty()) { continue; } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); + const int cur_lcs_len = slot.prompt_state.tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length - float cur_similarity = float(cur_lcs_len) / task.prompt_tokens.size(); + const float cur_similarity = float(cur_lcs_len) / task.prompt_tokens.size(); // select the current slot if the criteria match if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { - lcs_len = cur_lcs_len; + lcs_len = cur_lcs_len; similarity = cur_similarity; + ret = &slot; } } if (ret != nullptr) { SLT_INF(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %.3f (> %.3f thold)\n", lcs_len, similarity, slot_prompt_similarity); + + // if we are going to lose a large portion of the existing context - save it in the prompt cache + const float f_keep = float(lcs_len) / ret->prompt_state.tokens.size(); + if (f_keep < 0.5f) { + update_cache = true; + } } } @@ -2401,9 +2581,19 @@ struct server_context { if (ret != nullptr) { SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); + + update_cache = true; } } + // TODO: mtmd does not support prompt cache + if (update_cache && ret->prompt_state.tokens.size() > 0 && !ret->prompt_state.tokens.has_mtmd) { + SRV_INF("%s", "updating prompt cache\n"); + + prompt_cache.save(*ret); + prompt_cache.load(*ret, task.prompt_tokens); + } + return ret; } @@ -2419,7 +2609,7 @@ struct server_context { // if lora has changed, check to see if the cache should be cleared if (lora_should_clear_cache(slot.lora, slot.params.lora)) { SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size()); - slot.cache_tokens.clear(); + slot.prompt_state.tokens.clear(); } else { SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size()); } @@ -2767,7 +2957,7 @@ struct server_context { res->is_progress = true; res->progress.total = slot.n_prompt_tokens; res->progress.cache = slot.n_prompt_tokens_cache; - res->progress.processed = slot.cache_tokens.size(); + res->progress.processed = slot.prompt_state.tokens.size(); res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000); } else { res->content = tkn.text_to_send; @@ -3034,7 +3224,7 @@ struct server_context { case SERVER_TASK_TYPE_EMBEDDING: case SERVER_TASK_TYPE_RERANK: { - const int id_slot = task.id_selected_slot; + const int id_slot = task.id_slot; server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); @@ -3138,13 +3328,13 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.size(); + const size_t token_count = slot->prompt_state.tokens.size(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); + const llama_tokens & tokens = slot->prompt_state.tokens.get_text_tokens(); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); @@ -3186,13 +3376,13 @@ struct server_context { size_t token_count = 0; size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->cache_tokens.clear(); // KV may already been invalidated? + slot->prompt_state.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } tokens.resize(token_count); - slot->cache_tokens.clear(); - slot->cache_tokens.insert(tokens); + slot->prompt_state.tokens.clear(); + slot->prompt_state.tokens.insert(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -3226,9 +3416,9 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); + const size_t n_erased = slot->prompt_state.tokens.size(); llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); - slot->cache_tokens.clear(); + slot->prompt_state.tokens.clear(); auto res = std::make_unique(); res->id = task.id; @@ -3307,14 +3497,14 @@ struct server_context { // add generated tokens to cache { - llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + llama_tokens new_tokens = slot.prompt_state.tokens.get_text_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize(slot.cache_tokens.size() - n_discard); - slot.cache_tokens.clear(); - slot.cache_tokens.insert(new_tokens); + new_tokens.resize(slot.prompt_state.tokens.size() - n_discard); + slot.prompt_state.tokens.clear(); + slot.prompt_state.tokens.insert(new_tokens); } slot.n_past -= n_discard; @@ -3351,10 +3541,10 @@ struct server_context { common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; - slot.cache_tokens.push_back(slot.sampled); + slot.prompt_state.tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); + slot.n_ctx, slot.n_past, (int) slot.prompt_state.tokens.size(), slot.truncated); } // process in chunks of params.n_batch @@ -3482,7 +3672,7 @@ struct server_context { if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + slot.n_past = slot.prompt_state.tokens.get_common_prefix(prompt_tokens); // if there is an alora invoked, don't cache after the invocation start if (slot.alora_invocation_start >= 0) { @@ -3502,13 +3692,13 @@ struct server_context { SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.cache_tokens.size() && + while (head_c < slot.prompt_state.tokens.size() && head_p < prompt_tokens.size()) { size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.size() && + while (head_c + n_match < slot.prompt_state.tokens.size() && head_p + n_match < prompt_tokens.size() && - slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + slot.prompt_state.tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { n_match++; } @@ -3525,7 +3715,7 @@ struct server_context { llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); + slot.prompt_state.tokens.set_token(head_p + i, slot.prompt_state.tokens[head_c + i]); slot.n_past++; } @@ -3549,28 +3739,27 @@ struct server_context { // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, slot.n_past - n_swa); - if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { + if (slot.n_past > 0 && slot.n_past < (int) slot.prompt_state.tokens.size()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); + SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt_state.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } if (pos_min > pos_min_thold) { - SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt_state.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint const auto it = std::find_if( - slot.ctx_checkpoints.rbegin(), - slot.ctx_checkpoints.rend(), + slot.prompt_state.checkpoints.rbegin(), + slot.prompt_state.checkpoints.rend(), [&](const auto & cur) { // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] return cur.pos_min < pos_min_thold; } ); - bool do_reset = it == slot.ctx_checkpoints.rend(); - //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false"); + bool do_reset = it == slot.prompt_state.checkpoints.rend(); if (!do_reset) { // restore the context checkpoint @@ -3597,11 +3786,13 @@ struct server_context { { // erase any checkpoints with pos_min > pos_min_thold - for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) { - const auto & cur = slot.ctx_checkpoints[i]; + for (auto it = slot.prompt_state.checkpoints.begin(); it != slot.prompt_state.checkpoints.end();) { + const auto & cur = *it; if (cur.pos_min > pos_min_thold) { SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i); + it = slot.prompt_state.checkpoints.erase(it); + } else { + ++it; } } } @@ -3638,7 +3829,7 @@ struct server_context { SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past); // remove the non-common part from the cache - slot.cache_tokens.keep_first(slot.n_past); + slot.prompt_state.tokens.keep_first(slot.n_past); // check if we should process the image if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { @@ -3657,7 +3848,7 @@ struct server_context { // add the image chunk to cache { const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); - slot.cache_tokens.push_back(chunk.get()); // copy + slot.prompt_state.tokens.push_back(chunk.get()); // copy } slot.n_past += n_pos; @@ -3712,7 +3903,7 @@ struct server_context { const bool need_embd = server_task_type_need_embd(slot.task_type); common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); - slot.cache_tokens.push_back(cur_tok); + slot.prompt_state.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; slot.n_past++; @@ -3759,21 +3950,22 @@ struct server_context { do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64); + do_checkpoint = do_checkpoint && (slot.prompt_state.checkpoints.empty() || pos_max > slot.prompt_state.checkpoints.back().pos_max + 64); if (do_checkpoint) { - while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + while (slot.prompt_state.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { // make room for the new checkpoint, if needed - const auto & cur = slot.ctx_checkpoints.front(); + const auto & cur = slot.prompt_state.checkpoints.front(); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); + slot.prompt_state.checkpoints.erase(slot.prompt_state.checkpoints.begin()); } const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ + auto & cur = slot.prompt_state.checkpoints.emplace_back(ctx_checkpoint{ /*.pos_min = */ pos_min, /*.pos_max = */ pos_max, /*.data = */ std::vector(checkpoint_size), @@ -3782,7 +3974,7 @@ struct server_context { llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + (int) slot.prompt_state.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } } } @@ -3991,7 +4183,7 @@ struct server_context { params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + const llama_tokens & cached_text_tokens = slot.prompt_state.tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // ignore small drafts @@ -4025,8 +4217,8 @@ struct server_context { // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - slot.cache_tokens.push_back(id); - slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); + slot.prompt_state.tokens.push_back(id); + slot.prompt_state.tokens.insert({ids.begin(), ids.end() - 1}); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); @@ -4657,7 +4849,7 @@ int main(int argc, char ** argv) { ctx_server.ctx, ctx_server.params_base, data); - task.id_selected_slot = json_value(data, "id_slot", -1); + task.id_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.oaicompat = oaicompat; diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 8ee3042e2c7de..3e6974df8d442 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1102,6 +1102,7 @@ struct server_tokens { ~server_tokens() = default; // Prevent copying + // TODO: server_tokens should be copyable - remove this: server_tokens(const server_tokens&) = delete; server_tokens& operator=(const server_tokens&) = delete; @@ -1119,7 +1120,7 @@ struct server_tokens { } } - server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} // for debugging std::string str() const { From 3234723004f52669be2df7fe7a6fcac0a23db543 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 12:50:49 +0300 Subject: [PATCH 04/17] cont --- tools/server/server.cpp | 47 +++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 850006d2b4ab8..60677bca99e51 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2522,8 +2522,8 @@ struct server_context { // find the slot that has at least n% prompt similarity if (ret == nullptr && slot_prompt_similarity != 0.0f) { - int lcs_len = 0; - float similarity = 0; + int lcs_len_best = 0; + float sim_best = 0; for (server_slot & slot : slots) { // skip the slot if it is not available @@ -2531,31 +2531,33 @@ struct server_context { continue; } + const auto & tokens = slot.prompt_state.tokens; + // skip the slot if it does not contains cached tokens - if (slot.prompt_state.tokens.empty()) { + if (tokens.empty()) { continue; } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - const int cur_lcs_len = slot.prompt_state.tokens.get_common_prefix(task.prompt_tokens); + const int lcs_len_cur = tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length - const float cur_similarity = float(cur_lcs_len) / task.prompt_tokens.size(); + const float sim_cur = float(lcs_len_cur) / task.prompt_tokens.size(); // select the current slot if the criteria match - if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { - lcs_len = cur_lcs_len; - similarity = cur_similarity; + if (lcs_len_cur > lcs_len_best && sim_cur > slot_prompt_similarity) { + lcs_len_best = lcs_len_cur; + sim_best = sim_cur; ret = &slot; } } if (ret != nullptr) { - SLT_INF(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %.3f (> %.3f thold)\n", lcs_len, similarity, slot_prompt_similarity); + SLT_INF(*ret, "selected slot by lcs similarity, lcs_len_best = %d, sim_best = %.3f (> %.3f thold)\n", lcs_len_best, sim_best, slot_prompt_similarity); - // if we are going to lose a large portion of the existing context - save it in the prompt cache - const float f_keep = float(lcs_len) / ret->prompt_state.tokens.size(); + // if we are about to lose a large portion of the existing context - save it in the prompt cache + const float f_keep = float(lcs_len_best) / ret->prompt_state.tokens.size(); if (f_keep < 0.5f) { update_cache = true; } @@ -2586,12 +2588,25 @@ struct server_context { } } - // TODO: mtmd does not support prompt cache - if (update_cache && ret->prompt_state.tokens.size() > 0 && !ret->prompt_state.tokens.has_mtmd) { - SRV_INF("%s", "updating prompt cache\n"); + if (ret) { + const auto & tokens = ret->prompt_state.tokens; + + // don't update the cache if the slot's context is empty + update_cache = update_cache && tokens.size() > 0; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && !ret->prompt_state.tokens.has_mtmd; + + if (update_cache) { + SRV_INF("%s", "updating prompt cache\n"); - prompt_cache.save(*ret); - prompt_cache.load(*ret, task.prompt_tokens); + const int64_t t_start = ggml_time_us(); + + prompt_cache.save(*ret); + prompt_cache.load(*ret, task.prompt_tokens); + + SRV_INF("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } } return ret; From 967b1e45ee4565bc68be9460f709381abb88adf3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 13:43:03 +0300 Subject: [PATCH 05/17] server : refactor --- common/common.h | 2 +- tools/server/server.cpp | 385 ++++++++++------------ tools/server/tests/unit/test_ctx_shift.py | 13 +- 3 files changed, 184 insertions(+), 216 deletions(-) diff --git a/common/common.h b/common/common.h index 8a8ecd667f2cc..832c047e4bd85 100644 --- a/common/common.h +++ b/common/common.h @@ -378,7 +378,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool no_perf = false; // disable performance metrics - bool ctx_shift = false; // context shift on infinite text generation + bool ctx_shift = false; // context shift on infinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 60677bca99e51..5fe150ad6bae1 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -265,16 +265,15 @@ struct server_task { int id = -1; // to be filled by server_queue int index = -1; // used when there are multiple prompts (batch request) - server_task_type type; - // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; + int id_slot = -1; // used by SERVER_TASK_TYPE_INFERENCE slot_params params; - server_tokens prompt_tokens; + server_tokens tokens; - int id_slot = -1; + server_task_type type; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE struct slot_action { @@ -1367,17 +1366,17 @@ struct server_task_result_slot_save_load : server_task_result { { "save_ms", t_ms } }}, }; - } else { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", n_tokens }, - { "n_read", n_bytes }, - { "timings", { - { "restore_ms", t_ms } - }}, - }; } + + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; } }; @@ -1398,7 +1397,7 @@ struct server_task_result_apply_lora : server_task_result { } }; -struct ctx_checkpoint { +struct server_slot_prompt_checkpoint { llama_pos pos_min; llama_pos pos_max; @@ -1409,12 +1408,12 @@ struct ctx_checkpoint { } }; -struct server_slot_prompt_state { +struct server_slot_prompt { server_tokens tokens; std::vector data; - std::list checkpoints; + std::list checkpoints; size_t size() const { size_t res = data.size(); @@ -1431,10 +1430,8 @@ struct server_slot_prompt_state { } }; -struct server_slot; - struct server_prompt_cache { - std::list states; + std::list states; size_t limit_size = 0; // 0 = no limit @@ -1457,9 +1454,6 @@ struct server_prompt_cache { return res; } - - void save(server_slot & slot); - void load(server_slot & slot, const server_tokens & prompt); }; struct server_slot { @@ -1500,13 +1494,15 @@ struct server_slot { int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated - int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; // input prompt tokens - server_tokens prompt_tokens; + server_tokens input_tokens; + + int32_t n_prompt_tokens() const { + return input_tokens.size(); + } size_t last_nl_pos = 0; @@ -1526,7 +1522,10 @@ struct server_slot { std::string stopping_word; // state - server_slot_prompt_state prompt_state; + server_slot_prompt prompt; + + void prompt_save(server_prompt_cache & prompt_cache); + void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens); // sampling json json_schema; @@ -1556,7 +1555,6 @@ struct server_slot { void reset() { SLT_DBG(*this, "%s", "\n"); - n_prompt_tokens = 0; n_prompt_tokens_cache = 0; last_nl_pos = 0; @@ -1767,7 +1765,7 @@ struct server_slot { {"speculative", can_speculate()}, {"is_processing", is_processing()}, {"params", params.to_json()}, - {"prompt", prompt_tokens.detokenize(ctx, true)}, + {"prompt", input_tokens.detokenize(ctx, true)}, {"next_token", { {"has_next_token", has_next_token}, @@ -1781,18 +1779,18 @@ struct server_slot { } }; -void server_prompt_cache::save(server_slot & slot) { - auto & state = slot.prompt_state; +void server_slot::prompt_save(server_prompt_cache & prompt_cache) { + auto & states = prompt_cache.states; - assert(state.data.size() == 0); + assert(prompt.data.size() == 0); // first check if the current state is contained fully in the cache for (auto it = states.begin(); it != states.end(); ++it) { const auto & cached_prompt = it->tokens; - const int cur_lcs_len = cached_prompt.get_common_prefix(state.tokens); + const int cur_lcs_len = cached_prompt.get_common_prefix(prompt.tokens); - if (cur_lcs_len == (int) state.tokens.size()) { + if (cur_lcs_len == (int) prompt.tokens.size()) { SRV_INF("%s", " - prompt is already cached, skipping\n"); return; } @@ -1802,7 +1800,7 @@ void server_prompt_cache::save(server_slot & slot) { for (auto it = states.begin(); it != states.end();) { const auto & cached_prompt = it->tokens; - const int len = cached_prompt.get_common_prefix(state.tokens); + const int len = cached_prompt.get_common_prefix(prompt.tokens); if (len == (int) cached_prompt.size()) { SRV_INF(" - removing cached prompt with length %d\n", len); @@ -1813,14 +1811,14 @@ void server_prompt_cache::save(server_slot & slot) { } } - const size_t cur_size = llama_state_seq_get_size_ext(slot.ctx, slot.id, 0); + const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); SRV_INF(" - saving prompt with length %d, total cache size = %.3f MiB\n", - (int) state.tokens.size(), cur_size / (1024.0 * 1024.0)); + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); // if there is a limit, remove the oldest entries to make room - if (limit_size > 0) { - while (size() + cur_size > limit_size) { + if (prompt_cache.limit_size > 0) { + while (prompt_cache.size() + cur_size > prompt_cache.limit_size) { if (states.empty()) { break; } @@ -1829,7 +1827,7 @@ void server_prompt_cache::save(server_slot & slot) { } } else { // else, make sure the number of cached tokens doesn't exceed the context size of the slot - while (n_tokens() + (int) state.tokens.size() > slot.n_ctx) { + while (prompt_cache.n_tokens() + (int) prompt.tokens.size() > n_ctx) { if (states.empty()) { break; } @@ -1841,20 +1839,20 @@ void server_prompt_cache::save(server_slot & slot) { // TODO: for some reason we can't copy server_tokens, so we have to do this workaround auto & cur = states.emplace_back(); cur = { - /*.tokens =*/ server_tokens(state.tokens.get_text_tokens(), false), + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), /*.data =*/ std::vector(cur_size), - /*.checkpoints =*/ state.checkpoints, + /*.checkpoints =*/ prompt.checkpoints, }; - llama_state_seq_get_data_ext(slot.ctx, cur.data.data(), cur_size, slot.id, 0); + llama_state_seq_get_data_ext(ctx, cur.data.data(), cur_size, id, 0); - SRV_INF(" - cache state: %zu prompts, %.3f MiB\n", states.size(), size() / (1024.0 * 1024.0)); + SRV_INF(" - cache state: %zu prompts, %.3f MiB\n", states.size(), prompt_cache.size() / (1024.0 * 1024.0)); } -void server_prompt_cache::load(server_slot & slot, const server_tokens & prompt) { - auto & state = slot.prompt_state; +void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { + auto & states = prompt_cache.states; - int lcs_len = state.tokens.get_common_prefix(prompt); + int lcs_len = prompt.tokens.get_common_prefix(tokens); SRV_INF(" - looking for better prompt, base lcs_len = %d\n", lcs_len); @@ -1864,7 +1862,7 @@ void server_prompt_cache::load(server_slot & slot, const server_tokens & prompt) for (auto it = states.begin(); it != states.end(); ++it) { const auto & cached_prompt = it->tokens; - const int cur_lcs_len = cached_prompt.get_common_prefix(prompt); + const int cur_lcs_len = cached_prompt.get_common_prefix(tokens); if (lcs_len < cur_lcs_len) { lcs_len = cur_lcs_len; @@ -1876,16 +1874,16 @@ void server_prompt_cache::load(server_slot & slot, const server_tokens & prompt) SRV_INF(" - found better prompt with lcs_len = %d\n", lcs_len); const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(slot.ctx, it_best->data.data(), size, slot.id, 0); + const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id, 0); if (n != size) { - SLT_WRN(slot, "failed to restore slot state with size %zu\n", size); + SLT_WRN(*this, "failed to restore slot state with size %zu\n", size); return; } it_best->data.clear(); it_best->data.shrink_to_fit(); - state = std::move(*it_best); + prompt = std::move(*it_best); states.erase(it_best); } @@ -2441,7 +2439,7 @@ struct server_context { slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; slot.mctx = mctx; - slot.prompt_state.tokens.has_mtmd = mctx != nullptr; + slot.prompt.tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2531,7 +2529,7 @@ struct server_context { continue; } - const auto & tokens = slot.prompt_state.tokens; + const auto & tokens = slot.prompt.tokens; // skip the slot if it does not contains cached tokens if (tokens.empty()) { @@ -2539,10 +2537,10 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - const int lcs_len_cur = tokens.get_common_prefix(task.prompt_tokens); + const int lcs_len_cur = tokens.get_common_prefix(task.tokens); // fraction of the common subsequence length - const float sim_cur = float(lcs_len_cur) / task.prompt_tokens.size(); + const float sim_cur = float(lcs_len_cur) / task.tokens.size(); // select the current slot if the criteria match if (lcs_len_cur > lcs_len_best && sim_cur > slot_prompt_similarity) { @@ -2554,10 +2552,12 @@ struct server_context { } if (ret != nullptr) { - SLT_INF(*ret, "selected slot by lcs similarity, lcs_len_best = %d, sim_best = %.3f (> %.3f thold)\n", lcs_len_best, sim_best, slot_prompt_similarity); + const float f_keep = float(lcs_len_best) / ret->prompt.tokens.size(); + + SLT_INF(*ret, "selected slot by lcs similarity, lcs_len_best = %d, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + lcs_len_best, sim_best, slot_prompt_similarity, f_keep); // if we are about to lose a large portion of the existing context - save it in the prompt cache - const float f_keep = float(lcs_len_best) / ret->prompt_state.tokens.size(); if (f_keep < 0.5f) { update_cache = true; } @@ -2589,21 +2589,21 @@ struct server_context { } if (ret) { - const auto & tokens = ret->prompt_state.tokens; + const auto & tokens = ret->prompt.tokens; // don't update the cache if the slot's context is empty update_cache = update_cache && tokens.size() > 0; // TODO: mtmd does not support prompt cache - update_cache = update_cache && !ret->prompt_state.tokens.has_mtmd; + update_cache = update_cache && !ret->prompt.tokens.has_mtmd; if (update_cache) { SRV_INF("%s", "updating prompt cache\n"); const int64_t t_start = ggml_time_us(); - prompt_cache.save(*ret); - prompt_cache.load(*ret, task.prompt_tokens); + ret->prompt_save(prompt_cache); + ret->prompt_load(prompt_cache, task.tokens); SRV_INF("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); } @@ -2614,17 +2614,17 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); - slot.id_task = task.id; - slot.index = task.index; - slot.task_type = task.type; - slot.params = std::move(task.params); - slot.prompt_tokens = std::move(task.prompt_tokens); + slot.id_task = task.id; + slot.index = task.index; + slot.task_type = task.type; + slot.params = std::move(task.params); + slot.input_tokens = std::move(task.tokens); if (!are_lora_equal(slot.params.lora, slot.lora)) { // if lora has changed, check to see if the cache should be cleared if (lora_should_clear_cache(slot.lora, slot.params.lora)) { SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size()); - slot.prompt_state.tokens.clear(); + slot.prompt.tokens.clear(); } else { SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size()); } @@ -2632,9 +2632,8 @@ struct server_context { } // if using alora, make sure it's only a single one requested and active - size_t alora_invocation_start = slot.prompt_tokens.size(); + size_t alora_invocation_start = slot.input_tokens.size(); if (lora_all_alora(slot.lora)) { - const auto & enabled_ids = lora_get_enabled_ids(slot.lora); // TODO: This will error out if a user requests two aloras, but only // provides the activation string for one. We could, instead search @@ -2653,10 +2652,10 @@ struct server_context { // scan backwards through the prompt tokens to find the last // occurrence of the invocation sequence int match_idx = static_cast(n_invocation_tokens) - 1; - for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) { + for (int i = slot.input_tokens.size() - 1; i >= 0; --i) { // the token in this position matches the next token to find in // the invocation sequence - if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) { + if (slot.input_tokens[i] == invocation_tokens[match_idx]) { // if it's a full match, we've found the start if (match_idx == 0) { alora_invocation_start = i; @@ -2671,7 +2670,7 @@ struct server_context { } // if the activation string is not found, disable the alora - if (alora_invocation_start == slot.prompt_tokens.size()) { + if (alora_invocation_start == slot.input_tokens.size()) { SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); slot.lora[enabled_ids[0]].scale = 0.0f; } else { @@ -2680,7 +2679,7 @@ struct server_context { } } - if (!slot.prompt_tokens.validate(ctx)) { + if (!slot.input_tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } @@ -2850,7 +2849,7 @@ struct server_context { slot.has_next_token = false; SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); + slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); } if (llama_vocab_is_eog(vocab, result.tok)) { @@ -2862,7 +2861,7 @@ struct server_context { const auto n_ctx_train = llama_model_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction @@ -2933,7 +2932,7 @@ struct server_context { } void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx); + send_error(slot.id_task, error, type, slot.n_prompt_tokens(), slot.n_ctx); } void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { @@ -2970,9 +2969,9 @@ struct server_context { if (is_progress) { res->is_progress = true; - res->progress.total = slot.n_prompt_tokens; + res->progress.total = slot.n_prompt_tokens(); res->progress.cache = slot.n_prompt_tokens_cache; - res->progress.processed = slot.prompt_state.tokens.size(); + res->progress.processed = slot.prompt.tokens.size(); res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000); } else { res->content = tkn.text_to_send; @@ -2982,7 +2981,7 @@ struct server_context { } res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens = slot.n_prompt_tokens(); res->post_sampling_probs = slot.params.post_sampling_probs; res->verbose = slot.params.verbose; @@ -3012,25 +3011,25 @@ struct server_context { res->content = slot.generated_text; res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = slot.prompt_tokens.detokenize(ctx, true); + res->prompt = slot.input_tokens.detokenize(ctx, true); res->response_fields = std::move(slot.params.response_fields); res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens = slot.n_prompt_tokens(); res->n_tokens_cached = slot.n_past; res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; res->post_sampling_probs = slot.params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->stream = slot.params.stream; - res->include_usage = slot.params.include_usage; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->include_usage = slot.params.include_usage; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); // populate res.probs_output if (slot.params.sampling.n_probs > 0) { @@ -3057,7 +3056,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->n_tokens = slot.n_prompt_tokens(); res->oaicompat = slot.params.oaicompat; const int n_embd = llama_model_n_embd(model); @@ -3088,9 +3087,9 @@ struct server_context { common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize); res->embedding.push_back(embd_res); break; - } else { - res->embedding.emplace_back(embd, embd + n_embd); } + + res->embedding.emplace_back(embd, embd + n_embd); } SLT_DBG(slot, "%s", "sending embeddings\n"); @@ -3100,9 +3099,9 @@ struct server_context { void send_rerank(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens(); for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -3343,13 +3342,13 @@ struct server_context { break; } - const size_t token_count = slot->prompt_state.tokens.size(); + const size_t token_count = slot->prompt.tokens.size(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const llama_tokens & tokens = slot->prompt_state.tokens.get_text_tokens(); + const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); @@ -3391,13 +3390,13 @@ struct server_context { size_t token_count = 0; size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->prompt_state.tokens.clear(); // KV may already been invalidated? + slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } tokens.resize(token_count); - slot->prompt_state.tokens.clear(); - slot->prompt_state.tokens.insert(tokens); + slot->prompt.tokens.clear(); + slot->prompt.tokens.insert(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -3431,9 +3430,9 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->prompt_state.tokens.size(); + const size_t n_erased = slot->prompt.tokens.size(); llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); - slot->prompt_state.tokens.clear(); + slot->prompt.tokens.clear(); auto res = std::make_unique(); res->id = task.id; @@ -3512,14 +3511,14 @@ struct server_context { // add generated tokens to cache { - llama_tokens new_tokens = slot.prompt_state.tokens.get_text_tokens(); // copy + llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize(slot.prompt_state.tokens.size() - n_discard); - slot.prompt_state.tokens.clear(); - slot.prompt_state.tokens.insert(new_tokens); + new_tokens.resize(slot.prompt.tokens.size() - n_discard); + slot.prompt.tokens.clear(); + slot.prompt.tokens.insert(new_tokens); } slot.n_past -= n_discard; @@ -3556,10 +3555,10 @@ struct server_context { common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; - slot.prompt_state.tokens.push_back(slot.sampled); + slot.prompt.tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.prompt_state.tokens.size(), slot.truncated); + slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated); } // process in chunks of params.n_batch @@ -3582,7 +3581,7 @@ struct server_context { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - auto & prompt_tokens = slot.prompt_tokens; + const auto & input_tokens = slot.input_tokens; // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -3590,26 +3589,26 @@ struct server_context { slot.t_start_generation = 0; slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", + slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens()); // print prompt tokens (for debugging) /*if (1) { // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + for (int i = 0; i < std::min(16, input_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); } } else { // all - for (int i = 0; i < (int) prompt_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + for (int i = 0; i < (int) input_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); } }*/ // empty prompt passed -> release the slot and send empty response - if (prompt_tokens.empty()) { + if (input_tokens.empty()) { SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); slot.release(); @@ -3626,68 +3625,33 @@ struct server_context { } if (!slot.can_split()) { - if (slot.n_prompt_tokens > n_ubatch) { + if (slot.n_prompt_tokens() > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); continue; } - if (slot.n_prompt_tokens > slot.n_ctx) { + if (slot.n_prompt_tokens() > slot.n_ctx) { slot.release(); send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); continue; } } else { - if (!params_base.ctx_shift) { - // if context shift is disabled, we make sure prompt size is smaller than KV size - // TODO: there should be a separate parameter that control prompt truncation - // context shift should be applied only during the generation phase - if (slot.n_prompt_tokens >= slot.n_ctx) { - slot.release(); - send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - continue; - } + if (slot.n_prompt_tokens() >= slot.n_ctx) { + slot.release(); + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + continue; } + if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; + slot.params.n_keep = slot.n_prompt_tokens(); } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) { - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - const int n_left = slot.n_ctx - slot.params.n_keep; - - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); - llama_tokens new_tokens( - curr_tokens.begin(), - curr_tokens.begin() + slot.params.n_keep); - - new_tokens.insert( - new_tokens.end(), - curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - curr_tokens.end()); - - prompt_tokens.clear(); - prompt_tokens.insert(new_tokens); - - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); - SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - } + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.prompt_state.tokens.get_common_prefix(prompt_tokens); + slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens); // if there is an alora invoked, don't cache after the invocation start if (slot.alora_invocation_start >= 0) { @@ -3707,13 +3671,13 @@ struct server_context { SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.prompt_state.tokens.size() && - head_p < prompt_tokens.size()) { + while (head_c < slot.prompt.tokens.size() && + head_p < input_tokens.size()) { size_t n_match = 0; - while (head_c + n_match < slot.prompt_state.tokens.size() && - head_p + n_match < prompt_tokens.size() && - slot.prompt_state.tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + while (head_c + n_match < slot.prompt.tokens.size() && + head_p + n_match < input_tokens.size() && + slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { n_match++; } @@ -3730,7 +3694,7 @@ struct server_context { llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.prompt_state.tokens.set_token(head_p + i, slot.prompt_state.tokens[head_c + i]); + slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); slot.n_past++; } @@ -3754,27 +3718,27 @@ struct server_context { // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, slot.n_past - n_swa); - if (slot.n_past > 0 && slot.n_past < (int) slot.prompt_state.tokens.size()) { + if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt_state.tokens.size(), slot.id, pos_min); + SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } if (pos_min > pos_min_thold) { - SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt_state.tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint const auto it = std::find_if( - slot.prompt_state.checkpoints.rbegin(), - slot.prompt_state.checkpoints.rend(), + slot.prompt.checkpoints.rbegin(), + slot.prompt.checkpoints.rend(), [&](const auto & cur) { // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] return cur.pos_min < pos_min_thold; } ); - bool do_reset = it == slot.prompt_state.checkpoints.rend(); + bool do_reset = it == slot.prompt.checkpoints.rend(); if (!do_reset) { // restore the context checkpoint @@ -3801,11 +3765,11 @@ struct server_context { { // erase any checkpoints with pos_min > pos_min_thold - for (auto it = slot.prompt_state.checkpoints.begin(); it != slot.prompt_state.checkpoints.end();) { + for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; if (cur.pos_min > pos_min_thold) { SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); - it = slot.prompt_state.checkpoints.erase(it); + it = slot.prompt.checkpoints.erase(it); } else { ++it; } @@ -3814,8 +3778,8 @@ struct server_context { } // [TAG_PROMPT_LOGITS] - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens); + if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens()); slot.n_past--; SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); } @@ -3826,7 +3790,7 @@ struct server_context { if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { continue; } } @@ -3844,15 +3808,13 @@ struct server_context { SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past); // remove the non-common part from the cache - slot.prompt_state.tokens.keep_first(slot.n_past); + slot.prompt.tokens.keep_first(slot.n_past); // check if we should process the image - if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image int32_t new_n_past; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); - int32_t n_pos = new_n_past - slot.n_past; - + int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); slot.release(); @@ -3862,10 +3824,12 @@ struct server_context { // add the image chunk to cache { - const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); - slot.prompt_state.tokens.push_back(chunk.get()); // copy + const auto & chunk = input_tokens.find_chunk(slot.n_past); + slot.prompt.tokens.push_back(chunk.get()); // copy } + const int32_t n_pos = new_n_past - slot.n_past; + slot.n_past += n_pos; slot.n_prompt_tokens_processed += n_pos; } @@ -3899,9 +3863,9 @@ struct server_context { ); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) { // get next token to process - llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + llama_token cur_tok = input_tokens[slot.n_past]; if (cur_tok == LLAMA_TOKEN_NULL) { break; // end of text chunk } @@ -3918,33 +3882,32 @@ struct server_context { const bool need_embd = server_task_type_need_embd(slot.task_type); common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); - slot.prompt_state.tokens.push_back(cur_tok); + slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; slot.n_past++; // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) { + if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) { break; } } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens()); // entire prompt has been processed - if (slot.n_past == slot.n_prompt_tokens) { + if (slot.n_past == slot.n_prompt_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); - GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { - llama_token id = slot.prompt_tokens[i]; + for (int i = 0; i < slot.n_prompt_tokens(); ++i) { + llama_token id = input_tokens[i]; if (id != LLAMA_TOKEN_NULL) { common_sampler_accept(slot.smpl, id, false); } @@ -3965,22 +3928,22 @@ struct server_context { do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt_state.checkpoints.empty() || pos_max > slot.prompt_state.checkpoints.back().pos_max + 64); + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); if (do_checkpoint) { - while (slot.prompt_state.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { // make room for the new checkpoint, if needed - const auto & cur = slot.prompt_state.checkpoints.front(); + const auto & cur = slot.prompt.checkpoints.front(); SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - slot.prompt_state.checkpoints.erase(slot.prompt_state.checkpoints.begin()); + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - auto & cur = slot.prompt_state.checkpoints.emplace_back(ctx_checkpoint{ + auto & cur = slot.prompt.checkpoints.emplace_back(server_slot_prompt_checkpoint{ /*.pos_min = */ pos_min, /*.pos_max = */ pos_max, /*.data = */ std::vector(checkpoint_size), @@ -3989,7 +3952,7 @@ struct server_context { llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.prompt_state.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } } } @@ -4198,7 +4161,7 @@ struct server_context { params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; - const llama_tokens & cached_text_tokens = slot.prompt_state.tokens.get_text_tokens(); + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // ignore small drafts @@ -4232,8 +4195,8 @@ struct server_context { // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - slot.prompt_state.tokens.push_back(id); - slot.prompt_state.tokens.insert({ids.begin(), ids.end() - 1}); + slot.prompt.tokens.push_back(id); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); @@ -4859,8 +4822,8 @@ int main(int argc, char ** argv) { task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, data); @@ -5233,9 +5196,9 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); // OAI-compat task.params.oaicompat = oaicompat; @@ -5331,10 +5294,10 @@ int main(int argc, char ** argv) { tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = std::move(tmp); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tmp); tasks.push_back(std::move(task)); } diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 92e49f2bb05a4..4adbbde64f594 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -4,6 +4,12 @@ server = ServerPreset.tinyllama2() +SHORT_TEXT = """ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. +""".strip() + LONG_TEXT = """ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. @@ -21,19 +27,18 @@ def create_server(): def test_ctx_shift_enabled(): - # the prompt is 301 tokens + # the prompt is 226 tokens # the slot context is 512/2 = 256 tokens - # the prompt is truncated to keep the last (301 - 256/2) = 173 tokens # 96 tokens are generated thanks to shifting the context when it gets full global server server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 96, - "prompt": LONG_TEXT, + "prompt": SHORT_TEXT, }) assert res.status_code == 200 - assert res.body["timings"]["prompt_n"] == 173 + assert res.body["timings"]["prompt_n"] == 226 assert res.body["timings"]["predicted_n"] == 96 assert res.body["truncated"] is True From 83ce8cbc6ece05362374d4dd336980155f26d012 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 14:32:26 +0300 Subject: [PATCH 06/17] cont --- tools/server/server.cpp | 271 +++++++++--------- .../server/tests/unit/test_chat_completion.py | 6 +- tools/server/tests/unit/test_completion.py | 4 +- tools/server/utils.hpp | 8 +- 4 files changed, 138 insertions(+), 151 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 5fe150ad6bae1..41fee8f9e5e33 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -158,7 +158,6 @@ struct slot_params { if (only_metrics) { return json { - {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, {"temperature", sampling.temp}, {"dynatemp_range", sampling.dynatemp_range}, @@ -181,7 +180,8 @@ struct slot_params { {"mirostat", sampling.mirostat}, {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, - {"max_tokens", n_predict}, // User configured n_predict + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, @@ -209,7 +209,6 @@ struct slot_params { } return json { - {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, {"temperature", sampling.temp}, {"dynatemp_range", sampling.dynatemp_range}, @@ -234,7 +233,8 @@ struct slot_params { {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, {"stop", antiprompt}, - {"max_tokens", n_predict}, // User configured n_predict + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, @@ -289,6 +289,8 @@ struct server_task { // used by SERVER_TASK_TYPE_SET_LORA std::vector set_lora; + server_task() = default; + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( @@ -305,6 +307,7 @@ struct server_task { defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; defaults.antiprompt = params_base.antiprompt; // enabling this will output extra debug information in the HTTP responses from the server @@ -1458,10 +1461,6 @@ struct server_prompt_cache { struct server_slot { int id; - int id_task = -1; - - // only used for completion/embedding/infill/rerank - server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; llama_batch batch_spec = {}; @@ -1473,16 +1472,6 @@ struct server_slot { common_speculative * spec = nullptr; - std::vector lora; - int32_t alora_invocation_start = -1; - - // the index relative to completion multi-task request - size_t index = 0; - - struct slot_params params; - - slot_state state = SLOT_STATE_IDLE; - // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -1492,16 +1481,12 @@ struct server_slot { int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; - int32_t n_predict = -1; // TODO: disambiguate from params.n_predict int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; - // input prompt tokens - server_tokens input_tokens; - int32_t n_prompt_tokens() const { - return input_tokens.size(); + return task.tokens.size(); } size_t last_nl_pos = 0; @@ -1522,11 +1507,18 @@ struct server_slot { std::string stopping_word; // state + slot_state state = SLOT_STATE_IDLE; + + server_task task; + server_slot_prompt prompt; void prompt_save(server_prompt_cache & prompt_cache); void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens); + std::vector lora; + int32_t alora_invocation_start = -1; + // sampling json json_schema; @@ -1557,16 +1549,15 @@ struct server_slot { n_prompt_tokens_cache = 0; - last_nl_pos = 0; - generated_text = ""; - has_new_line = false; - truncated = false; - stop = STOP_TYPE_NONE; - stopping_word = ""; - n_past = 0; - n_sent_text = 0; - task_type = SERVER_TASK_TYPE_COMPLETION; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; generated_tokens.clear(); generated_token_probs.clear(); @@ -1583,11 +1574,11 @@ struct server_slot { } bool need_embd() const { - return server_task_type_need_embd(task_type); + return server_task_type_need_embd(task.type); } bool need_logits() const { - return server_task_type_need_logits(task_type); + return server_task_type_need_logits(task.type); } // if the context does not have a memory module then all embeddings have to be computed within a single ubatch @@ -1599,18 +1590,18 @@ struct server_slot { } bool can_batch_with(server_slot & other_slot) const { - return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora); + return task.type == other_slot.task.type && are_lora_equal(lora, other_slot.lora); } bool has_budget(const common_params & global_params) { - if (params.n_predict == -1 && global_params.n_predict == -1) { + if (task.params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } n_remaining = -1; - if (params.n_predict != -1) { - n_remaining = params.n_predict - n_decoded; + if (task.params.n_predict != -1) { + n_remaining = task.params.n_predict - n_decoded; } else if (global_params.n_predict != -1) { n_remaining = global_params.n_predict - n_decoded; } @@ -1623,7 +1614,7 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + return ctx_dft && task.params.speculative.n_max > 0 && task.params.cache_prompt; } void add_token(const completion_token_output & token) { @@ -1674,7 +1665,7 @@ struct server_slot { auto new_msg = common_chat_parse( generated_text, /* is_partial= */ stop != STOP_TYPE_EOS, - params.oaicompat_chat_syntax); + task.params.oaicompat_chat_syntax); if (!new_msg.empty()) { new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); chat_msg = new_msg; @@ -1686,7 +1677,7 @@ struct server_slot { size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { size_t stop_pos = std::string::npos; - for (const std::string & word : params.antiprompt) { + for (const std::string & word : task.params.antiprompt) { size_t pos; if (is_full_stop) { @@ -1742,11 +1733,11 @@ struct server_slot { if (only_metrics) { return json { {"id", id}, - {"id_task", id_task}, + {"id_task", task.id}, {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, - {"params", params.to_json(true)}, + {"params", task.params.to_json(true)}, {"next_token", { {"has_next_token", has_next_token}, @@ -1760,12 +1751,12 @@ struct server_slot { return json { {"id", id}, - {"id_task", id_task}, + {"id_task", task.id}, {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, - {"params", params.to_json()}, - {"prompt", input_tokens.detokenize(ctx, true)}, + {"params", task.params.to_json()}, + {"prompt", task.tokens.detokenize(ctx, true)}, {"next_token", { {"has_next_token", has_next_token}, @@ -2437,7 +2428,6 @@ struct server_context { slot.id = i; slot.ctx = ctx; slot.n_ctx = n_ctx_slot; - slot.n_predict = params_base.n_predict; slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; @@ -2462,9 +2452,6 @@ struct server_context { SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); - slot.params.sampling = params_base.sampling; - slot.params.n_keep = params_base.n_keep; - slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); }; @@ -2474,7 +2461,12 @@ struct server_context { slots.push_back(std::move(slot)); } - default_generation_settings_for_props = slots[0].to_json(); + { + slots[0].task.params.sampling = params_base.sampling; + slots[0].task.params.n_keep = params_base.n_keep; + + default_generation_settings_for_props = slots[0].to_json(); + } // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) @@ -2595,7 +2587,7 @@ struct server_context { update_cache = update_cache && tokens.size() > 0; // TODO: mtmd does not support prompt cache - update_cache = update_cache && !ret->prompt.tokens.has_mtmd; + update_cache = update_cache && (ret->mctx == nullptr); if (update_cache) { SRV_INF("%s", "updating prompt cache\n"); @@ -2614,25 +2606,20 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); - slot.id_task = task.id; - slot.index = task.index; - slot.task_type = task.type; - slot.params = std::move(task.params); - slot.input_tokens = std::move(task.tokens); - if (!are_lora_equal(slot.params.lora, slot.lora)) { + if (!are_lora_equal(task.params.lora, slot.lora)) { // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, slot.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size()); + if (lora_should_clear_cache(slot.lora, task.params.lora)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); slot.prompt.tokens.clear(); } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size()); + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); } - slot.lora = slot.params.lora; + slot.lora = task.params.lora; } // if using alora, make sure it's only a single one requested and active - size_t alora_invocation_start = slot.input_tokens.size(); + size_t alora_invocation_start = task.tokens.size(); if (lora_all_alora(slot.lora)) { const auto & enabled_ids = lora_get_enabled_ids(slot.lora); // TODO: This will error out if a user requests two aloras, but only @@ -2652,10 +2639,10 @@ struct server_context { // scan backwards through the prompt tokens to find the last // occurrence of the invocation sequence int match_idx = static_cast(n_invocation_tokens) - 1; - for (int i = slot.input_tokens.size() - 1; i >= 0; --i) { + for (int i = task.tokens.size() - 1; i >= 0; --i) { // the token in this position matches the next token to find in // the invocation sequence - if (slot.input_tokens[i] == invocation_tokens[match_idx]) { + if (task.tokens[i] == invocation_tokens[match_idx]) { // if it's a full match, we've found the start if (match_idx == 0) { alora_invocation_start = i; @@ -2670,7 +2657,7 @@ struct server_context { } // if the activation string is not found, disable the alora - if (alora_invocation_start == slot.input_tokens.size()) { + if (alora_invocation_start == task.tokens.size()) { SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); slot.lora[enabled_ids[0]].scale = 0.0f; } else { @@ -2679,24 +2666,20 @@ struct server_context { } } - if (!slot.input_tokens.validate(ctx)) { + if (!task.tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } - SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { - // Might be better to reject the request with a 400 ? - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); - slot.params.n_predict = slot.n_predict; - } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + // initialize samplers { if (slot.smpl != nullptr) { common_sampler_free(slot.smpl); } - slot.smpl = common_sampler_init(model, slot.params.sampling); + slot.smpl = common_sampler_init(model, task.params.sampling); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -2704,12 +2687,15 @@ struct server_context { } } + // initialize draft batch if (slot.ctx_dft) { llama_batch_free(slot.batch_spec); - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); } + slot.task = std::move(task); + slot.state = SLOT_STATE_STARTED; SLT_INF(slot, "%s", "processing task\n"); @@ -2731,7 +2717,7 @@ struct server_context { slot.sampled = result.tok; slot.generated_text += token_str; - if (slot.params.return_tokens) { + if (slot.task.params.return_tokens) { slot.generated_tokens.push_back(result.tok); } slot.has_next_token = true; @@ -2768,7 +2754,7 @@ struct server_context { } slot.add_token(result); - if (slot.params.stream) { + if (slot.task.params.stream) { send_partial_response(slot, result, false); } } @@ -2790,12 +2776,12 @@ struct server_context { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task.params.n_predict); } if (slot.has_new_line) { // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.params.n_indent > 0) { + if (slot.task.params.n_indent > 0) { // check the current indentation // TODO: improve by not doing it more than once for each new line if (slot.last_nl_pos > 0) { @@ -2807,7 +2793,7 @@ struct server_context { pos++; } - if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + if (pos < slot.generated_text.size() && n_indent < slot.task.params.n_indent) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; @@ -2834,11 +2820,11 @@ struct server_context { slot.has_new_line = true; // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + if (slot.task.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task.params.t_max_predict_ms)) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task.params.t_max_predict_ms); } } @@ -2861,7 +2847,7 @@ struct server_context { const auto n_ctx_train = llama_model_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { + if (slot.task.params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction @@ -2869,7 +2855,7 @@ struct server_context { SLT_WRN(slot, "n_predict (%d) is set for infinite generation. " "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", - slot.params.n_predict, n_ctx_train); + slot.task.params.n_predict, n_ctx_train); } SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); @@ -2878,7 +2864,7 @@ struct server_context { } void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.params.sampling.n_probs; + size_t n_probs = slot.task.params.sampling.n_probs; size_t n_vocab = llama_vocab_n_tokens(vocab); if (post_sampling) { @@ -2932,7 +2918,7 @@ struct server_context { } void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.id_task, error, type, slot.n_prompt_tokens(), slot.n_ctx); + send_error(slot.task.id, error, type, slot.n_prompt_tokens(), slot.n_ctx); } void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { @@ -2964,8 +2950,8 @@ struct server_context { void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; + res->id = slot.task.id; + res->index = slot.task.index; if (is_progress) { res->is_progress = true; @@ -2982,20 +2968,20 @@ struct server_context { res->n_decoded = slot.n_decoded; res->n_prompt_tokens = slot.n_prompt_tokens(); - res->post_sampling_probs = slot.params.post_sampling_probs; + res->post_sampling_probs = slot.task.params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->verbose = slot.task.params.verbose; + res->oaicompat = slot.task.params.oaicompat; + res->oaicompat_model = slot.task.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task.params.oaicompat_cmpl_id; // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { + if (slot.task.params.sampling.n_probs > 0) { res->prob_output = tkn; // copy the token probs } // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + if (slot.stop != STOP_TYPE_NONE || slot.task.params.timings_per_token) { res->timings = slot.get_timings(); } @@ -3004,15 +2990,15 @@ struct server_context { void send_final_response(server_slot & slot) { auto res = std::make_unique(); - res->id = slot.id_task; - res->id_slot = slot.id; + res->id = slot.task.id; + res->id_slot = slot.id; - res->index = slot.index; + res->index = slot.task.index; res->content = slot.generated_text; res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = slot.input_tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.params.response_fields); + res->prompt = slot.task.tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.task.params.response_fields); res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; @@ -3021,19 +3007,19 @@ struct server_context { res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; - res->post_sampling_probs = slot.params.post_sampling_probs; - - res->verbose = slot.params.verbose; - res->stream = slot.params.stream; - res->include_usage = slot.params.include_usage; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->post_sampling_probs = slot.task.params.post_sampling_probs; + + res->verbose = slot.task.params.verbose; + res->stream = slot.task.params.stream; + res->include_usage = slot.task.params.include_usage; + res->oaicompat = slot.task.params.oaicompat; + res->oaicompat_model = slot.task.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task.params.oaicompat_cmpl_id; res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { - if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + if (slot.task.params.sampling.n_probs > 0) { + if (!slot.task.params.stream && slot.stop == STOP_TYPE_WORD) { const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); @@ -3047,17 +3033,17 @@ struct server_context { } } - res->generation_params = slot.params; // copy the parameters + res->generation_params = slot.task.params; // copy the parameters queue_results.send(std::move(res)); } void send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; + res->id = slot.task.id; + res->index = slot.task.index; res->n_tokens = slot.n_prompt_tokens(); - res->oaicompat = slot.params.oaicompat; + res->oaicompat = slot.task.params.oaicompat; const int n_embd = llama_model_n_embd(model); @@ -3084,7 +3070,7 @@ struct server_context { // normalize only when there is pooling if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize); + common_embd_normalize(embd, embd_res.data(), n_embd, slot.task.params.embd_normalize); res->embedding.push_back(embd_res); break; } @@ -3099,8 +3085,8 @@ struct server_context { void send_rerank(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; + res->id = slot.task.id; + res->index = slot.task.index; res->n_tokens = slot.n_prompt_tokens(); for (int i = 0; i < batch.n_tokens; ++i) { @@ -3265,7 +3251,7 @@ struct server_context { { // release slot linked with the task id for (auto & slot : slots) { - if (slot.id_task == task.id_target) { + if (slot.task.id == task.id_target) { slot.release(); break; } @@ -3500,9 +3486,9 @@ struct server_context { } // Shift context - const int n_keep = slot.params.n_keep + add_bos_token; + const int n_keep = slot.task.params.n_keep + add_bos_token; const int n_left = slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + const int n_discard = slot.task.params.n_discard ? slot.task.params.n_discard : (n_left / 2); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); @@ -3534,7 +3520,8 @@ struct server_context { server_slot * slot_batched = nullptr; auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); + return params_base.special || + slot.task.params.sampling.preserved_tokens.find(token) != slot.task.params.sampling.preserved_tokens.end(); }; // frist, add sampled tokens from any ongoing sequences @@ -3581,7 +3568,7 @@ struct server_context { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - const auto & input_tokens = slot.input_tokens; + const auto & input_tokens = slot.task.tokens; // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -3592,7 +3579,7 @@ struct server_context { slot.state = SLOT_STATE_PROCESSING_PROMPT; SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", - slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens()); + slot.n_ctx, slot.task.params.n_keep, slot.n_prompt_tokens()); // print prompt tokens (for debugging) /*if (1) { @@ -3643,13 +3630,13 @@ struct server_context { continue; } - if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens(); + if (slot.task.params.n_keep < 0) { + slot.task.params.n_keep = slot.n_prompt_tokens(); } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); + slot.task.params.n_keep = std::min(slot.n_ctx - 4, slot.task.params.n_keep); - if (slot.params.cache_prompt) { + if (slot.task.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens); @@ -3879,7 +3866,7 @@ struct server_context { } // embedding requires all tokens in the batch to be output - const bool need_embd = server_task_type_need_embd(slot.task_type); + const bool need_embd = server_task_type_need_embd(slot.task.type); common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); slot.prompt.tokens.push_back(cur_tok); @@ -4050,7 +4037,7 @@ struct server_context { for (auto & slot : slots) { // optionally send prompt processing progress if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.params.stream && slot.params.return_progress) { + if (slot.task.params.stream && slot.task.params.return_progress) { send_partial_response(slot, {}, true); } } @@ -4060,7 +4047,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + if (slot.task.type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -4068,7 +4055,7 @@ struct server_context { continue; // continue loop of slots } - if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + if (slot.task.type == SERVER_TASK_TYPE_RERANK) { send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -4106,8 +4093,8 @@ struct server_context { result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - if (slot.params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); + if (slot.task.params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task.params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { @@ -4136,7 +4123,7 @@ struct server_context { } // determine the max draft that fits the current slot state - int n_draft_max = slot.params.speculative.n_max; + int n_draft_max = slot.task.params.speculative.n_max; // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -4148,8 +4135,8 @@ struct server_context { SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < slot.params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + if (n_draft_max < slot.task.params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task.params.speculative.n_min); continue; } @@ -4157,16 +4144,16 @@ struct server_context { llama_token id = slot.sampled; struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task.params.speculative.n_max; + params_spec.p_min = slot.task.params.speculative.p_min; const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + if (slot.task.params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task.params.speculative.n_min); continue; } diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 2979ed4bb7b12..6e5a3488e789b 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -19,8 +19,8 @@ def create_server(): (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'), (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), ] @@ -54,7 +54,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), - ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), + ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"), ] ) def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 11483e679a505..00ba78cf67c09 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -16,7 +16,7 @@ def create_server(): @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), + ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), ]) def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): global server @@ -41,7 +41,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), ]) def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): global server diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 3e6974df8d442..1fcfbfda5e508 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -31,10 +31,10 @@ using json = nlohmann::ordered_json; -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).task.id, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).task.id, __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).task.id, __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).task.id, __VA_ARGS__) #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) From ba8ffa782d9a4b19c969b4e00e913467fd65fce4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 16:50:14 +0300 Subject: [PATCH 07/17] cont : make the server task of the slot const --- tools/server/server.cpp | 272 ++++++++++++++------------ tools/server/tests/unit/test_basic.py | 3 +- tools/server/utils.hpp | 8 +- 3 files changed, 149 insertions(+), 134 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 41fee8f9e5e33..735644aaa69e5 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -693,7 +693,7 @@ struct server_task_result { // using shared_ptr for polymorphism of server_task_result using server_task_result_ptr = std::unique_ptr; -inline std::string stop_type_to_str(stop_type type) { +static inline std::string stop_type_to_str(stop_type type) { switch (type) { case STOP_TYPE_EOS: return "eos"; case STOP_TYPE_WORD: return "word"; @@ -1478,6 +1478,7 @@ struct server_slot { // generation props int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; + int32_t n_keep = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; @@ -1486,7 +1487,7 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; int32_t n_prompt_tokens() const { - return task.tokens.size(); + return task->tokens.size(); } size_t last_nl_pos = 0; @@ -1509,7 +1510,7 @@ struct server_slot { // state slot_state state = SLOT_STATE_IDLE; - server_task task; + std::unique_ptr task; server_slot_prompt prompt; @@ -1569,16 +1570,22 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; + task.reset(); + // clear alora start alora_invocation_start = -1; } bool need_embd() const { - return server_task_type_need_embd(task.type); + GGML_ASSERT(task); + + return server_task_type_need_embd(task->type); } bool need_logits() const { - return server_task_type_need_logits(task.type); + GGML_ASSERT(task); + + return server_task_type_need_logits(task->type); } // if the context does not have a memory module then all embeddings have to be computed within a single ubatch @@ -1590,18 +1597,22 @@ struct server_slot { } bool can_batch_with(server_slot & other_slot) const { - return task.type == other_slot.task.type && are_lora_equal(lora, other_slot.lora); + GGML_ASSERT(task); + + return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); } bool has_budget(const common_params & global_params) { - if (task.params.n_predict == -1 && global_params.n_predict == -1) { + GGML_ASSERT(task); + + if (task->params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } n_remaining = -1; - if (task.params.n_predict != -1) { - n_remaining = task.params.n_predict - n_decoded; + if (task->params.n_predict != -1) { + n_remaining = task->params.n_predict - n_decoded; } else if (global_params.n_predict != -1) { n_remaining = global_params.n_predict - n_decoded; } @@ -1614,7 +1625,7 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft && task.params.speculative.n_max > 0 && task.params.cache_prompt; + return ctx_dft; } void add_token(const completion_token_output & token) { @@ -1627,11 +1638,15 @@ struct server_slot { void release() { if (is_processing()) { + GGML_ASSERT(task); + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; state = SLOT_STATE_IDLE; + task.reset(); + callback_on_release(id); } } @@ -1660,12 +1675,14 @@ struct server_slot { } const common_chat_msg & update_chat_msg(std::vector & diffs) { + GGML_ASSERT(task); + auto previous_msg = chat_msg; SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); auto new_msg = common_chat_parse( generated_text, /* is_partial= */ stop != STOP_TYPE_EOS, - task.params.oaicompat_chat_syntax); + task->params.oaicompat_chat_syntax); if (!new_msg.empty()) { new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); chat_msg = new_msg; @@ -1675,9 +1692,11 @@ struct server_slot { } size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + GGML_ASSERT(task); + size_t stop_pos = std::string::npos; - for (const std::string & word : task.params.antiprompt) { + for (const std::string & word : task->params.antiprompt) { size_t pos; if (is_full_stop) { @@ -1730,43 +1749,33 @@ struct server_slot { } json to_json(bool only_metrics = false) const { - if (only_metrics) { - return json { - {"id", id}, - {"id_task", task.id}, - {"n_ctx", n_ctx}, - {"speculative", can_speculate()}, - {"is_processing", is_processing()}, - {"params", task.params.to_json(true)}, - {"next_token", - { - {"has_next_token", has_next_token}, - {"has_new_line", has_new_line}, - {"n_remain", n_remaining}, - {"n_decoded", n_decoded}, - } - }, - }; - } + json res; - return json { + res = { {"id", id}, - {"id_task", task.id}, {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, - {"params", task.params.to_json()}, - {"prompt", task.tokens.detokenize(ctx, true)}, - {"next_token", + {"id_task", task ? task->id : -1}, + }; + + if (task) { + res["params"] = task->params.to_json(only_metrics); + res["next_token"] = { { {"has_next_token", has_next_token}, {"has_new_line", has_new_line}, {"n_remain", n_remaining}, {"n_decoded", n_decoded}, - {"stopping_word", stopping_word}, } - }, - }; + }; + + if (!only_metrics) { + res["prompt"] = task->tokens.detokenize(ctx, true); + } + } + + return res; } }; @@ -2267,7 +2276,6 @@ struct server_context { // slots / clients std::vector slots; - json default_generation_settings_for_props; server_queue queue_tasks; server_response queue_results; @@ -2461,13 +2469,6 @@ struct server_context { slots.push_back(std::move(slot)); } - { - slots[0].task.params.sampling = params_base.sampling; - slots[0].task.params.n_keep = params_base.n_keep; - - default_generation_settings_for_props = slots[0].to_json(); - } - // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { @@ -2694,7 +2695,7 @@ struct server_context { slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); } - slot.task = std::move(task); + slot.task = std::make_unique(std::move(task)); slot.state = SLOT_STATE_STARTED; @@ -2717,7 +2718,7 @@ struct server_context { slot.sampled = result.tok; slot.generated_text += token_str; - if (slot.task.params.return_tokens) { + if (slot.task->params.return_tokens) { slot.generated_tokens.push_back(result.tok); } slot.has_next_token = true; @@ -2754,7 +2755,7 @@ struct server_context { } slot.add_token(result); - if (slot.task.params.stream) { + if (slot.task->params.stream) { send_partial_response(slot, result, false); } } @@ -2776,12 +2777,12 @@ struct server_context { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task.params.n_predict); + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); } if (slot.has_new_line) { // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.task.params.n_indent > 0) { + if (slot.task->params.n_indent > 0) { // check the current indentation // TODO: improve by not doing it more than once for each new line if (slot.last_nl_pos > 0) { @@ -2793,7 +2794,7 @@ struct server_context { pos++; } - if (pos < slot.generated_text.size() && n_indent < slot.task.params.n_indent) { + if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; @@ -2820,11 +2821,11 @@ struct server_context { slot.has_new_line = true; // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.task.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task.params.t_max_predict_ms)) { + if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task.params.t_max_predict_ms); + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); } } @@ -2847,7 +2848,7 @@ struct server_context { const auto n_ctx_train = llama_model_n_ctx_train(model); - if (slot.task.params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { + if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction @@ -2855,7 +2856,7 @@ struct server_context { SLT_WRN(slot, "n_predict (%d) is set for infinite generation. " "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", - slot.task.params.n_predict, n_ctx_train); + slot.task->params.n_predict, n_ctx_train); } SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); @@ -2864,7 +2865,7 @@ struct server_context { } void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.task.params.sampling.n_probs; + size_t n_probs = slot.task->params.sampling.n_probs; size_t n_vocab = llama_vocab_n_tokens(vocab); if (post_sampling) { @@ -2918,7 +2919,7 @@ struct server_context { } void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.task.id, error, type, slot.n_prompt_tokens(), slot.n_ctx); + send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx); } void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { @@ -2950,8 +2951,8 @@ struct server_context { void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { auto res = std::make_unique(); - res->id = slot.task.id; - res->index = slot.task.index; + res->id = slot.task->id; + res->index = slot.task->index; if (is_progress) { res->is_progress = true; @@ -2968,20 +2969,20 @@ struct server_context { res->n_decoded = slot.n_decoded; res->n_prompt_tokens = slot.n_prompt_tokens(); - res->post_sampling_probs = slot.task.params.post_sampling_probs; + res->post_sampling_probs = slot.task->params.post_sampling_probs; - res->verbose = slot.task.params.verbose; - res->oaicompat = slot.task.params.oaicompat; - res->oaicompat_model = slot.task.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task.params.oaicompat_cmpl_id; + res->verbose = slot.task->params.verbose; + res->oaicompat = slot.task->params.oaicompat; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; // populate res.probs_output - if (slot.task.params.sampling.n_probs > 0) { + if (slot.task->params.sampling.n_probs > 0) { res->prob_output = tkn; // copy the token probs } // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.task.params.timings_per_token) { + if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { res->timings = slot.get_timings(); } @@ -2990,15 +2991,16 @@ struct server_context { void send_final_response(server_slot & slot) { auto res = std::make_unique(); - res->id = slot.task.id; + + res->id = slot.task->id; res->id_slot = slot.id; - res->index = slot.task.index; + res->index = slot.task->index; res->content = slot.generated_text; res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = slot.task.tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.task.params.response_fields); + res->prompt = slot.task->tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.task->params.response_fields); res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; @@ -3007,19 +3009,19 @@ struct server_context { res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; - res->post_sampling_probs = slot.task.params.post_sampling_probs; - - res->verbose = slot.task.params.verbose; - res->stream = slot.task.params.stream; - res->include_usage = slot.task.params.include_usage; - res->oaicompat = slot.task.params.oaicompat; - res->oaicompat_model = slot.task.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task.params.oaicompat_cmpl_id; + res->post_sampling_probs = slot.task->params.post_sampling_probs; + + res->verbose = slot.task->params.verbose; + res->stream = slot.task->params.stream; + res->include_usage = slot.task->params.include_usage; + res->oaicompat = slot.task->params.oaicompat; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); // populate res.probs_output - if (slot.task.params.sampling.n_probs > 0) { - if (!slot.task.params.stream && slot.stop == STOP_TYPE_WORD) { + if (slot.task->params.sampling.n_probs > 0) { + if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); @@ -3033,17 +3035,17 @@ struct server_context { } } - res->generation_params = slot.task.params; // copy the parameters + res->generation_params = slot.task->params; // copy the parameters queue_results.send(std::move(res)); } void send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.task.id; - res->index = slot.task.index; + res->id = slot.task->id; + res->index = slot.task->index; res->n_tokens = slot.n_prompt_tokens(); - res->oaicompat = slot.task.params.oaicompat; + res->oaicompat = slot.task->params.oaicompat; const int n_embd = llama_model_n_embd(model); @@ -3070,7 +3072,7 @@ struct server_context { // normalize only when there is pooling if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.task.params.embd_normalize); + common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; } @@ -3085,8 +3087,8 @@ struct server_context { void send_rerank(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.task.id; - res->index = slot.task.index; + res->id = slot.task->id; + res->index = slot.task->index; res->n_tokens = slot.n_prompt_tokens(); for (int i = 0; i < batch.n_tokens; ++i) { @@ -3251,7 +3253,7 @@ struct server_context { { // release slot linked with the task id for (auto & slot : slots) { - if (slot.task.id == task.id_target) { + if (slot.task && slot.task->id == task.id_target) { slot.release(); break; } @@ -3474,8 +3476,8 @@ struct server_context { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() - slot.release(); send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + slot.release(); continue; } @@ -3486,9 +3488,16 @@ struct server_context { } // Shift context - const int n_keep = slot.task.params.n_keep + add_bos_token; + int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep; + + if (add_bos_token) { + n_keep += 1; + } + + n_keep = std::min(slot.n_ctx - 4, n_keep); + const int n_left = slot.n_past - n_keep; - const int n_discard = slot.task.params.n_discard ? slot.task.params.n_discard : (n_left / 2); + const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); @@ -3521,7 +3530,7 @@ struct server_context { auto accept_special_token = [&](server_slot & slot, llama_token token) { return params_base.special || - slot.task.params.sampling.preserved_tokens.find(token) != slot.task.params.sampling.preserved_tokens.end(); + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); }; // frist, add sampled tokens from any ongoing sequences @@ -3568,7 +3577,7 @@ struct server_context { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - const auto & input_tokens = slot.task.tokens; + const auto & input_tokens = slot.task->tokens; // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -3579,7 +3588,7 @@ struct server_context { slot.state = SLOT_STATE_PROCESSING_PROMPT; SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", - slot.n_ctx, slot.task.params.n_keep, slot.n_prompt_tokens()); + slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens()); // print prompt tokens (for debugging) /*if (1) { @@ -3598,45 +3607,40 @@ struct server_context { if (input_tokens.empty()) { SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - slot.release(); slot.print_timings(); send_final_response(slot); + slot.release(); + continue; } // TODO: support memory-less logits computation if (slot.need_logits() && !llama_get_memory(ctx)) { - slot.release(); send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); + slot.release(); continue; } if (!slot.can_split()) { if (slot.n_prompt_tokens() > n_ubatch) { - slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + slot.release(); continue; } if (slot.n_prompt_tokens() > slot.n_ctx) { - slot.release(); send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); continue; } } else { if (slot.n_prompt_tokens() >= slot.n_ctx) { - slot.release(); send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); continue; } - if (slot.task.params.n_keep < 0) { - slot.task.params.n_keep = slot.n_prompt_tokens(); - } - - slot.task.params.n_keep = std::min(slot.n_ctx - 4, slot.task.params.n_keep); - - if (slot.task.params.cache_prompt) { + if (slot.task->params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens); @@ -3804,8 +3808,8 @@ struct server_context { int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); - slot.release(); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + slot.release(); continue; } @@ -3866,9 +3870,7 @@ struct server_context { } // embedding requires all tokens in the batch to be output - const bool need_embd = server_task_type_need_embd(slot.task.type); - - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -4013,8 +4015,8 @@ struct server_context { if (!err.empty()) { SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); for (auto & slot : slots) { - slot.release(); send_error(slot, err); + slot.release(); } break; } @@ -4037,7 +4039,7 @@ struct server_context { for (auto & slot : slots) { // optionally send prompt processing progress if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task.params.stream && slot.task.params.return_progress) { + if (slot.task->params.stream && slot.task->params.return_progress) { send_partial_response(slot, {}, true); } } @@ -4047,7 +4049,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task.type == SERVER_TASK_TYPE_EMBEDDING) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -4055,7 +4057,7 @@ struct server_context { continue; // continue loop of slots } - if (slot.task.type == SERVER_TASK_TYPE_RERANK) { + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -4093,16 +4095,17 @@ struct server_context { result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - if (slot.task.params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.task.params.post_sampling_probs, params_base.special, tok_idx); + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { // release slot because of stop condition - slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + slot.release(); + continue; } } @@ -4123,7 +4126,7 @@ struct server_context { } // determine the max draft that fits the current slot state - int n_draft_max = slot.task.params.speculative.n_max; + int n_draft_max = slot.task->params.speculative.n_max; // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -4135,8 +4138,8 @@ struct server_context { SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < slot.task.params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task.params.speculative.n_min); + if (n_draft_max < slot.task->params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); continue; } @@ -4145,15 +4148,15 @@ struct server_context { struct common_speculative_params params_spec; params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task.params.speculative.n_max; - params_spec.p_min = slot.task.params.speculative.p_min; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // ignore small drafts - if (slot.task.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task.params.speculative.n_min); + if (slot.task->params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); continue; } @@ -4197,11 +4200,11 @@ struct server_context { // TODO: set result.probs if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + slot.release(); + break; } } @@ -4702,9 +4705,22 @@ int main(int argc, char ** argv) { }; const auto handle_props = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + json default_generation_settings_for_props; + + { + slot_params params; + + params.sampling = ctx_server.params_base.sampling; + + default_generation_settings_for_props = json { + {"params", params.to_json(true)}, + {"n_ctx", ctx_server.slots[0].n_ctx}, + }; + } + // this endpoint is publicly available, please only return what is safe to be exposed json data = { - { "default_generation_settings", ctx_server.default_generation_settings_for_props }, + { "default_generation_settings", default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model.path }, { "modalities", json { diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 829af2ebe7bfb..720b136b05175 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -66,8 +66,7 @@ def test_server_slots(): assert len(res.body) == server.n_slots assert server.n_ctx is not None and server.n_slots is not None assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots - assert "params" in res.body[0] - assert res.body[0]["params"]["seed"] == server.seed + assert "params" not in res.body[0] def test_load_split_model(): diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 1fcfbfda5e508..884f4108e6b64 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -31,10 +31,10 @@ using json = nlohmann::ordered_json; -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).task.id, __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).task.id, __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).task.id, __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).task.id, __VA_ARGS__) +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) From 23b7f7653c42df1b626e257a3b0afad159be21d3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 17:03:34 +0300 Subject: [PATCH 08/17] cont : minor [no ci] --- tools/server/server.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 735644aaa69e5..c17fb7705eaa5 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1464,6 +1464,7 @@ struct server_slot { llama_batch batch_spec = {}; + // TODO: change to unique_ptrs for consistency: llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1472,6 +1473,8 @@ struct server_slot { common_speculative * spec = nullptr; + std::unique_ptr task; + // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -1510,8 +1513,6 @@ struct server_slot { // state slot_state state = SLOT_STATE_IDLE; - std::unique_ptr task; - server_slot_prompt prompt; void prompt_save(server_prompt_cache & prompt_cache); From c32d8b40be2999b574d75825d121cf61bf0a9c0e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 18:02:13 +0300 Subject: [PATCH 09/17] server : cache prompts and checkpoints only for completion tasks --- tools/server/server.cpp | 51 ++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index c17fb7705eaa5..88ada41cd8d24 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1436,7 +1436,8 @@ struct server_slot_prompt { struct server_prompt_cache { std::list states; - size_t limit_size = 0; // 0 = no limit + // in bytes, 0 = no limit + size_t limit_size = 2ull*1024*1024*1024; size_t size() const { size_t res = 0; @@ -1532,7 +1533,7 @@ struct server_slot { std::vector generated_tool_call_ids; // stats - size_t n_sent_text = 0; // number of sent text character + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; int64_t t_start_generation; @@ -1792,7 +1793,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) { const int cur_lcs_len = cached_prompt.get_common_prefix(prompt.tokens); if (cur_lcs_len == (int) prompt.tokens.size()) { - SRV_INF("%s", " - prompt is already cached, skipping\n"); + SRV_WRN("%s", " - prompt is already cached, skipping\n"); return; } } @@ -1804,7 +1805,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) { const int len = cached_prompt.get_common_prefix(prompt.tokens); if (len == (int) cached_prompt.size()) { - SRV_INF(" - removing cached prompt with length %d\n", len); + SRV_WRN(" - removing cached prompt with length %d\n", len); it = states.erase(it); } else { @@ -1814,7 +1815,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) { const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); - SRV_INF(" - saving prompt with length %d, total cache size = %.3f MiB\n", + SRV_WRN(" - saving prompt with length %d, total cache size = %.3f MiB\n", (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); // if there is a limit, remove the oldest entries to make room @@ -1824,6 +1825,8 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) { break; } + SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + states.pop_front(); } } else { @@ -1833,6 +1836,8 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) { break; } + SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + states.pop_front(); } } @@ -1847,7 +1852,11 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) { llama_state_seq_get_data_ext(ctx, cur.data.data(), cur_size, id, 0); - SRV_INF(" - cache state: %zu prompts, %.3f MiB\n", states.size(), prompt_cache.size() / (1024.0 * 1024.0)); + SRV_WRN(" - cache state: %zu prompts, %.3f MiB\n", states.size(), prompt_cache.size() / (1024.0 * 1024.0)); + + for (const auto & state : states) { + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } } void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { @@ -1855,7 +1864,7 @@ void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_t int lcs_len = prompt.tokens.get_common_prefix(tokens); - SRV_INF(" - looking for better prompt, base lcs_len = %d\n", lcs_len); + SRV_WRN(" - looking for better prompt, base lcs_len = %d\n", lcs_len); auto it_best = states.end(); @@ -1872,7 +1881,7 @@ void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_t } if (it_best != states.end()) { - SRV_INF(" - found better prompt with lcs_len = %d\n", lcs_len); + SRV_WRN(" - found better prompt with lcs_len = %d\n", lcs_len); const size_t size = it_best->data.size(); const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id, 0); @@ -2454,7 +2463,7 @@ struct server_context { SRV_ERR("%s", "failed to create speculator\n"); return; } - for (auto &pair : params_base.speculative.replacements) { + for (auto & pair : params_base.speculative.replacements) { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } @@ -2483,7 +2492,7 @@ struct server_context { // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - SRV_INF("Enable thinking? %d\n", enable_thinking); + SRV_INF("thinking = %d\n", enable_thinking); oai_parser_opt = { /* use_jinja */ params_base.use_jinja, @@ -2585,6 +2594,9 @@ struct server_context { if (ret) { const auto & tokens = ret->prompt.tokens; + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + // don't update the cache if the slot's context is empty update_cache = update_cache && tokens.size() > 0; @@ -2592,14 +2604,14 @@ struct server_context { update_cache = update_cache && (ret->mctx == nullptr); if (update_cache) { - SRV_INF("%s", "updating prompt cache\n"); + SRV_WRN("%s", "updating prompt cache\n"); const int64_t t_start = ggml_time_us(); ret->prompt_save(prompt_cache); ret->prompt_load(prompt_cache, task.tokens); - SRV_INF("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); } } @@ -3734,16 +3746,16 @@ struct server_context { if (!do_reset) { // restore the context checkpoint - const size_t ctx_checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + const size_t checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != ctx_checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); + if (n != checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); } } @@ -3842,6 +3854,9 @@ struct server_context { bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + // make a checkpoint of the parts of the memory that cannot be rolled back. // checkpoints are created only if: // - the model uses SWA and we are not using `swa_full` @@ -3941,7 +3956,7 @@ struct server_context { llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } } From 677b10dda1523b7b159ad8df271b6a9c713d34bf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 19:15:00 +0300 Subject: [PATCH 10/17] server : improve prompt caching logic --- tools/server/server.cpp | 80 +++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 88ada41cd8d24..dc8499978899b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -9,7 +9,6 @@ #include "sampling.h" #include "speculative.h" #include "mtmd.h" -#include "mtmd-helper.h" // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" @@ -1439,6 +1438,9 @@ struct server_prompt_cache { // in bytes, 0 = no limit size_t limit_size = 2ull*1024*1024*1024; + // in tokens, 0 = no limit + size_t limit_tokens = 0; + size_t size() const { size_t res = 0; @@ -1449,8 +1451,8 @@ struct server_prompt_cache { return res; } - int n_tokens() const { - int res = 0; + size_t n_tokens() const { + size_t res = 0; for (const auto & state : states) { res += state.n_tokens(); @@ -1458,6 +1460,42 @@ struct server_prompt_cache { return res; } + + void update() { + // always keep at least one state, regardless of the limits + if (states.size() > 1) { + if (limit_size > 0) { + while (size() > limit_size) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + if (limit_tokens > 0) { + while (n_tokens() > limit_tokens) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + } + + SRV_WRN(" - cache state: %zu prompts, %.3f MiB, limits: %.3f MiB, %zu tokens\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens); + + for (const auto & state : states) { + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } + } }; struct server_slot { @@ -1805,7 +1843,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) { const int len = cached_prompt.get_common_prefix(prompt.tokens); if (len == (int) cached_prompt.size()) { - SRV_WRN(" - removing cached prompt with length %d\n", len); + SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); it = states.erase(it); } else { @@ -1815,33 +1853,9 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) { const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); - SRV_WRN(" - saving prompt with length %d, total cache size = %.3f MiB\n", + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); - // if there is a limit, remove the oldest entries to make room - if (prompt_cache.limit_size > 0) { - while (prompt_cache.size() + cur_size > prompt_cache.limit_size) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } else { - // else, make sure the number of cached tokens doesn't exceed the context size of the slot - while (prompt_cache.n_tokens() + (int) prompt.tokens.size() > n_ctx) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - // TODO: for some reason we can't copy server_tokens, so we have to do this workaround auto & cur = states.emplace_back(); cur = { @@ -1851,12 +1865,6 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) { }; llama_state_seq_get_data_ext(ctx, cur.data.data(), cur_size, id, 0); - - SRV_WRN(" - cache state: %zu prompts, %.3f MiB\n", states.size(), prompt_cache.size() / (1024.0 * 1024.0)); - - for (const auto & state : states) { - SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); - } } void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { @@ -2611,6 +2619,8 @@ struct server_context { ret->prompt_save(prompt_cache); ret->prompt_load(prompt_cache, task.tokens); + prompt_cache.update(); + SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); } } From 264d2c37e4cb9172f06136b3f6625dd1baa1b414 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Oct 2025 21:43:47 +0300 Subject: [PATCH 11/17] cont : fix check for number of cached prompts [no ci] --- tools/server/server.cpp | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index dc8499978899b..9f95e55cb03d7 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1462,30 +1462,28 @@ struct server_prompt_cache { } void update() { - // always keep at least one state, regardless of the limits - if (states.size() > 1) { - if (limit_size > 0) { - while (size() > limit_size) { - if (states.empty()) { - break; - } + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } - SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); - states.pop_front(); - } + states.pop_front(); } + } - if (limit_tokens > 0) { - while (n_tokens() > limit_tokens) { - if (states.empty()) { - break; - } + if (limit_tokens > 0) { + while (states.size() > 1 && n_tokens() > limit_tokens) { + if (states.empty()) { + break; + } - SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); - states.pop_front(); - } + states.pop_front(); } } From f42dfa451ee03bd2a0afd1e3a99446516a7f3bc0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 8 Oct 2025 15:33:37 +0300 Subject: [PATCH 12/17] server : improve caching logic, add -cram CLI arg --- common/arg.cpp | 8 ++ common/common.h | 3 +- tools/server/server.cpp | 193 ++++++++++++++++++++++++---------------- 3 files changed, 124 insertions(+), 80 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index ecc296485cb47..a2be2314146f7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1935,6 +1935,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.n_ctx_checkpoints = value; } ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--cache-ram", "-cram"}, "N", + string_format("set the maximum cache size in MiB (default: %d, 0 - no limit)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mb), + [](common_params & params, int value) { + params.cache_ram_mb = value; + } + ).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--kv-unified", "-kvu"}, string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" diff --git a/common/common.h b/common/common.h index 832c047e4bd85..627a1a34f313a 100644 --- a/common/common.h +++ b/common/common.h @@ -425,7 +425,8 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting - int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot + int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot + int32_t cache_ram_mb = 8192; // 0 = no limit, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 9f95e55cb03d7..93c7d7a342cfd 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1436,11 +1436,16 @@ struct server_prompt_cache { std::list states; // in bytes, 0 = no limit - size_t limit_size = 2ull*1024*1024*1024; + size_t limit_size = 0; // in tokens, 0 = no limit size_t limit_tokens = 0; + void init(size_t limit_size, size_t limit_tokens) { + this->limit_size = limit_size; + this->limit_tokens = limit_tokens; + } + size_t size() const { size_t res = 0; @@ -1461,6 +1466,97 @@ struct server_prompt_cache { return res; } + server_slot_prompt * alloc(const server_tokens & tokens, size_t state_size) { + // first check if the current state is contained fully in the cache + for (auto it = states.begin(); it != states.end(); ++it) { + const int cur_lcs_len = it->tokens.get_common_prefix(tokens); + + if (cur_lcs_len == (int) tokens.size()) { + SRV_WRN("%s", " - prompt is already cached, skipping\n"); + return nullptr; + } + } + + // next, remove any cached prompts that are fully contained in the current prompt + for (auto it = states.begin(); it != states.end();) { + const int len = it->tokens.get_common_prefix(tokens); + + if (len == (int) it->tokens.size()) { + SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); + + it = states.erase(it); + } else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } catch (const std::bad_alloc & e) { + SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4*size()); + + SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto & cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(tokens.get_text_tokens(), false), + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ {}, + }; + + return &cur; + } + + bool load(server_slot_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { + int lcs_len = prompt.tokens.get_common_prefix(tokens_new); + + SRV_WRN(" - looking for better prompt, base lcs_len = %d\n", lcs_len); + + auto it_best = states.end(); + + // find the most similar cached prompt + for (auto it = states.begin(); it != states.end(); ++it) { + const int cur_lcs_len = it->tokens.get_common_prefix(tokens_new); + + if (lcs_len < cur_lcs_len) { + lcs_len = cur_lcs_len; + it_best = it; + } + } + + if (it_best != states.end()) { + SRV_WRN(" - found better prompt with lcs_len = %d\n", lcs_len); + + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; + } + void update() { if (limit_size > 0) { // always keep at least one state, regardless of the limits @@ -1487,11 +1583,11 @@ struct server_prompt_cache { } } - SRV_WRN(" - cache state: %zu prompts, %.3f MiB, limits: %.3f MiB, %zu tokens\n", + SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n", states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens); for (const auto & state : states) { - SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); } } }; @@ -1552,7 +1648,7 @@ struct server_slot { server_slot_prompt prompt; - void prompt_save(server_prompt_cache & prompt_cache); + void prompt_save(server_prompt_cache & prompt_cache) const; void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens); std::vector lora; @@ -1817,91 +1913,28 @@ struct server_slot { } }; -void server_slot::prompt_save(server_prompt_cache & prompt_cache) { - auto & states = prompt_cache.states; - +void server_slot::prompt_save(server_prompt_cache & prompt_cache) const { assert(prompt.data.size() == 0); - // first check if the current state is contained fully in the cache - for (auto it = states.begin(); it != states.end(); ++it) { - const auto & cached_prompt = it->tokens; - - const int cur_lcs_len = cached_prompt.get_common_prefix(prompt.tokens); - - if (cur_lcs_len == (int) prompt.tokens.size()) { - SRV_WRN("%s", " - prompt is already cached, skipping\n"); - return; - } - } - - // next, remove any cached prompts that are fully contained in the current prompt - for (auto it = states.begin(); it != states.end();) { - const auto & cached_prompt = it->tokens; - - const int len = cached_prompt.get_common_prefix(prompt.tokens); - - if (len == (int) cached_prompt.size()) { - SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); - - it = states.erase(it); - } else { - ++it; - } - } - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); - // TODO: for some reason we can't copy server_tokens, so we have to do this workaround - auto & cur = states.emplace_back(); - cur = { - /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), - /*.data =*/ std::vector(cur_size), - /*.checkpoints =*/ prompt.checkpoints, - }; - - llama_state_seq_get_data_ext(ctx, cur.data.data(), cur_size, id, 0); -} - -void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - auto & states = prompt_cache.states; - - int lcs_len = prompt.tokens.get_common_prefix(tokens); - - SRV_WRN(" - looking for better prompt, base lcs_len = %d\n", lcs_len); - - auto it_best = states.end(); - - // find the most similar cached prompt - for (auto it = states.begin(); it != states.end(); ++it) { - const auto & cached_prompt = it->tokens; - - const int cur_lcs_len = cached_prompt.get_common_prefix(tokens); - - if (lcs_len < cur_lcs_len) { - lcs_len = cur_lcs_len; - it_best = it; - } + auto * cur = prompt_cache.alloc(prompt.tokens, cur_size); + if (cur == nullptr) { + return; } - if (it_best != states.end()) { - SRV_WRN(" - found better prompt with lcs_len = %d\n", lcs_len); - - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id, 0); - if (n != size) { - SLT_WRN(*this, "failed to restore slot state with size %zu\n", size); - return; - } + cur->checkpoints = prompt.checkpoints; - it_best->data.clear(); - it_best->data.shrink_to_fit(); - - prompt = std::move(*it_best); + llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); +} - states.erase(it_best); +void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { + bool res = prompt_cache.load(prompt, tokens, ctx, id); + if (!res) { + SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); } } @@ -2494,6 +2527,8 @@ struct server_context { metrics.init(); + prompt_cache.init(1024*1024ull*params_base.cache_ram_mb, n_ctx); + // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it @@ -3908,7 +3943,7 @@ struct server_context { // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens()); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens()); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens()) { From bf10940e90537b196927a3553e2869d8d030db74 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 8 Oct 2025 15:52:17 +0300 Subject: [PATCH 13/17] server : print prompt mismatch info --- tools/server/server.cpp | 43 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 93c7d7a342cfd..d39a7a24dc019 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3772,6 +3772,49 @@ struct server_context { GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } + // when the prompt prefix does not match, print the tokens around the mismatch + // this is useful for debugging prompt caching + { + const int np0 = std::max(slot.n_past - 4, 0); + const int np1 = std::min(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); + + std::stringstream ss0; + std::stringstream ss1; + + std::stringstream st0; + std::stringstream st1; + + ss0 << "old: ... "; + ss1 << "new: ... "; + + for (int i = np0; i < np1; i++) { + if (i == slot.n_past) { + ss0 << " | "; + ss1 << " | "; + } + + { + const auto token = slot.prompt.tokens[i]; + const auto piece = common_token_to_piece(ctx, token); + ss0 << piece; + st0 << std::setw(8) << token; + } + + { + const auto token = slot.task->tokens[i]; + const auto piece = common_token_to_piece(ctx, token); + ss1 << piece; + st1 << std::setw(8) << token; + } + } + + SLT_WRN(slot, "%s\n", ss0.str().c_str()); + SLT_WRN(slot, "%s\n", ss1.str().c_str()); + + SLT_WRN(slot, "%s\n", st0.str().c_str()); + SLT_WRN(slot, "%s\n", st1.str().c_str()); + } + if (pos_min > pos_min_thold) { SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); From bc6e238ed86133de62d7715c3ec228208167c75c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 8 Oct 2025 16:05:52 +0300 Subject: [PATCH 14/17] cont : better naming [no ci] --- tools/server/server.cpp | 152 +++++++++++++++++++--------------------- 1 file changed, 74 insertions(+), 78 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index d39a7a24dc019..4c376f4ae0650 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -325,32 +325,32 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); - - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); @@ -792,11 +792,12 @@ struct server_task_result_cmpl_final : server_task_result { slot_params generation_params; // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; + std::vector oaicompat_msg_diffs; virtual int get_index() override { @@ -1399,7 +1400,7 @@ struct server_task_result_apply_lora : server_task_result { } }; -struct server_slot_prompt_checkpoint { +struct server_prompt_checkpoint { llama_pos pos_min; llama_pos pos_max; @@ -1410,12 +1411,12 @@ struct server_slot_prompt_checkpoint { } }; -struct server_slot_prompt { +struct server_prompt { server_tokens tokens; std::vector data; - std::list checkpoints; + std::list checkpoints; size_t size() const { size_t res = data.size(); @@ -1433,7 +1434,7 @@ struct server_slot_prompt { }; struct server_prompt_cache { - std::list states; + std::list states; // in bytes, 0 = no limit size_t limit_size = 0; @@ -1466,12 +1467,12 @@ struct server_prompt_cache { return res; } - server_slot_prompt * alloc(const server_tokens & tokens, size_t state_size) { + server_prompt * alloc(const server_prompt & prompt, size_t state_size) { // first check if the current state is contained fully in the cache for (auto it = states.begin(); it != states.end(); ++it) { - const int cur_lcs_len = it->tokens.get_common_prefix(tokens); + const int cur_lcs_len = it->tokens.get_common_prefix(prompt.tokens); - if (cur_lcs_len == (int) tokens.size()) { + if (cur_lcs_len == (int) prompt.tokens.size()) { SRV_WRN("%s", " - prompt is already cached, skipping\n"); return nullptr; } @@ -1479,7 +1480,7 @@ struct server_prompt_cache { // next, remove any cached prompts that are fully contained in the current prompt for (auto it = states.begin(); it != states.end();) { - const int len = it->tokens.get_common_prefix(tokens); + const int len = it->tokens.get_common_prefix(prompt.tokens); if (len == (int) it->tokens.size()) { SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); @@ -1510,15 +1511,15 @@ struct server_prompt_cache { // TODO: for some reason we can't copy server_tokens, so we have to do this workaround auto & cur = states.emplace_back(); cur = { - /*.tokens =*/ server_tokens(tokens.get_text_tokens(), false), + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), /*.data =*/ std::move(state_data), - /*.checkpoints =*/ {}, + /*.checkpoints =*/ prompt.checkpoints, }; return &cur; } - bool load(server_slot_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { int lcs_len = prompt.tokens.get_common_prefix(tokens_new); SRV_WRN(" - looking for better prompt, base lcs_len = %d\n", lcs_len); @@ -1646,10 +1647,30 @@ struct server_slot { // state slot_state state = SLOT_STATE_IDLE; - server_slot_prompt prompt; + server_prompt prompt; + + void prompt_save(server_prompt_cache & prompt_cache) const { + assert(prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); - void prompt_save(server_prompt_cache & prompt_cache) const; - void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens); + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + } + + void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { + bool res = prompt_cache.load(prompt, tokens, ctx, id); + if (!res) { + SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); + } + } std::vector lora; int32_t alora_invocation_start = -1; @@ -1789,19 +1810,19 @@ struct server_slot { result_timings timings; timings.cache_n = n_prompt_tokens_cache; - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; // Add speculative metrics if (n_draft_total > 0) { - timings.draft_n = n_draft_total; + timings.draft_n = n_draft_total; timings.draft_n_accepted = n_draft_accepted; } @@ -1913,31 +1934,6 @@ struct server_slot { } }; -void server_slot::prompt_save(server_prompt_cache & prompt_cache) const { - assert(prompt.data.size() == 0); - - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); - - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); - - auto * cur = prompt_cache.alloc(prompt.tokens, cur_size); - if (cur == nullptr) { - return; - } - - cur->checkpoints = prompt.checkpoints; - - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); -} - -void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); - if (!res) { - SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); - } -} - struct server_metrics { int64_t t_start = 0; @@ -4034,7 +4030,7 @@ struct server_context { const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - auto & cur = slot.prompt.checkpoints.emplace_back(server_slot_prompt_checkpoint{ + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ /*.pos_min = */ pos_min, /*.pos_max = */ pos_max, /*.data = */ std::vector(checkpoint_size), From b612f7fd65efb68978124c7b98af55c78f12458a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 8 Oct 2025 16:57:03 +0300 Subject: [PATCH 15/17] server : improve prompt cache loading logic --- tools/server/server.cpp | 50 ++++++++++++++++++++++++----------------- tools/server/utils.hpp | 14 ++++++++++-- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4c376f4ae0650..708f5a580f592 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1470,9 +1470,9 @@ struct server_prompt_cache { server_prompt * alloc(const server_prompt & prompt, size_t state_size) { // first check if the current state is contained fully in the cache for (auto it = states.begin(); it != states.end(); ++it) { - const int cur_lcs_len = it->tokens.get_common_prefix(prompt.tokens); + const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); - if (cur_lcs_len == (int) prompt.tokens.size()) { + if (cur_lcp_len == (int) prompt.tokens.size()) { SRV_WRN("%s", " - prompt is already cached, skipping\n"); return nullptr; } @@ -1520,24 +1520,37 @@ struct server_prompt_cache { } bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { - int lcs_len = prompt.tokens.get_common_prefix(tokens_new); + const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); - SRV_WRN(" - looking for better prompt, base lcs_len = %d\n", lcs_len); + float f_keep_best = float(lcp_best) / prompt.tokens.size(); + float sim_best = float(lcp_best) / tokens_new.size(); + + SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); auto it_best = states.end(); - // find the most similar cached prompt + // find the most similar cached prompt, that would also preserve the most context for (auto it = states.begin(); it != states.end(); ++it) { - const int cur_lcs_len = it->tokens.get_common_prefix(tokens_new); + const int lcp_cur = it->tokens.get_common_prefix(tokens_new); + + const float f_keep_cur = float(lcp_cur) / it->tokens.size(); + const float sim_cur = float(lcp_cur) / tokens_new.size(); + + // don't trash large prompts + if (f_keep_cur < 0.25f) { + continue; + } + + if (f_keep_best < f_keep_cur && sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; - if (lcs_len < cur_lcs_len) { - lcs_len = cur_lcs_len; it_best = it; } } if (it_best != states.end()) { - SRV_WRN(" - found better prompt with lcs_len = %d\n", lcs_len); + SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); const size_t size = it_best->data.size(); const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); @@ -2560,7 +2573,6 @@ struct server_context { // find the slot that has at least n% prompt similarity if (ret == nullptr && slot_prompt_similarity != 0.0f) { - int lcs_len_best = 0; float sim_best = 0; for (server_slot & slot : slots) { @@ -2576,26 +2588,22 @@ struct server_context { continue; } - // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - const int lcs_len_cur = tokens.get_common_prefix(task.tokens); - - // fraction of the common subsequence length - const float sim_cur = float(lcs_len_cur) / task.tokens.size(); + // fraction of the Longest Common Prefix length with respect to the input prompt length + const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); // select the current slot if the criteria match - if (lcs_len_cur > lcs_len_best && sim_cur > slot_prompt_similarity) { - lcs_len_best = lcs_len_cur; - sim_best = sim_cur; + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; ret = &slot; } } if (ret != nullptr) { - const float f_keep = float(lcs_len_best) / ret->prompt.tokens.size(); + const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); - SLT_INF(*ret, "selected slot by lcs similarity, lcs_len_best = %d, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", - lcs_len_best, sim_best, slot_prompt_similarity, f_keep); + SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + sim_best, slot_prompt_similarity, f_keep); // if we are about to lose a large portion of the existing context - save it in the prompt cache if (f_keep < 0.5f) { diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 884f4108e6b64..f175115f4fd6a 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1273,13 +1273,23 @@ struct server_tokens { size_t get_common_prefix(const server_tokens & b) const { const size_t max_idx = std::min(tokens.size(), b.tokens.size()); + if (!has_mtmd) { + for (size_t i = 0; i < max_idx; ++i) { + if (tokens[i] == b.tokens[i]) { + continue; + } + + return i; + } + + return max_idx; + } + for (size_t i = 0; i < max_idx; ++i) { const llama_token ai = tokens[i]; const llama_token bi = b.tokens[i]; if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - GGML_ASSERT(has_mtmd); - const auto & a_chunk = find_chunk(i); const auto & b_chunk = b.find_chunk(i); From c5e5167d53ddd1926b032d500a901fd6e09ed58e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 17:18:01 +0300 Subject: [PATCH 16/17] server : add option to debug the slot contents (#16482) * server : add option to debug the slot contents * Update tools/server/server.cpp --------- Co-authored-by: Xuan-Son Nguyen --- tools/server/server.cpp | 64 ++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 708f5a580f592..38967d8d2f5db 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1621,6 +1621,7 @@ struct server_slot { common_speculative * spec = nullptr; std::unique_ptr task; + std::unique_ptr task_prev; // used for debugging // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -1739,6 +1740,7 @@ struct server_slot { n_draft_accepted = 0; task.reset(); + task_prev.reset(); // clear alora start alora_invocation_start = -1; @@ -1813,6 +1815,8 @@ struct server_slot { t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; state = SLOT_STATE_IDLE; + + task_prev = std::move(task); task.reset(); callback_on_release(id); @@ -1924,11 +1928,13 @@ struct server_slot { {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, - {"id_task", task ? task->id : -1}, }; - if (task) { - res["params"] = task->params.to_json(only_metrics); + const auto & ptask = task ? task : task_prev; + + if (ptask) { + res["id_task"] = ptask->id; + res["params"] = ptask->params.to_json(only_metrics); res["next_token"] = { { {"has_next_token", has_next_token}, @@ -1939,7 +1945,8 @@ struct server_slot { }; if (!only_metrics) { - res["prompt"] = task->tokens.detokenize(ctx, true); + res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["generated"] = generated_text; } } @@ -2335,6 +2342,8 @@ struct server_context { // slots / clients std::vector slots; + int slots_debug = 0; + server_queue queue_tasks; server_response queue_results; @@ -2527,6 +2536,15 @@ struct server_context { slots.push_back(std::move(slot)); } + { + const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); + slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; + + if (slots_debug) { + SRV_WRN("slots debug = %d\n", slots_debug); + } + } + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { @@ -3331,7 +3349,7 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = slot.to_json(true); + json slot_data = slot.to_json(slots_debug == 0); if (slot.is_processing()) { n_processing_slots++; @@ -4578,18 +4596,18 @@ int main(int argc, char ** argv) { } // TODO: get rid of this dynamic_cast - auto res_metrics = dynamic_cast(result.get()); - GGML_ASSERT(res_metrics != nullptr); + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); // optionally return "fail_on_no_slot" error if (req.has_param("fail_on_no_slot")) { - if (res_metrics->n_idle_slots == 0) { + if (res_task->n_idle_slots == 0) { res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, res_metrics->slots_data); + res_ok(res, res_task->slots_data); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { @@ -4617,56 +4635,56 @@ int main(int argc, char ** argv) { } // TODO: get rid of this dynamic_cast - auto res_metrics = dynamic_cast(result.get()); - GGML_ASSERT(res_metrics != nullptr); + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { {"counter", {{ {"name", "prompt_tokens_total"}, {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total} + {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} }, { {"name", "prompt_seconds_total"}, {"help", "Prompt process time"}, - {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3} + {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} }, { {"name", "tokens_predicted_total"}, {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) res_metrics->n_tokens_predicted_total} + {"value", (uint64_t) res_task->n_tokens_predicted_total} }, { {"name", "tokens_predicted_seconds_total"}, {"help", "Predict process time"}, - {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3} + {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} }, { {"name", "n_decode_total"}, {"help", "Total number of llama_decode() calls"}, - {"value", res_metrics->n_decode_total} + {"value", res_task->n_decode_total} }, { {"name", "n_past_max"}, {"help", "Largest observed n_past."}, - {"value", res_metrics->n_past_max} + {"value", res_task->n_past_max} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)} + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, {"help", "Average prompt throughput in tokens/s."}, - {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.} + {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} },{ {"name", "predicted_tokens_seconds"}, {"help", "Average generation throughput in tokens/s."}, - {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.} + {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} },{ {"name", "requests_processing"}, {"help", "Number of requests processing."}, - {"value", (uint64_t) res_metrics->n_processing_slots} + {"value", (uint64_t) res_task->n_processing_slots} },{ {"name", "requests_deferred"}, {"help", "Number of requests deferred."}, - {"value", (uint64_t) res_metrics->n_tasks_deferred} + {"value", (uint64_t) res_task->n_tasks_deferred} }}} }; @@ -4687,7 +4705,7 @@ int main(int argc, char ** argv) { } } - res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start)); + res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start)); res.set_content(prometheus.str(), "text/plain; version=0.0.4"); res.status = 200; // HTTP OK From ff3340626686c90b025cac3a2966b152fd4cc40a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 17:20:26 +0300 Subject: [PATCH 17/17] server : add option to disable prompt cache --- common/arg.cpp | 6 +++--- common/common.h | 2 +- tools/server/server.cpp | 38 ++++++++++++++++++++++++++------------ 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index a2be2314146f7..98c452e78e905 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1937,10 +1937,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--cache-ram", "-cram"}, "N", - string_format("set the maximum cache size in MiB (default: %d, 0 - no limit)\n" - "[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mb), + string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib), [](common_params & params, int value) { - params.cache_ram_mb = value; + params.cache_ram_mib = value; } ).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( diff --git a/common/common.h b/common/common.h index 627a1a34f313a..03d095198b516 100644 --- a/common/common.h +++ b/common/common.h @@ -426,7 +426,7 @@ struct common_params { int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot - int32_t cache_ram_mb = 8192; // 0 = no limit, 1 = 1 MiB, etc. + int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 38967d8d2f5db..41ecb279feb89 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1434,6 +1434,11 @@ struct server_prompt { }; struct server_prompt_cache { + server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + std::list states; // in bytes, 0 = no limit @@ -1442,11 +1447,6 @@ struct server_prompt_cache { // in tokens, 0 = no limit size_t limit_tokens = 0; - void init(size_t limit_size, size_t limit_tokens) { - this->limit_size = limit_size; - this->limit_tokens = limit_tokens; - } - size_t size() const { size_t res = 0; @@ -1473,7 +1473,7 @@ struct server_prompt_cache { const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); if (cur_lcp_len == (int) prompt.tokens.size()) { - SRV_WRN("%s", " - prompt is already cached, skipping\n"); + SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); return nullptr; } } @@ -2347,7 +2347,7 @@ struct server_context { server_queue queue_tasks; server_response queue_results; - server_prompt_cache prompt_cache; + std::unique_ptr prompt_cache; server_metrics metrics; @@ -2554,7 +2554,19 @@ struct server_context { metrics.init(); - prompt_cache.init(1024*1024ull*params_base.cache_ram_mb, n_ctx); + if (params_base.cache_ram_mib != 0) { + if (params_base.cache_ram_mib < 0) { + SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); + } else { + SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); + } + SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + + prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); + } else { + SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) @@ -2657,6 +2669,8 @@ struct server_context { if (ret) { const auto & tokens = ret->prompt.tokens; + update_cache = update_cache && prompt_cache; + // cache prompts only for completion tasks update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; @@ -2671,10 +2685,10 @@ struct server_context { const int64_t t_start = ggml_time_us(); - ret->prompt_save(prompt_cache); - ret->prompt_load(prompt_cache, task.tokens); + ret->prompt_save(*prompt_cache); + ret->prompt_load(*prompt_cache, task.tokens); - prompt_cache.update(); + prompt_cache->update(); SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); } @@ -5682,7 +5696,7 @@ int main(int argc, char ** argv) { #endif LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__, - is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : + is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str()); // this call blocks the main thread until queue_tasks.terminate() is called