From 3375c20b859d21a1dfe9b38ab6579747bf439173 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Nov 2025 16:55:27 +0100 Subject: [PATCH 1/8] git mv --- tools/server/{server.cpp => server-context.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tools/server/{server.cpp => server-context.cpp} (100%) diff --git a/tools/server/server.cpp b/tools/server/server-context.cpp similarity index 100% rename from tools/server/server.cpp rename to tools/server/server-context.cpp From 0150602bdb29ae106a85475c11c04d5cc1209c6b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Nov 2025 18:08:12 +0100 Subject: [PATCH 2/8] add server-context.h --- tools/server/CMakeLists.txt | 2 + tools/server/server-context.cpp | 1065 ++++++++++++------------------- tools/server/server-context.h | 266 ++++++++ tools/server/server.cpp | 232 +++++++ 4 files changed, 902 insertions(+), 663 deletions(-) create mode 100644 tools/server/server-context.h create mode 100644 tools/server/server.cpp diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 7fbca320162..d8623621f3f 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -21,6 +21,8 @@ set(TARGET_SRCS server-queue.h server-common.cpp server-common.h + server-context.cpp + server-context.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 96b2df27f7e..aadd0671590 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1,3 +1,4 @@ +#include "server-context.h" #include "server-common.h" #include "server-http.h" #include "server-task.h" @@ -31,22 +32,6 @@ using json = nlohmann::ordered_json; -constexpr int HTTP_POLLING_SECONDS = 1; - -// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future - SLOT_STATE_PROCESSING_PROMPT, - SLOT_STATE_DONE_PROMPT, - SLOT_STATE_GENERATING, -}; - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - static bool server_task_type_need_embd(server_task_type task_type) { switch (task_type) { case SERVER_TASK_TYPE_EMBEDDING: @@ -411,108 +396,78 @@ struct server_slot { } }; -struct server_metrics { - int64_t t_start = 0; +void server_slots_t::clear() { + for (auto & slot_ptr : data) { + delete slot_ptr; + } + data.clear(); +} - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; +server_slot & server_slots_t::create() { + auto instance = new server_slot(); + data.push_back(instance); + return *instance; +} - uint64_t n_tokens_max = 0; +server_slots_t::~server_slots_t() { + clear(); +} - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; +// +// server_metrics +// - void init() { - t_start = ggml_time_us(); - } +void server_metrics::init() { + t_start = ggml_time_us(); +} - void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; +void server_metrics::on_prompt_eval(const server_slot & slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); - } + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); +} - void on_prediction(const server_slot & slot) { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; - } +void server_metrics::on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; +} - void on_decoded(const std::vector & slots) { - n_decode_total++; - for (const auto & slot : slots) { - if (slot.is_processing()) { - n_busy_slots_total++; - } - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); +void server_metrics::on_decoded(const server_slots_t & slots) { + n_decode_total++; + for (size_t i = 0; i < slots.size(); i++) { + const auto & slot = slots[i]; + if (slot.is_processing()) { + n_busy_slots_total++; } + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); } +} - void reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; - } -}; - -struct server_context { - common_params params_base; - - // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; - - llama_model * model = nullptr; - llama_context * ctx = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - const llama_vocab * vocab = nullptr; - bool vocab_dft_compatible = true; - - llama_model * model_dft = nullptr; - - llama_context_params cparams_dft; - - llama_batch batch {}; - - bool add_bos_token = true; - - int32_t n_ctx; // total context for all clients / slots - - // slots / clients - std::vector slots; - - int slots_debug = 0; - - server_queue queue_tasks; - server_response queue_results; - - std::unique_ptr prompt_cache; +void server_metrics::reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; +} - server_metrics metrics; - // Necessary similarity of prompt for slot selection - float slot_prompt_similarity = 0.0f; +// +// server_context +// - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; +// TODO @ngxson : the only purpose of this extern "C" is to keep the first indentation level +// this was done to avoid massive changes in while doing the recent refactoring, avoiding merge conflicts +// we can remove this once things are more stable - ~server_context() { +extern "C" { + server_context::~server_context() { mtmd_free(mctx); // Clear any sampling context @@ -533,7 +488,7 @@ struct server_context { } // load the model and initialize llama_context - bool load_model(const common_params & params) { + bool server_context::load_model(const common_params & params) { SRV_INF("loading model '%s'\n", params.model.path.c_str()); params_base = params; @@ -653,7 +608,7 @@ struct server_context { } // initialize slots and server-related data - void init() { + void server_context::init() { SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); const int n_ctx_train = llama_model_n_ctx_train(model); @@ -665,7 +620,7 @@ struct server_context { } for (int i = 0; i < params_base.n_parallel; i++) { - server_slot slot; + server_slot & slot = slots.create(); slot.id = i; slot.ctx = ctx; @@ -700,8 +655,6 @@ struct server_context { }; slot.reset(); - - slots.push_back(std::move(slot)); } { @@ -759,7 +712,7 @@ struct server_context { common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); } - server_slot * get_slot_by_id(int id) { + server_slot * server_context::get_slot_by_id(int id) { for (server_slot & slot : slots) { if (slot.id == id) { return &slot; @@ -769,7 +722,7 @@ struct server_context { return nullptr; } - server_slot * get_available_slot(const server_task & task) { + server_slot * server_context::get_available_slot(const server_task & task) { server_slot * ret = nullptr; bool update_cache = false; @@ -873,7 +826,7 @@ struct server_context { return ret; } - void clear_slot(server_slot & slot) const { + void server_context::clear_slot(server_slot & slot) const { GGML_ASSERT(!slot.is_processing()); SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); @@ -887,7 +840,7 @@ struct server_context { // - smarter decision which slot to clear (LRU or longest prompt?) // - move slot to level 2 cache instead of removing? // - instead of purging, try to store and resume later? - bool try_clear_idle_slots() { + bool server_context::try_clear_idle_slots() { bool res = false; if (!params_base.kv_unified) { @@ -914,7 +867,7 @@ struct server_context { return res; } - bool launch_slot_with_task(server_slot & slot, server_task && task) { + bool server_context::launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); if (!are_lora_equal(task.params.lora, slot.lora)) { @@ -1016,7 +969,7 @@ struct server_context { return true; } - bool process_token(completion_token_output & result, server_slot & slot) { + bool server_context::process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = result.text_to_send; slot.sampled = result.tok; @@ -1147,7 +1100,7 @@ struct server_context { return slot.has_next_token; // continue } - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { + void server_context::populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { size_t n_probs = slot.task->params.sampling.n_probs; size_t n_vocab = llama_vocab_n_tokens(vocab); @@ -1197,15 +1150,11 @@ struct server_context { } } - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, error, type); - } - - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + void server_context::send_error(const server_slot & slot, const std::string & error, const enum error_type type) { send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); } - void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { + void server_context::send_error(const int id_task, const std::string & error, const enum error_type type, const int32_t n_prompt_tokens, const int32_t n_ctx) { SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { @@ -1223,7 +1172,7 @@ struct server_context { } // if multimodal is enabled, send an error and return false - bool check_no_mtmd(const int id_task) { + bool server_context::check_no_mtmd(const int id_task) { if (mctx) { send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); return false; @@ -1231,7 +1180,7 @@ struct server_context { return true; } - void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { + void server_context::send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { auto res = std::make_unique(); res->id = slot.task->id; @@ -1272,7 +1221,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_final_response(server_slot & slot) { + void server_context::send_final_response(server_slot & slot) { auto res = std::make_unique(); res->id = slot.task->id; @@ -1323,7 +1272,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, const llama_batch & batch) { + void server_context::send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); res->id = slot.task->id; res->index = slot.task->index; @@ -1368,7 +1317,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, const llama_batch & batch) { + void server_context::send_rerank(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); res->id = slot.task->id; res->index = slot.task->index; @@ -1403,7 +1352,7 @@ struct server_context { // Functions to process the task // - void process_single_task(server_task && task) { + void server_context::process_single_task(server_task && task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: case SERVER_TASK_TYPE_INFILL: @@ -1623,7 +1572,7 @@ struct server_context { } } - void update_slots() { + void server_context::update_slots() { // check if all slots are idle { bool all_idle = true; @@ -2472,17 +2421,14 @@ struct server_context { SRV_DBG("%s", "run slots completed\n"); } - json model_meta() const { - return json { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; + // + // Utility functions + // + + int server_context::get_slot_n_ctx() const { + return slots[0].n_ctx; } -}; +} // generator-like API for server responses, support pooling connection state and aggregating results @@ -2597,24 +2543,21 @@ struct server_res_generator : server_http_res { } }; -struct server_routes { - const common_params & params; - server_context & ctx_server; - server_http_context & ctx_http; // for reading is_ready - server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) - : params(params), ctx_server(ctx_server), ctx_http(ctx_http) {} -public: - // handlers using lambda function, so that they can capture `this` without `std::bind` - server_http_context::handler_t get_health = [this](const server_http_req &) { +// +// server_routes +// + +void server_routes::init_routes() { + this->get_health = [this](const server_http_req &) { // error and loading states are handled by middleware auto res = std::make_unique(ctx_server); res->ok({{"status", "ok"}}); return res; }; - server_http_context::handler_t get_metrics = [this](const server_http_req &) { + this->get_metrics = [this](const server_http_req &) { auto res = std::make_unique(ctx_server); if (!params.endpoint_metrics) { res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); @@ -2718,7 +2661,7 @@ struct server_routes { return res; }; - server_http_context::handler_t get_slots = [this](const server_http_req & req) { + this->get_slots = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); if (!params.endpoint_slots) { res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); @@ -2759,7 +2702,7 @@ struct server_routes { return res; }; - server_http_context::handler_t post_slots = [this](const server_http_req & req) { + this->post_slots = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); if (params.slot_save_path.empty()) { res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); @@ -2790,7 +2733,7 @@ struct server_routes { } }; - server_http_context::handler_t get_props = [this](const server_http_req &) { + this->get_props = [this](const server_http_req &) { auto res = std::make_unique(ctx_server); json default_generation_settings_for_props; @@ -2801,7 +2744,7 @@ struct server_routes { default_generation_settings_for_props = json { {"params", params.to_json(true)}, - {"n_ctx", ctx_server.slots[0].n_ctx}, + {"n_ctx", ctx_server.get_slot_n_ctx()}, }; } @@ -2834,7 +2777,7 @@ struct server_routes { return res; }; - server_http_context::handler_t post_props = [this](const server_http_req &) { + this->post_props = [this](const server_http_req &) { auto res = std::make_unique(ctx_server); if (!params.endpoint_props) { res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); @@ -2846,7 +2789,7 @@ struct server_routes { return res; }; - server_http_context::handler_t get_api_show = [this](const server_http_req &) { + this->get_api_show = [this](const server_http_req &) { auto res = std::make_unique(ctx_server); bool has_mtmd = ctx_server.mctx != nullptr; json data = { @@ -2855,7 +2798,7 @@ struct server_routes { }, { "model_info", { - { "llama.context_length", ctx_server.slots.back().n_ctx, }, + { "llama.context_length", ctx_server.get_slot_n_ctx() }, } }, {"modelfile", ""}, @@ -2877,7 +2820,7 @@ struct server_routes { return res; }; - server_http_context::handler_t post_infill = [this](const server_http_req & req) { + this->post_infill = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); // check model compatibility std::string err; @@ -2941,7 +2884,7 @@ struct server_routes { data.at("input_extra"), ctx_server.params_base.n_batch, ctx_server.params_base.n_predict, - ctx_server.slots[0].n_ctx, // TODO: there should be a better way + ctx_server.get_slot_n_ctx(), ctx_server.params_base.spm_infill, tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. ); @@ -2955,7 +2898,7 @@ struct server_routes { TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible }; - server_http_context::handler_t post_completions = [this](const server_http_req & req) { + this->post_completions = [this](const server_http_req & req) { std::vector files; // dummy const json body = json::parse(req.body); return handle_completions_impl( @@ -2966,7 +2909,7 @@ struct server_routes { TASK_RESPONSE_TYPE_NONE); }; - server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) { + this->post_completions_oai = [this](const server_http_req & req) { std::vector files; // dummy const json body = json::parse(req.body); return handle_completions_impl( @@ -2977,7 +2920,7 @@ struct server_routes { TASK_RESPONSE_TYPE_OAI_CMPL); }; - server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) { + this->post_chat_completions = [this](const server_http_req & req) { std::vector files; json body = json::parse(req.body); json body_parsed = oaicompat_chat_params_parse( @@ -2992,7 +2935,7 @@ struct server_routes { TASK_RESPONSE_TYPE_OAI_CHAT); }; - server_http_context::handler_t post_anthropic_messages = [this](const server_http_req & req) { + this->post_anthropic_messages = [this](const server_http_req & req) { std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); json body_parsed = oaicompat_chat_params_parse( @@ -3007,7 +2950,7 @@ struct server_routes { TASK_RESPONSE_TYPE_ANTHROPIC); }; - server_http_context::handler_t post_anthropic_count_tokens = [this](const server_http_req & req) { + this->post_anthropic_count_tokens = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); @@ -3024,7 +2967,7 @@ struct server_routes { }; // same with handle_chat_completions, but without inference part - server_http_context::handler_t post_apply_template = [this](const server_http_req & req) { + this->post_apply_template = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); std::vector files; // dummy, unused json body = json::parse(req.body); @@ -3036,7 +2979,7 @@ struct server_routes { return res; }; - server_http_context::handler_t get_models = [this](const server_http_req &) { + this->get_models = [this](const server_http_req &) { auto res = std::make_unique(ctx_server); bool is_model_ready = ctx_http.is_ready.load(); json model_meta = nullptr; @@ -3083,7 +3026,7 @@ struct server_routes { return res; }; - server_http_context::handler_t post_tokenize = [this](const server_http_req & req) { + this->post_tokenize = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); const json body = json::parse(req.body); json tokens_response = json::array(); @@ -3124,7 +3067,7 @@ struct server_routes { return res; }; - server_http_context::handler_t post_detokenize = [this](const server_http_req & req) { + this->post_detokenize = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); const json body = json::parse(req.body); @@ -3138,15 +3081,15 @@ struct server_routes { return res; }; - server_http_context::handler_t post_embeddings = [this](const server_http_req & req) { + this->post_embeddings = [this](const server_http_req & req) { return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE); }; - server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) { + this->post_embeddings_oai = [this](const server_http_req & req) { return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD); }; - server_http_context::handler_t post_rerank = [this](const server_http_req & req) { + this->post_rerank = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); @@ -3226,7 +3169,7 @@ struct server_routes { return res; }; - server_http_context::handler_t get_lora_adapters = [this](const server_http_req &) { + this->get_lora_adapters = [this](const server_http_req &) { auto res = std::make_unique(ctx_server); json result = json::array(); const auto & loras = ctx_server.params_base.lora_adapters; @@ -3257,7 +3200,7 @@ struct server_routes { return res; }; - server_http_context::handler_t post_lora_adapters = [this](const server_http_req & req) { + this->post_lora_adapters = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); const json body = json::parse(req.body); if (!body.is_array()) { @@ -3287,575 +3230,371 @@ struct server_routes { res->ok(result->to_json()); return res; }; +} -private: - std::unique_ptr handle_completions_impl( - server_task_type type, - const json & data, - const std::vector & files, - const std::function & should_stop, - task_response_type res_type) { - GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - - auto res = std::make_unique(ctx_server); - auto completion_id = gen_chatcmplid(); - auto & rd = res->rd; - - try { - std::vector tasks; - - const auto & prompt = data.at("prompt"); - // TODO: this log can become very long, put it behind a flag or think about a more compact format - //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - - // process prompt - std::vector inputs; - - if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { - // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } else { - // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - } - tasks.reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - - task.tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( - ctx_server.ctx, - ctx_server.params_base, - data); - task.id_slot = json_value(data, "id_slot", -1); +std::unique_ptr server_routes::handle_completions_impl( + server_task_type type, + const json & data, + const std::vector & files, + const std::function & should_stop, + task_response_type res_type) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - // OAI-compat - task.params.res_type = res_type; - task.params.oaicompat_cmpl_id = completion_id; - // oaicompat_model is already populated by params_from_json_cmpl + auto res = std::make_unique(ctx_server); + auto completion_id = gen_chatcmplid(); + auto & rd = res->rd; - tasks.push_back(std::move(task)); - } + try { + std::vector tasks; - rd.post_tasks(std::move(tasks)); - } catch (const std::exception & e) { - res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return res; - } + const auto & prompt = data.at("prompt"); + // TODO: this log can become very long, put it behind a flag or think about a more compact format + //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - bool stream = json_value(data, "stream", false); - - if (!stream) { - // non-stream, wait for the results - auto all_results = rd.wait_for_all(should_stop); - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - json arr = json::array(); - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - arr.push_back(res->to_json()); - } - // if single request, return single object instead of array - res->ok(arr.size() == 1 ? arr[0] : arr); - } + // process prompt + std::vector inputs; + if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { + // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); } else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd.next(should_stop); - if (first_result == nullptr) { - return res; // connection is closed - } else if (first_result->is_error()) { - res->error(first_result->to_json()); - return res; - } else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - - // next responses are streamed - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - res->data = format_anthropic_sse(first_result->to_json()); - } else { - res->data = format_oai_sse(first_result->to_json()); // to be sent immediately - } - res->status = 200; - res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { - if (should_stop()) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - if (!res_this->data.empty()) { - // flush the first chunk - output = std::move(res_this->data); - res_this->data.clear(); - return true; - } - - server_response_reader & rd = res_this->rd; - - // check if there is more data - if (!rd.has_next()) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - // Anthropic doesn't send [DONE], message_stop was already sent - output = ""; - } else if (res_type != TASK_RESPONSE_TYPE_NONE) { - output = "data: [DONE]\n\n"; - } else { - output = ""; - } - SRV_DBG("%s", "all results received, terminating stream\n"); - return false; // no more data, terminate - } - - // receive subsequent results - auto result = rd.next(should_stop); - if (result == nullptr) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - // send the results - json res_json = result->to_json(); - if (result->is_error()) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - output = format_anthropic_sse({ - {"event", "error"}, - {"data", res_json}, - }); - } else { - output = format_oai_sse(json {{ "error", res_json }}); - } - SRV_DBG("%s", "error received during streaming, terminating stream\n"); - return false; // terminate on error - } else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - output = format_anthropic_sse(res_json); - } else { - output = format_oai_sse(res_json); - } - } - - // has next data, continue - return true; - }; - } - - return res; - } - - std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_SAVE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + // Everything else, including multimodal completions. + inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_slot = json_value(data, "id_slot", -1); - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + // OAI-compat + task.params.res_type = res_type; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl - if (result->is_error()) { - res->error(result->to_json()); - return res; + tasks.push_back(std::move(task)); } - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); + rd.post_tasks(std::move(tasks)); + } catch (const std::exception & e) { + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); return res; } - std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot) { - auto res = std::make_unique(ctx_server); - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - } + bool stream = json_value(data, "stream", false); - std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding) { - res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + if (!stream) { + // non-stream, wait for the results + auto all_results = rd.wait_for_all(should_stop); + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); return res; + } else { + json arr = json::array(); + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + arr.push_back(res->to_json()); + } + // if single request, return single object instead of array + res->ok(arr.size() == 1 ? arr[0] : arr); } - if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + } else { + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 + server_task_result_ptr first_result = rd.next(should_stop); + if (first_result == nullptr) { + return res; // connection is closed + } else if (first_result->is_error()) { + res->error(first_result->to_json()); return res; + } else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); } - const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } else if (body.contains("content")) { - res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); + // next responses are streamed + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + res->data = format_anthropic_sse(first_result->to_json()); } else { - res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return res; + res->data = format_oai_sse(first_result->to_json()); // to be sent immediately } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } else if (format != "float") { - res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return res; + res->status = 200; + res->content_type = "text/event-stream"; + res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { + if (should_stop()) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met } - } - auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - for (const auto & tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return res; + if (!res_this->data.empty()) { + // flush the first chunk + output = std::move(res_this->data); + res_this->data.clear(); + return true; } - } - int embd_normalize = 2; // default to Euclidean/L2 norm - if (body.count("embd_normalize") != 0) { - embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); - } - } + server_response_reader & rd = res_this->rd; - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tokenized_prompts[i]); + // check if there is more data + if (!rd.has_next()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + // Anthropic doesn't send [DONE], message_stop was already sent + output = ""; + } else if (res_type != TASK_RESPONSE_TYPE_NONE) { + output = "data: [DONE]\n\n"; + } else { + output = ""; + } + SRV_DBG("%s", "all results received, terminating stream\n"); + return false; // no more data, terminate + } - // OAI-compat - task.params.res_type = res_type; - task.params.embd_normalize = embd_normalize; + // receive subsequent results + auto result = rd.next(should_stop); + if (result == nullptr) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } - tasks.push_back(std::move(task)); + // send the results + json res_json = result->to_json(); + if (result->is_error()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse({ + {"event", "error"}, + {"data", res_json}, + }); + } else { + output = format_oai_sse(json {{ "error", res_json }}); + } + SRV_DBG("%s", "error received during streaming, terminating stream\n"); + return false; // terminate on error + } else { + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse(res_json); + } else { + output = format_oai_sse(res_json); + } } - rd.post_tasks(std::move(tasks)); - } - // wait for the results - auto all_results = rd.wait_for_all(req.should_stop); + // has next data, continue + return true; + }; + } - // collect results - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } + return res; +} - // write JSON response - json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - res->ok(root); +std::unique_ptr server_routes::handle_slots_save(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return res; } -}; + std::string filepath = params.slot_save_path + filename; -static std::function shutdown_handler; -static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; -static inline void signal_handler(int signal) { - if (is_terminating.test_and_set()) { - // in case it hangs, we can force terminate the server by hitting Ctrl+C twice - // this is for better developer experience, we can remove when the server is stable enough - fprintf(stderr, "Received second interrupt, terminating immediately.\n"); - exit(1); + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); } - shutdown_handler(signal); -} + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); -// wrapper function that handles exceptions and logs errors -// this is to make sure handler_t never throws exceptions; instead, it returns an error response -static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) { - return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr { - std::string message; - try { - return func(req); - } catch (const std::exception & e) { - message = e.what(); - } catch (...) { - message = "unknown error"; - } - - auto res = std::make_unique(); - res->status = 500; - try { - json error_data = format_error_response(message, ERROR_TYPE_SERVER); - res->status = json_value(error_data, "code", 500); - res->data = safe_json_to_str({{ "error", error_data }}); - LOG_WRN("got exception: %s\n", res->data.c_str()); - } catch (const std::exception & e) { - LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); - res->data = "Internal Server Error"; - } + if (result->is_error()) { + res->error(result->to_json()); return res; - }; -} + } -int main(int argc, char ** argv) { - // own arguments required by this example - common_params params; + res->ok(result->to_json()); + return res; +} - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { - return 1; +std::unique_ptr server_routes::handle_slots_restore(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; } + std::string filepath = params.slot_save_path + filename; - // TODO: should we have a separate n_parallel parameter for the server? - // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 - // TODO: this is a common configuration that is suitable for most local use cases - // however, overriding the parameters is a bit confusing - figure out something more intuitive - if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { - LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; - params.n_parallel = 4; - params.kv_unified = true; + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); } - common_init(); + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); - // struct that contains llama context and inference - server_context ctx_server; - - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + if (result->is_error()) { + res->error(result->to_json()); + return res; + } - llama_backend_init(); - llama_numa_init(params.numa); + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; +} - LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); - LOG_INF("\n"); - LOG_INF("%s\n", common_params_get_system_info(params).c_str()); - LOG_INF("\n"); +std::unique_ptr server_routes::handle_slots_erase(const server_http_req &, int id_slot) { + auto res = std::make_unique(ctx_server); + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = task_id; + task.slot_action.slot_id = id_slot; - server_http_context ctx_http; - if (!ctx_http.init(params)) { - LOG_ERR("%s: failed to initialize HTTP server\n", __func__); - return 1; + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); } - // - // Router - // + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); - // register API routes - server_routes routes(params, ctx_server, ctx_http); - - ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) - ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) - ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); - ctx_http.get ("/props", ex_wrapper(routes.get_props)); - ctx_http.post("/props", ex_wrapper(routes.post_props)); - ctx_http.post("/api/show", ex_wrapper(routes.get_api_show)); - ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) - ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) - ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check) - ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy - ctx_http.post("/completions", ex_wrapper(routes.post_completions)); - ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai)); - ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); - ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); - ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint - ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API - ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting - ctx_http.post("/infill", ex_wrapper(routes.post_infill)); - ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy - ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); - ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai)); - ctx_http.post("/rerank", ex_wrapper(routes.post_rerank)); - ctx_http.post("/reranking", ex_wrapper(routes.post_rerank)); - ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank)); - ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank)); - ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize)); - ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize)); - ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template)); - // LoRA adapters hotswap - ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters)); - ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters)); - // Save & load slots - ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); - ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); + if (result->is_error()) { + res->error(result->to_json()); + return res; + } - // - // Start the server - // + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; +} - // setup clean up function, to be called before exit - auto clean_up = [&ctx_http, &ctx_server]() { - SRV_INF("%s: cleaning up before exit...\n", __func__); - ctx_http.stop(); - ctx_server.queue_results.terminate(); - llama_backend_free(); - }; +std::unique_ptr server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding) { + res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } - // start the HTTP server before loading the model to be able to serve /health requests - if (!ctx_http.start()) { - clean_up(); - LOG_ERR("%s: exiting due to HTTP server error\n", __func__); - return 1; + if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return res; } - // load the model - LOG_INF("%s: loading model\n", __func__); + const json body = json::parse(req.body); - if (!ctx_server.load_model(params)) { - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); + // for the shape of input/content, see tokenize_input_prompts() + json prompt; + if (body.count("input") != 0) { + prompt = body.at("input"); + } else if (body.contains("content")) { + res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible + prompt = body.at("content"); + } else { + res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } else if (format != "float") { + res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return res; } - LOG_ERR("%s: exiting due to model loading error\n", __func__); - return 1; } - ctx_server.init(); - ctx_http.is_ready.store(true); + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + for (const auto & tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } - LOG_INF("%s: model loaded\n", __func__); + int embd_normalize = 2; // default to Euclidean/L2 norm + if (body.count("embd_normalize") != 0) { + embd_normalize = body.at("embd_normalize"); + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + } + } - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { - ctx_server.process_single_task(std::move(task)); - }); + // create and queue the task + json responses = json::array(); + server_response_reader rd(ctx_server); + { + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - ctx_server.queue_tasks.on_update_slots([&ctx_server]() { - ctx_server.update_slots(); - }); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); - shutdown_handler = [&](int) { - // this will unblock start_loop() - ctx_server.queue_tasks.terminate(); - }; + // OAI-compat + task.params.res_type = res_type; + task.params.embd_normalize = embd_normalize; - // TODO: refactor in common/console -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = signal_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); - sigaction(SIGTERM, &sigint_action, NULL); -#elif defined (_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } - LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); - LOG_INF("%s: starting the main loop...\n", __func__); - // this call blocks the main thread until queue_tasks.terminate() is called - ctx_server.queue_tasks.start_loop(); + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } } - llama_memory_breakdown_print(ctx_server.ctx); - return 0; + // write JSON response + json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); + res->ok(root); + return res; } diff --git a/tools/server/server-context.h b/tools/server/server-context.h new file mode 100644 index 00000000000..5c48f8e8082 --- /dev/null +++ b/tools/server/server-context.h @@ -0,0 +1,266 @@ +#include "server-common.h" +#include "server-http.h" +#include "server-task.h" +#include "server-queue.h" + +#include "arg.h" +#include "common.h" +#include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" +#include "mtmd.h" +#include "mtmd-helper.h" + +#include +#include +#include +#include +#include +#include +#include + +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + +using json = nlohmann::ordered_json; + +constexpr int HTTP_POLLING_SECONDS = 1; + +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, +}; + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + +// forward declarations +struct server_slot; + +// proxy for std::vector to allow forward declaration of server_slot +struct server_slots_t { + ~server_slots_t(); + std::vector data; + size_t size() const { return data.size(); } + server_slot & operator[](size_t idx) { return *(data[idx]); } + server_slot & operator[](size_t idx) const { return *(data[idx]); } + void clear(); + server_slot & create(); + struct iterator { + typename std::vector::iterator it; + iterator(typename std::vector::iterator i) : it(i) {} + server_slot & operator*() { return **it; } + iterator & operator++() { ++it; return *this; } + bool operator!=(const iterator& other) const { return it != other.it; } + }; + iterator begin() { return iterator(data.begin()); } + iterator end() { return iterator(data.end()); } +}; + +struct server_metrics { + int64_t t_start = 0; + + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_tokens_max = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init(); + void on_prompt_eval(const server_slot & slot); + void on_prediction(const server_slot & slot); + void on_decoded(const server_slots_t & slots); + void reset_bucket(); +}; + +struct server_context { +public: + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; + + const llama_vocab * vocab = nullptr; + bool vocab_dft_compatible = true; + + // multimodal + mtmd_context * mctx = nullptr; + + server_queue queue_tasks; + server_response queue_results; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + +private: + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; + + llama_batch batch {}; + + bool add_bos_token = true; + + int32_t n_ctx; // total context for all clients / slots + + // slots / clients + server_slots_t slots; + + int slots_debug = 0; + + std::unique_ptr prompt_cache; + + server_metrics metrics; + +public: + ~server_context(); + + // load the model and initialize llama_context + bool load_model(const common_params & params); + + // initialize slots and server-related data + void init(); + + server_slot * get_slot_by_id(int id); + + server_slot * get_available_slot(const server_task & task); + + void clear_slot(server_slot & slot) const; + + // return true if at least one slot has been cleared + // TODO: improve logic + // - smarter decision which slot to clear (LRU or longest prompt?) + // - move slot to level 2 cache instead of removing? + // - instead of purging, try to store and resume later? + bool try_clear_idle_slots(); + + bool launch_slot_with_task(server_slot & slot, server_task && task); + + bool process_token(completion_token_output & result, server_slot & slot); + + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const; + + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER); + + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0); + + // if multimodal is enabled, send an error and return false + bool check_no_mtmd(const int id_task); + + void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress); + + void send_final_response(server_slot & slot); + + void send_embedding(const server_slot & slot, const llama_batch & batch); + + void send_rerank(const server_slot & slot, const llama_batch & batch); + + // + // Functions to process the task + // + + void process_single_task(server_task && task); + + void update_slots(); + + // + // Utility functions + // + + int get_slot_n_ctx() const; + + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, + }; + } +}; + + +struct server_res_generator; + +struct server_routes { + const common_params & params; + server_context & ctx_server; + server_http_context & ctx_http; // for reading is_ready + server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) + : params(params), ctx_server(ctx_server), ctx_http(ctx_http) { + init_routes(); + } + +public: + void init_routes(); + // handlers using lambda function, so that they can capture `this` without `std::bind` + server_http_context::handler_t get_health; + server_http_context::handler_t get_metrics; + server_http_context::handler_t get_slots; + server_http_context::handler_t post_slots; + server_http_context::handler_t get_props; + server_http_context::handler_t post_props; + server_http_context::handler_t get_api_show; + server_http_context::handler_t post_infill; + server_http_context::handler_t post_completions; + server_http_context::handler_t post_completions_oai; + server_http_context::handler_t post_chat_completions; + server_http_context::handler_t post_anthropic_messages; + server_http_context::handler_t post_anthropic_count_tokens; + server_http_context::handler_t post_apply_template; + server_http_context::handler_t get_models; + server_http_context::handler_t post_tokenize; + server_http_context::handler_t post_detokenize; + server_http_context::handler_t post_embeddings; + server_http_context::handler_t post_embeddings_oai; + server_http_context::handler_t post_rerank; + server_http_context::handler_t get_lora_adapters; + server_http_context::handler_t post_lora_adapters; +private: + std::unique_ptr handle_completions_impl( + server_task_type type, + const json & data, + const std::vector & files, + const std::function & should_stop, + task_response_type res_type); + std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot); + std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot); + std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot); + std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type); +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp new file mode 100644 index 00000000000..b9a9335ff26 --- /dev/null +++ b/tools/server/server.cpp @@ -0,0 +1,232 @@ +#include "server-context.h" +#include "server-http.h" + +#include "arg.h" +#include "common.h" +#include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" +#include "mtmd.h" +#include "mtmd-helper.h" + +#include +#include +#include +#include +#include +#include +#include + +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + + +static std::function shutdown_handler; +static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + +static inline void signal_handler(int signal) { + if (is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + shutdown_handler(signal); +} + +// wrapper function that handles exceptions and logs errors +// this is to make sure handler_t never throws exceptions; instead, it returns an error response +static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) { + return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr { + std::string message; + try { + return func(req); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "unknown error"; + } + + auto res = std::make_unique(); + res->status = 500; + try { + json error_data = format_error_response(message, ERROR_TYPE_SERVER); + res->status = json_value(error_data, "code", 500); + res->data = safe_json_to_str({{ "error", error_data }}); + LOG_WRN("got exception: %s\n", res->data.c_str()); + } catch (const std::exception & e) { + LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + res->data = "Internal Server Error"; + } + return res; + }; +} + +int main(int argc, char ** argv) { + // own arguments required by this example + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { + return 1; + } + + // TODO: should we have a separate n_parallel parameter for the server? + // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 + // TODO: this is a common configuration that is suitable for most local use cases + // however, overriding the parameters is a bit confusing - figure out something more intuitive + if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { + LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); + + params.n_parallel = 4; + params.kv_unified = true; + } + + common_init(); + + // struct that contains llama context and inference + server_context ctx_server; + + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; + + llama_backend_init(); + llama_numa_init(params.numa); + + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + + server_http_context ctx_http; + if (!ctx_http.init(params)) { + LOG_ERR("%s: failed to initialize HTTP server\n", __func__); + return 1; + } + + // + // Router + // + + // register API routes + server_routes routes(params, ctx_server, ctx_http); + + ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); + ctx_http.get ("/props", ex_wrapper(routes.get_props)); + ctx_http.post("/props", ex_wrapper(routes.post_props)); + ctx_http.post("/api/show", ex_wrapper(routes.get_api_show)); + ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check) + ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy + ctx_http.post("/completions", ex_wrapper(routes.post_completions)); + ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai)); + ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API + ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting + ctx_http.post("/infill", ex_wrapper(routes.post_infill)); + ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy + ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); + ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai)); + ctx_http.post("/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize)); + ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize)); + ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template)); + // LoRA adapters hotswap + ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters)); + ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters)); + // Save & load slots + ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); + ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); + + // + // Start the server + // + + // setup clean up function, to be called before exit + auto clean_up = [&ctx_http, &ctx_server]() { + SRV_INF("%s: cleaning up before exit...\n", __func__); + ctx_http.stop(); + ctx_server.queue_results.terminate(); + llama_backend_free(); + }; + + // start the HTTP server before loading the model to be able to serve /health requests + if (!ctx_http.start()) { + clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); + return 1; + } + + // load the model + LOG_INF("%s: loading model\n", __func__); + + if (!ctx_server.load_model(params)) { + clean_up(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } + LOG_ERR("%s: exiting due to model loading error\n", __func__); + return 1; + } + + ctx_server.init(); + ctx_http.is_ready.store(true); + + LOG_INF("%s: model loaded\n", __func__); + + ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { + ctx_server.process_single_task(std::move(task)); + }); + + ctx_server.queue_tasks.on_update_slots([&ctx_server]() { + ctx_server.update_slots(); + }); + + shutdown_handler = [&](int) { + // this will unblock start_loop() + ctx_server.queue_tasks.terminate(); + }; + + // TODO: refactor in common/console +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + LOG_INF("%s: starting the main loop...\n", __func__); + // this call blocks the main thread until queue_tasks.terminate() is called + ctx_server.queue_tasks.start_loop(); + + clean_up(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } + llama_memory_breakdown_print(ctx_server.ctx); + + return 0; +} From 9a7b4f36c851526a93bc36e846351f287ee51452 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Nov 2025 19:04:42 +0100 Subject: [PATCH 3/8] add server-context.h --- tools/server/server-context.cpp | 260 ++++++++++++++++++++++---------- tools/server/server-context.h | 172 ++------------------- tools/server/server.cpp | 21 +-- 3 files changed, 201 insertions(+), 252 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index aadd0671590..0853a5b8373 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -396,78 +396,119 @@ struct server_slot { } }; -void server_slots_t::clear() { - for (auto & slot_ptr : data) { - delete slot_ptr; - } - data.clear(); -} - -server_slot & server_slots_t::create() { - auto instance = new server_slot(); - data.push_back(instance); - return *instance; -} - -server_slots_t::~server_slots_t() { - clear(); -} - // // server_metrics // -void server_metrics::init() { - t_start = ggml_time_us(); -} +struct server_metrics { + int64_t t_start = 0; -void server_metrics::on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); -} + uint64_t n_tokens_max = 0; -void server_metrics::on_prediction(const server_slot & slot) { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; -} + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init() { + t_start = ggml_time_us(); + } + + void on_prompt_eval(const server_slot & slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; -void server_metrics::on_decoded(const server_slots_t & slots) { - n_decode_total++; - for (size_t i = 0; i < slots.size(); i++) { - const auto & slot = slots[i]; - if (slot.is_processing()) { - n_busy_slots_total++; - } n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); } -} -void server_metrics::reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; -} + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; + } + + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); + } + } + + void reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; + } +}; // -// server_context +// server_context_impl (private implementation) // -// TODO @ngxson : the only purpose of this extern "C" is to keep the first indentation level -// this was done to avoid massive changes in while doing the recent refactoring, avoiding merge conflicts -// we can remove this once things are more stable +struct server_context_impl { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + const llama_vocab * vocab = nullptr; + bool vocab_dft_compatible = true; + + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; -extern "C" { - server_context::~server_context() { + llama_batch batch {}; + + bool add_bos_token = true; + + int32_t n_ctx; // total context for all clients / slots + + // slots / clients + std::vector slots; + + int slots_debug = 0; + + server_queue queue_tasks; + server_response queue_results; + + std::unique_ptr prompt_cache; + + server_metrics metrics; + + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + ~server_context_impl() { mtmd_free(mctx); // Clear any sampling context @@ -488,7 +529,7 @@ extern "C" { } // load the model and initialize llama_context - bool server_context::load_model(const common_params & params) { + bool load_model(const common_params & params) { SRV_INF("loading model '%s'\n", params.model.path.c_str()); params_base = params; @@ -608,7 +649,19 @@ extern "C" { } // initialize slots and server-related data - void server_context::init() { + void init() { + // wiring up server queues + queue_tasks.on_new_task([this](server_task && task) { + process_single_task(std::move(task)); + }); + queue_tasks.on_update_slots([this]() { + update_slots(); + }); + + // Necessary similarity of prompt for slot selection + slot_prompt_similarity = params_base.slot_prompt_similarity; + + // setup slots SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); const int n_ctx_train = llama_model_n_ctx_train(model); @@ -620,7 +673,7 @@ extern "C" { } for (int i = 0; i < params_base.n_parallel; i++) { - server_slot & slot = slots.create(); + server_slot slot; slot.id = i; slot.ctx = ctx; @@ -655,6 +708,8 @@ extern "C" { }; slot.reset(); + + slots.push_back(std::move(slot)); } { @@ -712,7 +767,7 @@ extern "C" { common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); } - server_slot * server_context::get_slot_by_id(int id) { + server_slot * get_slot_by_id(int id) { for (server_slot & slot : slots) { if (slot.id == id) { return &slot; @@ -722,7 +777,7 @@ extern "C" { return nullptr; } - server_slot * server_context::get_available_slot(const server_task & task) { + server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; bool update_cache = false; @@ -826,7 +881,7 @@ extern "C" { return ret; } - void server_context::clear_slot(server_slot & slot) const { + void clear_slot(server_slot & slot) const { GGML_ASSERT(!slot.is_processing()); SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); @@ -840,7 +895,7 @@ extern "C" { // - smarter decision which slot to clear (LRU or longest prompt?) // - move slot to level 2 cache instead of removing? // - instead of purging, try to store and resume later? - bool server_context::try_clear_idle_slots() { + bool try_clear_idle_slots() { bool res = false; if (!params_base.kv_unified) { @@ -867,7 +922,7 @@ extern "C" { return res; } - bool server_context::launch_slot_with_task(server_slot & slot, server_task && task) { + bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); if (!are_lora_equal(task.params.lora, slot.lora)) { @@ -969,7 +1024,7 @@ extern "C" { return true; } - bool server_context::process_token(completion_token_output & result, server_slot & slot) { + bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = result.text_to_send; slot.sampled = result.tok; @@ -1100,7 +1155,7 @@ extern "C" { return slot.has_next_token; // continue } - void server_context::populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { size_t n_probs = slot.task->params.sampling.n_probs; size_t n_vocab = llama_vocab_n_tokens(vocab); @@ -1150,11 +1205,15 @@ extern "C" { } } - void server_context::send_error(const server_slot & slot, const std::string & error, const enum error_type type) { + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); } - void server_context::send_error(const int id_task, const std::string & error, const enum error_type type, const int32_t n_prompt_tokens, const int32_t n_ctx) { + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { @@ -1172,7 +1231,7 @@ extern "C" { } // if multimodal is enabled, send an error and return false - bool server_context::check_no_mtmd(const int id_task) { + bool check_no_mtmd(const int id_task) { if (mctx) { send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); return false; @@ -1180,7 +1239,7 @@ extern "C" { return true; } - void server_context::send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { + void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { auto res = std::make_unique(); res->id = slot.task->id; @@ -1221,7 +1280,7 @@ extern "C" { queue_results.send(std::move(res)); } - void server_context::send_final_response(server_slot & slot) { + void send_final_response(server_slot & slot) { auto res = std::make_unique(); res->id = slot.task->id; @@ -1272,7 +1331,7 @@ extern "C" { queue_results.send(std::move(res)); } - void server_context::send_embedding(const server_slot & slot, const llama_batch & batch) { + void send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); res->id = slot.task->id; res->index = slot.task->index; @@ -1317,7 +1376,7 @@ extern "C" { queue_results.send(std::move(res)); } - void server_context::send_rerank(const server_slot & slot, const llama_batch & batch) { + void send_rerank(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); res->id = slot.task->id; res->index = slot.task->index; @@ -1352,7 +1411,7 @@ extern "C" { // Functions to process the task // - void server_context::process_single_task(server_task && task) { + void process_single_task(server_task && task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: case SERVER_TASK_TYPE_INFILL: @@ -1572,7 +1631,7 @@ extern "C" { } } - void server_context::update_slots() { + void update_slots() { // check if all slots are idle { bool all_idle = true; @@ -2421,24 +2480,65 @@ extern "C" { SRV_DBG("%s", "run slots completed\n"); } - // - // Utility functions - // + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, + }; + } - int server_context::get_slot_n_ctx() const { - return slots[0].n_ctx; + int get_slot_n_ctx() { + return slots.back().n_ctx; + } +}; + +void server_context_impl_deleter::operator()(server_context_impl * ptr) const { + if (ptr) { + delete ptr; } } +// +// server_context (public API) +// + +server_context::server_context() : impl(new server_context_impl()) {} +server_context::~server_context() = default; + +void server_context::init() { + impl->init(); +} + +bool server_context::load_model(const common_params & params) { + return impl->load_model(params); +} + +void server_context::start_loop() { + impl->queue_tasks.start_loop(); +} + +void server_context::terminate() { + impl->queue_tasks.terminate(); +} + +llama_context * server_context::get_llama_context() const { + return impl->ctx; +} + + // generator-like API for server responses, support pooling connection state and aggregating results struct server_response_reader { std::unordered_set id_tasks; - server_context & ctx_server; + server_context_impl & ctx_server; size_t received_count = 0; bool cancelled = false; - server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {} + server_response_reader(server_context_impl & ctx_server) : ctx_server(ctx_server) {} ~server_response_reader() { stop(); } @@ -2532,7 +2632,7 @@ struct server_response_reader { // generator-like API for HTTP response generation struct server_res_generator : server_http_res { server_response_reader rd; - server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {} + server_res_generator(server_context_impl & ctx_server_) : rd(ctx_server_) {} void ok(const json & response_data) { status = 200; data = safe_json_to_str(response_data); diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 5c48f8e8082..b799754145d 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -47,172 +47,32 @@ enum server_state { SERVER_STATE_READY, // Server is ready and model is loaded }; -// forward declarations -struct server_slot; - -// proxy for std::vector to allow forward declaration of server_slot -struct server_slots_t { - ~server_slots_t(); - std::vector data; - size_t size() const { return data.size(); } - server_slot & operator[](size_t idx) { return *(data[idx]); } - server_slot & operator[](size_t idx) const { return *(data[idx]); } - void clear(); - server_slot & create(); - struct iterator { - typename std::vector::iterator it; - iterator(typename std::vector::iterator i) : it(i) {} - server_slot & operator*() { return **it; } - iterator & operator++() { ++it; return *this; } - bool operator!=(const iterator& other) const { return it != other.it; } - }; - iterator begin() { return iterator(data.begin()); } - iterator end() { return iterator(data.end()); } -}; - -struct server_metrics { - int64_t t_start = 0; - - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - void init(); - void on_prompt_eval(const server_slot & slot); - void on_prediction(const server_slot & slot); - void on_decoded(const server_slots_t & slots); - void reset_bucket(); +struct server_context_impl; // private implementation +struct server_context_impl_deleter { + void operator()(server_context_impl * p) const; }; struct server_context { -public: - common_params params_base; - - // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; - - llama_model * model = nullptr; - llama_context * ctx = nullptr; - - const llama_vocab * vocab = nullptr; - bool vocab_dft_compatible = true; - - // multimodal - mtmd_context * mctx = nullptr; - - server_queue queue_tasks; - server_response queue_results; - - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - - // Necessary similarity of prompt for slot selection - float slot_prompt_similarity = 0.0f; - -private: - llama_model * model_dft = nullptr; - - llama_context_params cparams_dft; - - llama_batch batch {}; - - bool add_bos_token = true; - - int32_t n_ctx; // total context for all clients / slots - - // slots / clients - server_slots_t slots; - - int slots_debug = 0; - - std::unique_ptr prompt_cache; + std::unique_ptr impl; - server_metrics metrics; - -public: + server_context(); ~server_context(); - // load the model and initialize llama_context - bool load_model(const common_params & params); - // initialize slots and server-related data void init(); - server_slot * get_slot_by_id(int id); - - server_slot * get_available_slot(const server_task & task); - - void clear_slot(server_slot & slot) const; - - // return true if at least one slot has been cleared - // TODO: improve logic - // - smarter decision which slot to clear (LRU or longest prompt?) - // - move slot to level 2 cache instead of removing? - // - instead of purging, try to store and resume later? - bool try_clear_idle_slots(); - - bool launch_slot_with_task(server_slot & slot, server_task && task); - - bool process_token(completion_token_output & result, server_slot & slot); - - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const; - - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, error, type); - } - - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER); - - void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0); - - // if multimodal is enabled, send an error and return false - bool check_no_mtmd(const int id_task); - - void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress); - - void send_final_response(server_slot & slot); - - void send_embedding(const server_slot & slot, const llama_batch & batch); - - void send_rerank(const server_slot & slot, const llama_batch & batch); - - // - // Functions to process the task - // - - void process_single_task(server_task && task); - - void update_slots(); + // load the model and initialize llama_context + // returns true on success + bool load_model(const common_params & params); - // - // Utility functions - // + // this function will block main thread until termination + void start_loop(); - int get_slot_n_ctx() const; + // terminate main loop (will unblock start_loop) + void terminate(); - json model_meta() const { - return json { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; - } + // get the underlaying llama_context + llama_context * get_llama_context() const; }; @@ -220,10 +80,10 @@ struct server_res_generator; struct server_routes { const common_params & params; - server_context & ctx_server; + server_context_impl & ctx_server; server_http_context & ctx_http; // for reading is_ready server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) - : params(params), ctx_server(ctx_server), ctx_http(ctx_http) { + : params(params), ctx_server(*ctx_server.impl.get()), ctx_http(ctx_http) { init_routes(); } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index b9a9335ff26..d6603a5299f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -94,9 +94,6 @@ int main(int argc, char ** argv) { // struct that contains llama context and inference server_context ctx_server; - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - llama_backend_init(); llama_numa_init(params.numa); @@ -161,7 +158,7 @@ int main(int argc, char ** argv) { auto clean_up = [&ctx_http, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); ctx_http.stop(); - ctx_server.queue_results.terminate(); + ctx_server.terminate(); llama_backend_free(); }; @@ -189,17 +186,9 @@ int main(int argc, char ** argv) { LOG_INF("%s: model loaded\n", __func__); - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { - ctx_server.process_single_task(std::move(task)); - }); - - ctx_server.queue_tasks.on_update_slots([&ctx_server]() { - ctx_server.update_slots(); - }); - shutdown_handler = [&](int) { // this will unblock start_loop() - ctx_server.queue_tasks.terminate(); + ctx_server.terminate(); }; // TODO: refactor in common/console @@ -219,14 +208,14 @@ int main(int argc, char ** argv) { LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); LOG_INF("%s: starting the main loop...\n", __func__); - // this call blocks the main thread until queue_tasks.terminate() is called - ctx_server.queue_tasks.start_loop(); + // this call blocks the main thread until ctx_server.terminate() is called + ctx_server.start_loop(); clean_up(); if (ctx_http.thread.joinable()) { ctx_http.thread.join(); } - llama_memory_breakdown_print(ctx_server.ctx); + llama_memory_breakdown_print(ctx_server.get_llama_context()); return 0; } From 239c7a2615ef761f2442ae37248f854de53f1f37 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Nov 2025 19:25:31 +0100 Subject: [PATCH 4/8] clean up headers --- tools/server/server-context.cpp | 367 +++++++++++++++++--------------- tools/server/server-context.h | 46 +--- tools/server/server.cpp | 19 +- 3 files changed, 196 insertions(+), 236 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0853a5b8373..b3285b42626 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -13,12 +13,9 @@ #include "mtmd.h" #include "mtmd-helper.h" -#include #include #include #include -#include -#include #include // fix problem with std::min and std::max @@ -32,6 +29,22 @@ using json = nlohmann::ordered_json; +constexpr int HTTP_POLLING_SECONDS = 1; + +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, +}; + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + static bool server_task_type_need_embd(server_task_type task_type) { switch (task_type) { case SERVER_TASK_TYPE_EMBEDDING: @@ -2649,6 +2662,178 @@ struct server_res_generator : server_http_res { // server_routes // +static std::unique_ptr handle_completions_impl( + server_context_impl & ctx_server, + server_task_type type, + const json & data, + const std::vector & files, + const std::function & should_stop, + task_response_type res_type) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + + auto res = std::make_unique(ctx_server); + auto completion_id = gen_chatcmplid(); + auto & rd = res->rd; + + try { + std::vector tasks; + + const auto & prompt = data.at("prompt"); + // TODO: this log can become very long, put it behind a flag or think about a more compact format + //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + + // process prompt + std::vector inputs; + + if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { + // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } else { + // Everything else, including multimodal completions. + inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + } + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.res_type = res_type; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(std::move(task)); + } + + rd.post_tasks(std::move(tasks)); + } catch (const std::exception & e) { + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool stream = json_value(data, "stream", false); + + if (!stream) { + // non-stream, wait for the results + auto all_results = rd.wait_for_all(should_stop); + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + json arr = json::array(); + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + arr.push_back(res->to_json()); + } + // if single request, return single object instead of array + res->ok(arr.size() == 1 ? arr[0] : arr); + } + + } else { + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 + server_task_result_ptr first_result = rd.next(should_stop); + if (first_result == nullptr) { + return res; // connection is closed + } else if (first_result->is_error()) { + res->error(first_result->to_json()); + return res; + } else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); + } + + // next responses are streamed + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + res->data = format_anthropic_sse(first_result->to_json()); + } else { + res->data = format_oai_sse(first_result->to_json()); // to be sent immediately + } + res->status = 200; + res->content_type = "text/event-stream"; + res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { + if (should_stop()) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + if (!res_this->data.empty()) { + // flush the first chunk + output = std::move(res_this->data); + res_this->data.clear(); + return true; + } + + server_response_reader & rd = res_this->rd; + + // check if there is more data + if (!rd.has_next()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + // Anthropic doesn't send [DONE], message_stop was already sent + output = ""; + } else if (res_type != TASK_RESPONSE_TYPE_NONE) { + output = "data: [DONE]\n\n"; + } else { + output = ""; + } + SRV_DBG("%s", "all results received, terminating stream\n"); + return false; // no more data, terminate + } + + // receive subsequent results + auto result = rd.next(should_stop); + if (result == nullptr) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + // send the results + json res_json = result->to_json(); + if (result->is_error()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse({ + {"event", "error"}, + {"data", res_json}, + }); + } else { + output = format_oai_sse(json {{ "error", res_json }}); + } + SRV_DBG("%s", "error received during streaming, terminating stream\n"); + return false; // terminate on error + } else { + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse(res_json); + } else { + output = format_oai_sse(res_json); + } + } + + // has next data, continue + return true; + }; + } + + return res; +} + void server_routes::init_routes() { this->get_health = [this](const server_http_req &) { // error and loading states are handled by middleware @@ -2991,6 +3176,7 @@ void server_routes::init_routes() { std::vector files; // dummy return handle_completions_impl( + ctx_server, SERVER_TASK_TYPE_INFILL, data, files, @@ -3002,6 +3188,7 @@ void server_routes::init_routes() { std::vector files; // dummy const json body = json::parse(req.body); return handle_completions_impl( + ctx_server, SERVER_TASK_TYPE_COMPLETION, body, files, @@ -3013,6 +3200,7 @@ void server_routes::init_routes() { std::vector files; // dummy const json body = json::parse(req.body); return handle_completions_impl( + ctx_server, SERVER_TASK_TYPE_COMPLETION, body, files, @@ -3028,6 +3216,7 @@ void server_routes::init_routes() { ctx_server.oai_parser_opt, files); return handle_completions_impl( + ctx_server, SERVER_TASK_TYPE_COMPLETION, body_parsed, files, @@ -3043,6 +3232,7 @@ void server_routes::init_routes() { ctx_server.oai_parser_opt, files); return handle_completions_impl( + ctx_server, SERVER_TASK_TYPE_COMPLETION, body_parsed, files, @@ -3332,177 +3522,6 @@ void server_routes::init_routes() { }; } -std::unique_ptr server_routes::handle_completions_impl( - server_task_type type, - const json & data, - const std::vector & files, - const std::function & should_stop, - task_response_type res_type) { - GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - - auto res = std::make_unique(ctx_server); - auto completion_id = gen_chatcmplid(); - auto & rd = res->rd; - - try { - std::vector tasks; - - const auto & prompt = data.at("prompt"); - // TODO: this log can become very long, put it behind a flag or think about a more compact format - //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - - // process prompt - std::vector inputs; - - if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { - // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } else { - // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - } - tasks.reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - - task.tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( - ctx_server.ctx, - ctx_server.params_base, - data); - task.id_slot = json_value(data, "id_slot", -1); - - // OAI-compat - task.params.res_type = res_type; - task.params.oaicompat_cmpl_id = completion_id; - // oaicompat_model is already populated by params_from_json_cmpl - - tasks.push_back(std::move(task)); - } - - rd.post_tasks(std::move(tasks)); - } catch (const std::exception & e) { - res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - bool stream = json_value(data, "stream", false); - - if (!stream) { - // non-stream, wait for the results - auto all_results = rd.wait_for_all(should_stop); - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - json arr = json::array(); - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - arr.push_back(res->to_json()); - } - // if single request, return single object instead of array - res->ok(arr.size() == 1 ? arr[0] : arr); - } - - } else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd.next(should_stop); - if (first_result == nullptr) { - return res; // connection is closed - } else if (first_result->is_error()) { - res->error(first_result->to_json()); - return res; - } else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - - // next responses are streamed - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - res->data = format_anthropic_sse(first_result->to_json()); - } else { - res->data = format_oai_sse(first_result->to_json()); // to be sent immediately - } - res->status = 200; - res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { - if (should_stop()) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - if (!res_this->data.empty()) { - // flush the first chunk - output = std::move(res_this->data); - res_this->data.clear(); - return true; - } - - server_response_reader & rd = res_this->rd; - - // check if there is more data - if (!rd.has_next()) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - // Anthropic doesn't send [DONE], message_stop was already sent - output = ""; - } else if (res_type != TASK_RESPONSE_TYPE_NONE) { - output = "data: [DONE]\n\n"; - } else { - output = ""; - } - SRV_DBG("%s", "all results received, terminating stream\n"); - return false; // no more data, terminate - } - - // receive subsequent results - auto result = rd.next(should_stop); - if (result == nullptr) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - // send the results - json res_json = result->to_json(); - if (result->is_error()) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - output = format_anthropic_sse({ - {"event", "error"}, - {"data", res_json}, - }); - } else { - output = format_oai_sse(json {{ "error", res_json }}); - } - SRV_DBG("%s", "error received during streaming, terminating stream\n"); - return false; // terminate on error - } else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - output = format_anthropic_sse(res_json); - } else { - output = format_oai_sse(res_json); - } - } - - // has next data, continue - return true; - }; - } - - return res; -} - std::unique_ptr server_routes::handle_slots_save(const server_http_req & req, int id_slot) { auto res = std::make_unique(ctx_server); const json request_data = json::parse(req.body); diff --git a/tools/server/server-context.h b/tools/server/server-context.h index b799754145d..18337d0ee5c 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -1,51 +1,13 @@ #include "server-common.h" #include "server-http.h" #include "server-task.h" -#include "server-queue.h" -#include "arg.h" #include "common.h" #include "llama.h" -#include "log.h" -#include "sampling.h" -#include "speculative.h" -#include "mtmd.h" -#include "mtmd-helper.h" -#include #include #include #include -#include -#include -#include - -// fix problem with std::min and std::max -#if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX -#endif -#include -#endif - -using json = nlohmann::ordered_json; - -constexpr int HTTP_POLLING_SECONDS = 1; - -// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future - SLOT_STATE_PROCESSING_PROMPT, - SLOT_STATE_DONE_PROMPT, - SLOT_STATE_GENERATING, -}; - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; struct server_context_impl; // private implementation struct server_context_impl_deleter { @@ -76,6 +38,7 @@ struct server_context { }; +// forward declarations struct server_res_generator; struct server_routes { @@ -113,12 +76,7 @@ struct server_routes { server_http_context::handler_t get_lora_adapters; server_http_context::handler_t post_lora_adapters; private: - std::unique_ptr handle_completions_impl( - server_task_type type, - const json & data, - const std::vector & files, - const std::function & should_stop, - task_response_type res_type); + // TODO: move these outside of server_routes? std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot); std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot); std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index d6603a5299f..6570c83c3a3 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5,27 +5,10 @@ #include "common.h" #include "llama.h" #include "log.h" -#include "sampling.h" -#include "speculative.h" -#include "mtmd.h" -#include "mtmd-helper.h" #include -#include -#include -#include #include -#include -#include - -// fix problem with std::min and std::max -#if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX -#endif -#include -#endif +#include // for std::thread::hardware_concurrency static std::function shutdown_handler; From 26204aede77246605882c47b29375c225c9ad43f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 29 Nov 2025 20:42:53 +0200 Subject: [PATCH 5/8] cont : cleanup --- tools/server/server-context.cpp | 6 ------ tools/server/server-context.h | 20 +++++++------------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b3285b42626..ac470c5ed5d 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2509,12 +2509,6 @@ struct server_context_impl { } }; -void server_context_impl_deleter::operator()(server_context_impl * ptr) const { - if (ptr) { - delete ptr; - } -} - // // server_context (public API) // diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 18337d0ee5c..5259ce416ca 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -1,21 +1,15 @@ -#include "server-common.h" #include "server-http.h" #include "server-task.h" -#include "common.h" -#include "llama.h" +#include #include -#include #include struct server_context_impl; // private implementation -struct server_context_impl_deleter { - void operator()(server_context_impl * p) const; -}; struct server_context { - std::unique_ptr impl; + std::unique_ptr impl; server_context(); ~server_context(); @@ -42,15 +36,11 @@ struct server_context { struct server_res_generator; struct server_routes { - const common_params & params; - server_context_impl & ctx_server; - server_http_context & ctx_http; // for reading is_ready server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) - : params(params), ctx_server(*ctx_server.impl.get()), ctx_http(ctx_http) { + : params(params), ctx_server(*ctx_server.impl), ctx_http(ctx_http) { init_routes(); } -public: void init_routes(); // handlers using lambda function, so that they can capture `this` without `std::bind` server_http_context::handler_t get_health; @@ -81,4 +71,8 @@ struct server_routes { std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot); std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot); std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type); + + const common_params & params; + server_context_impl & ctx_server; + server_http_context & ctx_http; // for reading is_ready }; From 8b510b476849a01d30452feca14ccc77ec700059 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Nov 2025 20:07:23 +0100 Subject: [PATCH 6/8] also expose server_response_reader (to be used by CLI) --- tools/server/server-context.cpp | 107 +++----------------------------- tools/server/server-context.h | 5 ++ tools/server/server-queue.cpp | 83 +++++++++++++++++++++++++ tools/server/server-queue.h | 36 +++++++++++ 4 files changed, 131 insertions(+), 100 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ac470c5ed5d..e902b0220c1 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2536,110 +2536,17 @@ llama_context * server_context::get_llama_context() const { return impl->ctx; } +std::pair server_context::get_queues() { + return { impl->queue_tasks, impl->queue_results }; +} -// generator-like API for server responses, support pooling connection state and aggregating results -struct server_response_reader { - std::unordered_set id_tasks; - server_context_impl & ctx_server; - size_t received_count = 0; - bool cancelled = false; - - server_response_reader(server_context_impl & ctx_server) : ctx_server(ctx_server) {} - ~server_response_reader() { - stop(); - } - - void post_tasks(std::vector && tasks) { - id_tasks = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } - - bool has_next() const { - return !cancelled && received_count < id_tasks.size(); - } - - // return nullptr if should_stop() is true before receiving a result - // note: if one error is received, it will stop further processing and return error result - server_task_result_ptr next(const std::function & should_stop) { - while (true) { - server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - if (result == nullptr) { - // timeout, check stop condition - if (should_stop()) { - SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); - return nullptr; - } - } else { - if (result->is_error()) { - stop(); // cancel remaining tasks - SRV_DBG("%s", "received error result, stopping further processing\n"); - return result; - } - if (result->is_stop()) { - received_count++; - } - return result; - } - } - - // should not reach here - } - - struct batch_response { - bool is_terminated = false; // if true, indicates that processing was stopped before all results were received - std::vector results; - server_task_result_ptr error; // nullptr if no error - }; - - batch_response wait_for_all(const std::function & should_stop) { - batch_response batch_res; - batch_res.results.resize(id_tasks.size()); - while (has_next()) { - auto res = next(should_stop); - if (res == nullptr) { - batch_res.is_terminated = true; - return batch_res; - } - if (res->is_error()) { - batch_res.error = std::move(res); - return batch_res; - } - const size_t idx = res->get_index(); - GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); - GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); - batch_res.results[idx] = std::move(res); - } - return batch_res; - } - - void stop() { - ctx_server.queue_results.remove_waiting_task_ids(id_tasks); - if (has_next() && !cancelled) { - // if tasks is not finished yet, cancel them - cancelled = true; - std::vector cancel_tasks; - cancel_tasks.reserve(id_tasks.size()); - for (const auto & id_task : id_tasks) { - SRV_WRN("cancel task, id_task = %d\n", id_task); - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_task; - ctx_server.queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(std::move(task)); - } - // push to beginning of the queue, so it has highest priority - ctx_server.queue_tasks.post(std::move(cancel_tasks), true); - } else { - SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); - } - } -}; // generator-like API for HTTP response generation struct server_res_generator : server_http_res { server_response_reader rd; - server_res_generator(server_context_impl & ctx_server_) : rd(ctx_server_) {} + server_res_generator(server_context_impl & ctx_server) + : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {} void ok(const json & response_data) { status = 200; data = safe_json_to_str(response_data); @@ -3410,7 +3317,7 @@ void server_routes::init_routes() { // create and queue the task json responses = json::array(); - server_response_reader rd(ctx_server); + server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); { std::vector tasks; tasks.reserve(documents.size()); @@ -3669,7 +3576,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons // create and queue the task json responses = json::array(); - server_response_reader rd(ctx_server); + server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 5259ce416ca..1a0e19e32d1 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -1,5 +1,6 @@ #include "server-http.h" #include "server-task.h" +#include "server-queue.h" #include @@ -29,6 +30,10 @@ struct server_context { // get the underlaying llama_context llama_context * get_llama_context() const; + + // get the underlaying queue_tasks and queue_results + // used by CLI application + std::pair get_queues(); }; diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 5a74fd76ac3..9f0d57ae843 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -266,3 +266,86 @@ void server_response::terminate() { running = false; condition_results.notify_all(); } + +// +// server_response_reader +// + +void server_response_reader::post_tasks(std::vector && tasks) { + id_tasks = server_task::get_list_id(tasks); + queue_results.add_waiting_tasks(tasks); + queue_tasks.post(std::move(tasks)); +} + +bool server_response_reader::has_next() const { + return !cancelled && received_count < id_tasks.size(); +} + +// return nullptr if should_stop() is true before receiving a result +// note: if one error is received, it will stop further processing and return error result +server_task_result_ptr server_response_reader::next(const std::function & should_stop) { + while (true) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds); + if (result == nullptr) { + // timeout, check stop condition + if (should_stop()) { + SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); + return nullptr; + } + } else { + if (result->is_error()) { + stop(); // cancel remaining tasks + SRV_DBG("%s", "received error result, stopping further processing\n"); + return result; + } + if (result->is_stop()) { + received_count++; + } + return result; + } + } + + // should not reach here +} + +server_response_reader::batch_response server_response_reader::wait_for_all(const std::function & should_stop) { + batch_response batch_res; + batch_res.results.resize(id_tasks.size()); + while (has_next()) { + auto res = next(should_stop); + if (res == nullptr) { + batch_res.is_terminated = true; + return batch_res; + } + if (res->is_error()) { + batch_res.error = std::move(res); + return batch_res; + } + const size_t idx = res->get_index(); + GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); + GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); + batch_res.results[idx] = std::move(res); + } + return batch_res; +} + +void server_response_reader::stop() { + queue_results.remove_waiting_task_ids(id_tasks); + if (has_next() && !cancelled) { + // if tasks is not finished yet, cancel them + cancelled = true; + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(std::move(cancel_tasks), true); + } else { + SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); + } +} diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 47ef58425ea..209d2017c7e 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -108,3 +108,39 @@ struct server_response { // terminate the waiting loop void terminate(); }; + +// utility class to make working with server_queue and server_response easier +// it provides a generator-like API for server responses +// support pooling connection state and aggregating multiple results +struct server_response_reader { + std::unordered_set id_tasks; + server_queue & queue_tasks; + server_response & queue_results; + size_t received_count = 0; + bool cancelled = false; + int polling_interval_seconds; + + // should_stop function will be called each polling_interval_seconds + server_response_reader(std::pair server_queues, int polling_interval_seconds) + : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} + ~server_response_reader() { + stop(); + } + + void post_tasks(std::vector && tasks); + bool has_next() const; + + // return nullptr if should_stop() is true before receiving a result + // note: if one error is received, it will stop further processing and return error result + server_task_result_ptr next(const std::function & should_stop); + + struct batch_response { + bool is_terminated = false; // if true, indicates that processing was stopped before all results were received + std::vector results; + server_task_result_ptr error; // nullptr if no error + }; + // aggregate multiple results + batch_response wait_for_all(const std::function & should_stop); + + void stop(); +}; From 2141a3e7daf6e6b983cb737478627dc354aaaa24 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Nov 2025 20:07:31 +0100 Subject: [PATCH 7/8] fix windows build --- tools/server/server.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6570c83c3a3..4ed1660ab7f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -10,6 +10,9 @@ #include #include // for std::thread::hardware_concurrency +#if defined(_WIN32) +#include +#endif static std::function shutdown_handler; static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; From c5f8cdf320536ca01cd56a183e69acbadba53e12 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Nov 2025 20:21:25 +0100 Subject: [PATCH 8/8] decouple server_routes and server_http --- tools/server/server-context.cpp | 3 +-- tools/server/server-context.h | 6 +++--- tools/server/server.cpp | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index e902b0220c1..2bf3924df90 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3172,9 +3172,8 @@ void server_routes::init_routes() { this->get_models = [this](const server_http_req &) { auto res = std::make_unique(ctx_server); - bool is_model_ready = ctx_http.is_ready.load(); json model_meta = nullptr; - if (is_model_ready) { + if (is_ready()) { model_meta = ctx_server.model_meta(); } bool has_mtmd = ctx_server.mctx != nullptr; diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 1a0e19e32d1..05b4afaeeb2 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -41,8 +41,8 @@ struct server_context { struct server_res_generator; struct server_routes { - server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) - : params(params), ctx_server(*ctx_server.impl), ctx_http(ctx_http) { + server_routes(const common_params & params, server_context & ctx_server, std::function is_ready = []() { return true; }) + : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) { init_routes(); } @@ -79,5 +79,5 @@ struct server_routes { const common_params & params; server_context_impl & ctx_server; - server_http_context & ctx_http; // for reading is_ready + std::function is_ready; }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4ed1660ab7f..5256790db2f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -99,7 +99,7 @@ int main(int argc, char ** argv) { // // register API routes - server_routes routes(params, ctx_server, ctx_http); + server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }); ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)