diff --git a/CLAUDE.md b/CLAUDE.md index 4298e12c..912e3a57 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -6,7 +6,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co Java bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp) via JNI, providing a high-level API for LLM inference in Java. The Java layer communicates with a native C++ library through JNI. -Current llama.cpp pinned version: **b8887** +Current llama.cpp pinned version: **b8913** ## Upgrading CUDA Version @@ -137,7 +137,7 @@ Also review the project `CMakeLists.txt` for build-system-level breaks (e.g. ren `ggml/include/ggml.h`, `ggml/include/ggml-backend.h`, `ggml/include/ggml-opt.h`, `ggml-alloc.h`, `ggml-cpu.h`, `peg-parser.h`, `base64.hpp` -**Known breaking changes by version range** (b5022 → b8887): +**Known breaking changes by version range** (b5022 → b8913): | Version | File | Change | |---------|------|--------| @@ -162,6 +162,12 @@ Also review the project `CMakeLists.txt` for build-system-level breaks (e.g. ren | ~b8854–b8887 | `common/chat.h` | `common_chat_msg_diff_to_json_oaicompat` removed; moved to `tools/server/server-chat.cpp`; project defines it locally in `server.hpp` — importing server-chat.cpp is impractical because it pulls in `convert_transcriptions_to_chatcmpl` → `get_media_marker` → `server-common.cpp` | | ~b8854–b8887 | `common/common.h` | `common_params::reasoning_budget` and `reasoning_budget_message` moved into `common_params::sampling` sub-struct as `reasoning_budget_tokens`; update: `params_base.reasoning_budget` → `params_base.sampling.reasoning_budget_tokens` | | ~b8854–b8887 | `common/fit.h` (new) | `llama_params_fit` and `llama_memory_breakdown_print` removed from `include/llama.h`; now `common_fit_params` / `common_memory_breakdown_print` in new `common/fit.h`; not used directly by project | +| ~b8887–b8913 | `tools/server/server-chat.h` | `convert_transcriptions_to_chatcmpl` gained a new `const common_chat_templates * tmpls` second parameter; not called by project's `server.hpp` — handled automatically by upstream `server-chat.cpp` | +| ~b8887–b8913 | `tools/server/server-task.cpp` | `n_discard` clamped to non-negative: `params.n_discard = std::max(0, params.n_discard)`; applied in project's `server.hpp` after the `json_value` parse | +| ~b8887–b8913 | `tools/server/server-common.cpp` | `parallel_tool_calls` now defaults to `caps["supports_parallel_tool_calls"]` instead of hardcoded `false`; handled automatically by upstream file | +| ~b8887–b8913 | `common/chat.h` | New additive `common_chat_prompt_preset` struct and `common_chat_get_asr_prompt()` function; no project changes required | +| ~b8887–b8913 | `common/common.h` | New `string_starts_with(std::string_view, char)` overload added; no project changes required | +| ~b8887–b8913 | `tools/mtmd/mtmd.cpp` | Added `LLAMA_ROPE_TYPE_NONE` case to rope-type switch; internal fix, no project changes required | ## Build Commands diff --git a/CMakeLists.txt b/CMakeLists.txt index 86c365a3..a959183c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -97,7 +97,7 @@ set(GGML_AVX512 OFF CACHE BOOL "" FORCE) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b8887 + GIT_TAG b8913 ) FetchContent_MakeAvailable(llama.cpp) diff --git a/README.md b/README.md index a264f7df..e2632d4a 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ ![Java 8+](https://img.shields.io/badge/Java-8%2B-informational) -[![llama.cpp b8887](https://img.shields.io/badge/llama.cpp-%23b8887-informational)](https://github.com/ggml-org/llama.cpp/releases/tag/b8887) +[![llama.cpp b8913](https://img.shields.io/badge/llama.cpp-%23b8913-informational)](https://github.com/ggml-org/llama.cpp/releases/tag/b8913) # Java Bindings for [llama.cpp](https://github.com/ggerganov/llama.cpp) diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 61fa5826..fd606d8b 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -94,17 +94,20 @@ static bool server_task_type_need_logits(server_task_type task_type) { } struct slot_params { - bool stream = true; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt - bool return_tokens = false; - - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = - 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict - int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters - - int64_t t_max_prompt_ms = -1; // TODO: implement + bool stream = true; + bool include_usage = false; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + bool return_progress = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters + int32_t n_cmpl = 1; // number of completions to generate from this prompt + int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) + + int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit std::vector lora; @@ -139,7 +142,7 @@ struct slot_params { auto grammar_triggers = json::array(); for (const auto &trigger : sampling.grammar_triggers) { - server_grammar_trigger ct(std::move(trigger)); + server_grammar_trigger ct(trigger); grammar_triggers.push_back(ct.to_json()); } @@ -186,11 +189,16 @@ struct slot_params { {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, {"generation_prompt", oaicompat_chat_syntax.generation_prompt}, {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.ngram_size_n", speculative.ngram_size_n}, + {"speculative.ngram_size_m", speculative.ngram_size_m}, + {"speculative.ngram_m_hits", speculative.ngram_min_hits}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"backend_sampling", sampling.backend_sampling}, {"lora", lora}, }; } @@ -238,23 +246,33 @@ struct server_task { slot_params defaults; defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; + defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; + defaults.cache_prompt = params_base.cache_prompt; + defaults.antiprompt = params_base.antiprompt; + defaults.n_cache_reuse = params_base.n_cache_reuse; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; params.timings_per_token = json_value(data, "timings_per_token", false); - params.stream = json_value(data, "stream", false); - params.cache_prompt = json_value(data, "cache_prompt", true); - params.return_tokens = json_value(data, "return_tokens", false); - params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); - params.n_indent = json_value(data, "n_indent", defaults.n_indent); - params.n_keep = json_value(data, "n_keep", defaults.n_keep); - params.n_discard = json_value(data, "n_discard", defaults.n_discard); - // params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: - // implement - params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); + params.stream = json_value(data, "stream", false); + auto stream_opt = json_value(data, "stream_options", json::object()); + params.include_usage = json_value(stream_opt, "include_usage", false); + params.cache_prompt = json_value(data, "cache_prompt", defaults.cache_prompt); + params.return_tokens = json_value(data, "return_tokens", false); + params.return_progress = json_value(data, "return_progress", false); + auto max_tokens = json_value(data, "max_tokens", defaults.n_predict); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_completion_tokens", max_tokens)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + params.n_discard = std::max(0, params.n_discard); + params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1)); + params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); @@ -281,8 +299,13 @@ struct server_task { params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.sampling.adaptive_target = json_value(data, "adaptive_target", defaults.sampling.adaptive_target); + params.sampling.adaptive_decay = json_value(data, "adaptive_decay", defaults.sampling.adaptive_decay); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); + + params.speculative = defaults.speculative; params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); @@ -292,6 +315,16 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 0); params.speculative.n_max = std::max(params.speculative.n_max, 0); + params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type))); + + params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n); + params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m); + params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits); + + params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024); + params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024); + params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024); + // Use OpenAI API logprobs only if n_probs wasn't provided if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs) { params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); @@ -356,16 +389,18 @@ struct server_task { throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); } } else { - { - const std::string grammar_str = - json_value(data, "grammar", common_grammar_value(defaults.sampling.grammar)); - if (!grammar_str.empty()) { - params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, grammar_str}; + params.sampling.grammar = defaults.sampling.grammar; + + std::string grammar_str = json_value(data, "grammar", std::string()); + if (!grammar_str.empty()) { + std::string grammar_type = json_value(data, "grammar_type", std::string()); + if (grammar_type == "tool_calls") { + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_TOOL_CALLS, grammar_str}; } else { - params.sampling.grammar = {}; + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, grammar_str}; } + SRV_DBG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(params.sampling.grammar).c_str()); } - SRV_DBG("Grammar: %s\n", common_grammar_value(params.sampling.grammar).c_str()); params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); } @@ -378,11 +413,20 @@ struct server_task { } else { params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; } - params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; + common_reasoning_format reasoning_format = params_base.reasoning_format; + if (data.contains("reasoning_format")) { + reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); + } + params.oaicompat_chat_syntax.reasoning_format = reasoning_format; params.oaicompat_chat_syntax.reasoning_in_content = - params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); + params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); params.oaicompat_chat_syntax.generation_prompt = json_value(data, "generation_prompt", std::string()); + params.sampling.generation_prompt = params.oaicompat_chat_syntax.generation_prompt; + SRV_DBG("Generation prompt: '%s'\n", params.oaicompat_chat_syntax.generation_prompt.c_str()); params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); + if (data.contains("chat_parser")) { + params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get()); + } } { @@ -443,9 +487,31 @@ struct server_task { } } + // Parse reasoning budget sampler parameters + { + const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1); + const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string()); + const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string()); + const auto message = json_value(data, "reasoning_budget_message", std::string()); + params.sampling.reasoning_budget_tokens = budget; + + if (!start_tag.empty()) { + params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true); + } + if (!end_tag.empty()) { + params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true); + params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true); + + SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n", + budget, params.sampling.generation_prompt.c_str(), + params.sampling.reasoning_budget_start.size(), + params.sampling.reasoning_budget_end.size(), + params.sampling.reasoning_budget_forced.size()); + } + } + { params.sampling.logit_bias.clear(); - params.ignore_eos = json_value(data, "ignore_eos", false); const auto &logit_bias = data.find("logit_bias"); if (logit_bias != data.end() && logit_bias->is_array()) { @@ -475,6 +541,43 @@ struct server_task { } } } + } else if (logit_bias != data.end() && logit_bias->is_object()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto &el : logit_bias->items()) { + float bias; + const auto &key = el.key(); + const auto &value = el.value(); + if (value.is_number()) { + bias = value.get(); + } else if (value.is_boolean() && !value.get()) { + bias = -INFINITY; + } else { + continue; + } + + char *end; + llama_token tok = strtol(key.c_str(), &end, 10); + if (*end == 0) { + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else { + auto toks = common_tokenize(vocab, key, false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + + params.ignore_eos = json_value(data, "ignore_eos", false); + if (params.ignore_eos) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (llama_token tok = 0; tok < n_vocab; ++tok) { + if (llama_vocab_is_eog(vocab, tok)) { + params.sampling.logit_bias.push_back({tok, -INFINITY}); + } + } } } @@ -489,6 +592,9 @@ struct server_task { } } } + if (params.antiprompt.empty()) { + params.antiprompt = defaults.antiprompt; + } } { @@ -508,6 +614,10 @@ struct server_task { params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : *params_base.model_alias.begin(); params.oaicompat_model = json_value(data, "model", model_name); + if (params.n_cmpl > params_base.n_parallel) { + throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np"); + } + return params; } @@ -522,6 +632,8 @@ struct server_task { }; struct result_timings { + int32_t cache_n = -1; + int32_t prompt_n = -1; double prompt_ms; double prompt_per_token_ms; @@ -538,8 +650,9 @@ struct result_timings { json to_json() const { json base = { - {"prompt_n", prompt_n}, - {"prompt_ms", prompt_ms}, + {"cache_n", cache_n}, + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, {"prompt_per_token_ms", prompt_per_token_ms}, {"prompt_per_second", prompt_per_second}, @@ -558,6 +671,22 @@ struct result_timings { } }; +struct result_prompt_progress { + int32_t total = 0; + int32_t cache = 0; + int32_t processed = 0; + int64_t time_ms = 0; + + json to_json() const { + return json{ + {"total", total}, + {"cache", cache}, + {"processed", processed}, + {"time_ms", time_ms}, + }; + } +}; + struct server_task_result { int id = -1; int id_slot = -1; @@ -652,12 +781,14 @@ struct server_task_result_cmpl_final : server_task_result { llama_tokens tokens; bool stream; + bool include_usage; result_timings timings; std::string prompt; bool truncated; int32_t n_decoded; int32_t n_prompt_tokens; + int32_t n_prompt_tokens_cache; int32_t n_tokens_cached; bool has_new_line; std::string stopping_word; @@ -722,6 +853,15 @@ struct server_task_result_cmpl_final : server_task_result { return response_fields.empty() ? res : json_get_nested_values(response_fields, res); } + json usage_json_oaicompat() { + return json{ + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + {"prompt_tokens_details", json{{"cached_tokens", n_prompt_tokens_cache}}}, + }; + } + json to_json_oaicompat() { std::time_t t = std::time(0); json logprobs = json(nullptr); // OAI default to null @@ -741,9 +881,7 @@ struct server_task_result_cmpl_final : server_task_result { {"model", oaicompat_model}, {"system_fingerprint", std::string(llama_build_info())}, {"object", "text_completion"}, - {"usage", json{{"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}}}, + {"usage", usage_json_oaicompat()}, {"id", oaicompat_cmpl_id}}; // extra fields for debugging purposes @@ -785,9 +923,7 @@ struct server_task_result_cmpl_final : server_task_result { {"model", oaicompat_model}, {"system_fingerprint", std::string(llama_build_info())}, {"object", "chat.completion"}, - {"usage", json{{"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}}}, + {"usage", usage_json_oaicompat()}, {"id", oaicompat_cmpl_id}}; // extra fields for debugging purposes @@ -836,14 +972,21 @@ struct server_task_result_cmpl_final : server_task_result { {"model", oaicompat_model}, {"system_fingerprint", std::string(llama_build_info())}, {"object", "chat.completion.chunk"}, - {"usage", - json{ - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, }); + if (include_usage) { + // OpenAI spec: separate final chunk with empty choices and usage + deltas.push_back({ + {"choices", json::array()}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", std::string(llama_build_info())}, + {"object", "chat.completion.chunk"}, + {"usage", usage_json_oaicompat()}, + }); + } + if (timings.prompt_n >= 0) { deltas.back().push_back({"timings", timings.to_json()}); } @@ -865,10 +1008,13 @@ struct server_task_result_cmpl_partial : server_task_result { int32_t n_decoded; int32_t n_prompt_tokens; + int32_t n_prompt_tokens_cache; bool post_sampling_probs; + bool is_progress = false; completion_token_output prob_output; result_timings timings; + result_prompt_progress progress; // OAI-compat fields bool verbose = false; @@ -911,6 +1057,9 @@ struct server_task_result_cmpl_partial : server_task_result { if (timings.prompt_n > 0) { res.push_back({"timings", timings.to_json()}); } + if (is_progress) { + res.push_back({"prompt_progress", progress.to_json()}); + } if (!prob_output.probs.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); @@ -945,6 +1094,9 @@ struct server_task_result_cmpl_partial : server_task_result { if (timings.prompt_n >= 0) { res.push_back({"timings", timings.to_json()}); } + if (is_progress) { + res.push_back({"prompt_progress", progress.to_json()}); + } return res; } @@ -972,7 +1124,7 @@ struct server_task_result_cmpl_partial : server_task_result { }); }; // We have to send an initial update to conform to openai behavior - if (first) { + if (first || is_progress) { add_delta({ {"role", "assistant"}, {"content", nullptr}, @@ -995,6 +1147,9 @@ struct server_task_result_cmpl_partial : server_task_result { if (timings.prompt_n >= 0) { deltas[deltas.size() - 1].push_back({"timings", timings.to_json()}); } + if (is_progress) { + deltas[deltas.size() - 1].push_back({"prompt_progress", progress.to_json()}); + } } return deltas; @@ -1189,7 +1344,8 @@ struct server_slot { int32_t n_predict = -1; // TODO: disambiguate from params.n_predict // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated - int32_t n_prompt_tokens = 0; + int32_t n_prompt_tokens = 0; + int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; // input prompt tokens @@ -1240,7 +1396,8 @@ struct server_slot { void reset() { SLT_DBG(*this, "%s", "\n"); - n_prompt_tokens = 0; + n_prompt_tokens = 0; + n_prompt_tokens_cache = 0; last_nl_pos = 0; generated_text = ""; has_new_line = false; @@ -1318,6 +1475,7 @@ struct server_slot { result_timings get_timings() const { result_timings timings; + timings.cache_n = n_prompt_tokens_cache; timings.prompt_n = n_prompt_tokens_processed; timings.prompt_ms = t_prompt_processing; timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; @@ -2387,21 +2545,31 @@ struct server_context { return true; } - void send_partial_response(server_slot &slot, const completion_token_output &tkn) { + void send_partial_response(server_slot &slot, const completion_token_output &tkn, bool is_progress = false) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; - res->content = tkn.text_to_send; - res->tokens = {tkn.tok}; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->post_sampling_probs = slot.params.post_sampling_probs; + if (is_progress) { + res->is_progress = true; + res->progress.total = slot.n_prompt_tokens; + res->progress.cache = slot.n_prompt_tokens_cache; + res->progress.processed = slot.n_prompt_tokens_processed; + res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; + } else { + res->content = tkn.text_to_send; + res->tokens = {tkn.tok}; + } - res->verbose = slot.params.verbose; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; slot.update_chat_msg(res->oaicompat_msg_diffs); @@ -2431,18 +2599,20 @@ struct server_context { res->prompt = slot.prompt_tokens.detokenize(ctx, true); res->response_fields = std::move(slot.params.response_fields); - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->n_tokens_cached = slot.n_past; + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache; + res->n_tokens_cached = slot.n_past; res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; res->post_sampling_probs = slot.params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->stream = slot.params.stream; - res->oaicompat = slot.params.oaicompat; + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->include_usage = slot.params.include_usage; + res->oaicompat = slot.params.oaicompat; res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); @@ -3190,6 +3360,7 @@ struct server_context { slot.n_past--; } + slot.n_prompt_tokens_cache = slot.n_past; slot.n_prompt_tokens_processed = 0; } @@ -3206,7 +3377,8 @@ struct server_context { llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); // there is no common part left - slot.n_past = 0; + slot.n_past = 0; + slot.n_prompt_tokens_cache = 0; } SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); @@ -3400,6 +3572,13 @@ struct server_context { n_batch = llama_n_batch(ctx); for (auto &slot : slots) { + // optionally send prompt processing progress + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.params.stream && slot.params.return_progress) { + send_partial_response(slot, {}, true); + } + } + if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { continue; // continue loop of slots }