diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index df165c34a3c..fbcd6bc1f93 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -42,7 +42,15 @@ graph TD server_response --> server_routes ``` -TODO: mention about how batching is handled by `server_slot` +### Batching + +The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch. + +Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all. + +Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`. + +Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration. ### Thread Management @@ -62,6 +70,23 @@ Each incoming HTTP request is handled by its own thread managed by the HTTP libr - All JSON formatting and chat template logic must stay in the HTTP layer. - Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible. +### Example trace of a request + +Here is an example trace of an API request for text completion: + +- A request arrives at the HTTP layer. +- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked. +- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`. +- `server_res_generator` creates a new `task_result_state` for each task: + - `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages). + - `server_task` is moved into `server_queue` inside `server_context`. +- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`). +- `update_slot()` processes the task as described in the "Batching" section above. +- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue. +- At the same time, `server_res_generator` listens to the response queue and retrieves this response. +- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state. +- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer. + ### Testing `llama-server` includes an automated test suite based on `pytest`. diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3bf90510269..4578f8d7a9f 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2589,6 +2589,10 @@ struct server_context_impl { int get_slot_n_ctx() { return slots.back().n_ctx; } + + server_response_reader get_response_reader() { + return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS); + } }; // @@ -2618,8 +2622,8 @@ llama_context * server_context::get_llama_context() const { return impl->ctx; } -std::pair server_context::get_queues() { - return { impl->queue_tasks, impl->queue_results }; +server_response_reader server_context::get_response_reader() { + return impl->get_response_reader(); } @@ -2628,7 +2632,7 @@ std::pair server_context::get_queues() { struct server_res_generator : server_http_res { server_response_reader rd; server_res_generator(server_context_impl & ctx_server) - : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {} + : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {} void ok(const json & response_data) { status = 200; data = safe_json_to_str(response_data); @@ -2661,9 +2665,6 @@ static std::unique_ptr handle_completions_impl( try { std::vector tasks; - // tracking generation state and partial tool calls - std::vector states; - const auto & prompt = data.at("prompt"); // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); @@ -2679,7 +2680,6 @@ static std::unique_ptr handle_completions_impl( inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } tasks.reserve(inputs.size()); - states.reserve(inputs.size()); int idx = 0; for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -2698,7 +2698,6 @@ static std::unique_ptr handle_completions_impl( task.params.res_type = res_type; task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_model = ctx_server.model_name; - states.push_back(task.params.oaicompat_chat_syntax); if (task.params.n_cmpl > 1) { task.n_children = task.params.n_cmpl - 1; @@ -2707,7 +2706,6 @@ static std::unique_ptr handle_completions_impl( task.id, ctx_server.queue_tasks.get_new_id(), idx++); - states.push_back(child.params.oaicompat_chat_syntax); tasks.push_back(std::move(child)); } } @@ -2715,7 +2713,6 @@ static std::unique_ptr handle_completions_impl( tasks.push_back(std::move(task)); } - rd.set_states(std::move(states)); rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); @@ -3445,7 +3442,7 @@ void server_routes::init_routes() { // create and queue the task json responses = json::array(); - server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + server_response_reader rd = ctx_server.get_response_reader(); { std::vector tasks; tasks.reserve(documents.size()); @@ -3705,7 +3702,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons // create and queue the task json responses = json::array(); - server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + server_response_reader rd = ctx_server.get_response_reader(); { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 05b4afaeeb2..eaa13808779 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -31,9 +31,8 @@ struct server_context { // get the underlaying llama_context llama_context * get_llama_context() const; - // get the underlaying queue_tasks and queue_results - // used by CLI application - std::pair get_queues(); + // get a new response reader, used by CLI application + server_response_reader get_response_reader(); }; diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 10196128db1..3cceb2bbe21 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -271,12 +271,21 @@ void server_response::terminate() { // server_response_reader // -void server_response_reader::set_states(std::vector && states) { - this->states = std::move(states); +void server_response_reader::post_task(server_task && task) { + GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader"); + id_tasks.insert(task.id); + states.push_back(task.create_state()); + queue_results.add_waiting_task_id(task.id); + queue_tasks.post(std::move(task)); } void server_response_reader::post_tasks(std::vector && tasks) { + GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader"); id_tasks = server_task::get_list_id(tasks); + states.reserve(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + states.push_back(tasks[i].create_state()); + } queue_results.add_waiting_tasks(tasks); queue_tasks.post(std::move(tasks)); } diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index a5c3179d8ca..726eadf4efc 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -129,13 +129,13 @@ struct server_response_reader { std::vector states; // should_stop function will be called each polling_interval_seconds - server_response_reader(std::pair server_queues, int polling_interval_seconds) - : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} + server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds) + : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {} ~server_response_reader() { stop(); } - void set_states(std::vector && states); + void post_task(server_task && tasks); void post_tasks(std::vector && tasks); bool has_next() const; diff --git a/tools/server/server-task.h b/tools/server/server-task.h index da4e22a7cd8..9011ff944b9 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -85,6 +85,25 @@ struct task_params { json to_json(bool only_metrics = false) const; }; +// struct for tracking the state of a task (e.g., for streaming) +struct task_result_state { + // tracking diffs for partial tool calls + std::vector diffs; + common_chat_syntax oaicompat_chat_syntax; + common_chat_msg chat_msg; + std::string generated_text; // append new chunks of generated text here + std::vector generated_tool_call_ids; + + task_result_state(const common_chat_syntax & oaicompat_chat_syntax) + : oaicompat_chat_syntax(oaicompat_chat_syntax) {} + + // parse partial tool calls and update the internal state + common_chat_msg update_chat_msg( + const std::string & text_added, + bool is_partial, + std::vector & diffs); +}; + struct server_task { int id = -1; // to be filled by server_queue int index = -1; // used when there are multiple prompts (batch request) @@ -149,6 +168,12 @@ struct server_task { copy.tokens = tokens.clone(); return copy; } + + // the task will be moved into queue, then onto slots + // however, the state must be kept by caller (e.g., HTTP thread) + task_result_state create_state() const { + return task_result_state(params.oaicompat_chat_syntax); + } }; struct result_timings { @@ -180,25 +205,6 @@ struct result_prompt_progress { json to_json() const; }; -// struct for tracking the state of a task (e.g., for streaming) -struct task_result_state { - // tracking diffs for partial tool calls - std::vector diffs; - common_chat_syntax oaicompat_chat_syntax; - common_chat_msg chat_msg; - std::string generated_text; // append new chunks of generated text here - std::vector generated_tool_call_ids; - - task_result_state(const common_chat_syntax & oaicompat_chat_syntax) - : oaicompat_chat_syntax(oaicompat_chat_syntax) {} - - // parse partial tool calls and update the internal state - common_chat_msg update_chat_msg( - const std::string & text_added, - bool is_partial, - std::vector & diffs); -}; - struct server_task_result { int id = -1; int id_slot = -1;