Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion tools/server/README-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,15 @@ graph TD
server_response --> server_routes
```

TODO: mention about how batching is handled by `server_slot`
### Batching

The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch.

Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all.

Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`.

Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration.

### Thread Management

Expand All @@ -62,6 +70,23 @@ Each incoming HTTP request is handled by its own thread managed by the HTTP libr
- All JSON formatting and chat template logic must stay in the HTTP layer.
- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.

### Example trace of a request

Here is an example trace of an API request for text completion:

- A request arrives at the HTTP layer.
- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked.
- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`.
- `server_res_generator` creates a new `task_result_state` for each task:
- `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages).
- `server_task` is moved into `server_queue` inside `server_context`.
- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`).
- `update_slot()` processes the task as described in the "Batching" section above.
- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue.
- At the same time, `server_res_generator` listens to the response queue and retrieves this response.
- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.

### Testing

`llama-server` includes an automated test suite based on `pytest`.
Expand Down
21 changes: 9 additions & 12 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2589,6 +2589,10 @@ struct server_context_impl {
int get_slot_n_ctx() {
return slots.back().n_ctx;
}

server_response_reader get_response_reader() {
return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
}
};

//
Expand Down Expand Up @@ -2618,8 +2622,8 @@ llama_context * server_context::get_llama_context() const {
return impl->ctx;
}

std::pair<server_queue &, server_response &> server_context::get_queues() {
return { impl->queue_tasks, impl->queue_results };
server_response_reader server_context::get_response_reader() {
return impl->get_response_reader();
}


Expand All @@ -2628,7 +2632,7 @@ std::pair<server_queue &, server_response &> server_context::get_queues() {
struct server_res_generator : server_http_res {
server_response_reader rd;
server_res_generator(server_context_impl & ctx_server)
: rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
void ok(const json & response_data) {
status = 200;
data = safe_json_to_str(response_data);
Expand Down Expand Up @@ -2661,9 +2665,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
try {
std::vector<server_task> tasks;

// tracking generation state and partial tool calls
std::vector<task_result_state> states;

const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
Expand All @@ -2679,7 +2680,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
tasks.reserve(inputs.size());
states.reserve(inputs.size());
int idx = 0;
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
Expand All @@ -2698,7 +2698,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);

if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1;
Expand All @@ -2707,15 +2706,13 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.id,
ctx_server.queue_tasks.get_new_id(),
idx++);
states.push_back(child.params.oaicompat_chat_syntax);
tasks.push_back(std::move(child));
}
}

tasks.push_back(std::move(task));
}

rd.set_states(std::move(states));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

states looks unused now - shouldn't it be removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I forgot - it's removed in e25bf4b

rd.post_tasks(std::move(tasks));
} catch (const std::exception & e) {
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
Expand Down Expand Up @@ -3445,7 +3442,7 @@ void server_routes::init_routes() {

// create and queue the task
json responses = json::array();
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
server_response_reader rd = ctx_server.get_response_reader();
{
std::vector<server_task> tasks;
tasks.reserve(documents.size());
Expand Down Expand Up @@ -3705,7 +3702,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons

// create and queue the task
json responses = json::array();
server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
server_response_reader rd = ctx_server.get_response_reader();
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
Expand Down
5 changes: 2 additions & 3 deletions tools/server/server-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ struct server_context {
// get the underlaying llama_context
llama_context * get_llama_context() const;

// get the underlaying queue_tasks and queue_results
// used by CLI application
std::pair<server_queue &, server_response &> get_queues();
// get a new response reader, used by CLI application
server_response_reader get_response_reader();
};


Expand Down
13 changes: 11 additions & 2 deletions tools/server/server-queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,21 @@ void server_response::terminate() {
// server_response_reader
//

void server_response_reader::set_states(std::vector<task_result_state> && states) {
this->states = std::move(states);
void server_response_reader::post_task(server_task && task) {
GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
id_tasks.insert(task.id);
states.push_back(task.create_state());
queue_results.add_waiting_task_id(task.id);
queue_tasks.post(std::move(task));
}

void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
id_tasks = server_task::get_list_id(tasks);
states.reserve(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
states.push_back(tasks[i].create_state());
}
queue_results.add_waiting_tasks(tasks);
queue_tasks.post(std::move(tasks));
}
Expand Down
6 changes: 3 additions & 3 deletions tools/server/server-queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ struct server_response_reader {
std::vector<task_result_state> states;

// should_stop function will be called each polling_interval_seconds
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
: queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
~server_response_reader() {
stop();
}

void set_states(std::vector<task_result_state> && states);
void post_task(server_task && tasks);
void post_tasks(std::vector<server_task> && tasks);
bool has_next() const;

Expand Down
44 changes: 25 additions & 19 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@ struct task_params {
json to_json(bool only_metrics = false) const;
};

// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;

task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}

// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};

struct server_task {
int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)
Expand Down Expand Up @@ -149,6 +168,12 @@ struct server_task {
copy.tokens = tokens.clone();
return copy;
}

// the task will be moved into queue, then onto slots
// however, the state must be kept by caller (e.g., HTTP thread)
task_result_state create_state() const {
return task_result_state(params.oaicompat_chat_syntax);
}
};

struct result_timings {
Expand Down Expand Up @@ -180,25 +205,6 @@ struct result_prompt_progress {
json to_json() const;
};

// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;

task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}

// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};

struct server_task_result {
int id = -1;
int id_slot = -1;
Expand Down
Loading