Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tools/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ Note for `multimodal_data` in JSON object prompts. This should be an array of st
`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.

`n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.

`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.

`stop`: Specify a JSON array of stopping strings.
Expand Down
24 changes: 12 additions & 12 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,18 @@ int32_t server_tokens::process_chunk(
return 0;
}

server_tokens server_tokens::clone() const {
server_tokens res;
res.has_mtmd = has_mtmd;
res.tokens = tokens;
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
size_t idx = it->first;
const mtmd::input_chunk_ptr & chunk = it->second;
res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get()));
}
return res;
}

Comment on lines +497 to +508
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have this function, I think we can enable host-memory prompt caching with mtmd:

  • Update this code:

// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
auto & cur = states.emplace_back();
cur = {
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
/*.data =*/ std::move(state_data),
/*.checkpoints =*/ prompt.checkpoints,
};

  • Remove this condition:

// TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr);

I haven't tested, but I think the only reason that prompt caching didn't work was because wasn't sure how to copy the server_tokens. So it's worth giving it a try after these changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will be nice to enable RAM cache for mtmd. I created an issue so we can have a look later on: #17821

//
// tokenizer and input processing utils
//
Expand Down Expand Up @@ -745,12 +757,6 @@ json oaicompat_completion_params_parse(const json & body) {
llama_params["stop"] = json_value(body, "stop", json::array());
}

// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
throw std::runtime_error("Only one completion choice is allowed");
}

// Handle "echo" field
if (json_value(body, "echo", false)) {
throw std::runtime_error("Only no echo is supported");
Expand Down Expand Up @@ -1049,12 +1055,6 @@ json oaicompat_chat_params_parse(
llama_params["chat_parser"] = chat_params.parser;
}

// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
throw std::invalid_argument("Only one completion choice is allowed");
}

// Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (json_value(body, "logprobs", false)) {
Expand Down
2 changes: 2 additions & 0 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ struct server_tokens {
llama_pos pos,
int32_t seq_id,
size_t & n_tokens_out) const;

server_tokens clone() const;
};


Expand Down
85 changes: 80 additions & 5 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING,
Expand Down Expand Up @@ -254,6 +255,15 @@ struct server_slot {
generated_token_probs.push_back(token);
}

// note: a slot can also be either a parent or a child
bool is_parent() const {
return is_processing() && task->n_children > 0;
}

bool is_child() const {
return is_processing() && task->id_parent >= 0;
}

void release() {
if (is_processing()) {
GGML_ASSERT(task);
Expand Down Expand Up @@ -383,6 +393,17 @@ struct server_slot {

return res;
}

void copy_state_to(server_slot & other) const {
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
other.n_decoded = n_decoded;
other.n_remaining = n_remaining;
other.i_batch = i_batch;
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
other.n_prompt_tokens_processed = n_prompt_tokens_processed;
other.prompt = prompt.clone();
}
};


Expand Down Expand Up @@ -1022,7 +1043,9 @@ struct server_context_impl {

slot.task = std::make_unique<const server_task>(std::move(task));

slot.state = SLOT_STATE_STARTED;
slot.state = slot.is_child()
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
: SLOT_STATE_STARTED;

SLT_INF(slot, "%s", "processing task\n");

Expand Down Expand Up @@ -1684,6 +1707,12 @@ struct server_context_impl {
GGML_ABORT("not supported by multimodal");
}

if (slot.is_parent() || slot.is_child()) {
send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
slot.release();
continue;
}

Comment on lines +1710 to +1715
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, what is the reason to not support context shift here?

Copy link
Collaborator Author

@ngxson ngxson Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure about this, but IIUC llama_kv_cache::seq_add does not have a notion of copy-on-write. For example, if a KV cell is both used by 2 sequences, one seq shifting it will also cause the second to also be shifted

This is fine if the current (generating) token position is synchronized among all sequence, but we don't have an explicit logic to guarantee that this will always happen

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the generation length of each sequence an be different, which can be quite difficult to keep track

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that is correct. The problem is that some of the tokens are shared when we use unified KV cache. It would work with split KV cache, but maybe it's not worth the extra logic branching.

Either way, context shifting is probably something that we should remove at some point - it does not have much value with today's models with more than 128k token contexts.

// Shift context
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;

Expand Down Expand Up @@ -2308,6 +2337,26 @@ struct server_context_impl {
n_batch = llama_n_batch(ctx);

for (auto & slot : slots) {
// may need to copy state to other slots
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
std::vector<server_slot *> child_slots;
for (auto & other : slots) {
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
child_slots.push_back(&other);
}
}

// we can only proceed if all child slots are having the correct tasks
if (child_slots.size() == slot.task->n_children) {
// copy state to the child slots
for (auto & child : child_slots) {
SLT_INF(slot, "copying state to child %d\n", child->id);
slot.copy_state_to(*child);
child->state = SLOT_STATE_DONE_PROMPT;
}
}
}

// optionally send prompt processing progress
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.task->params.stream && slot.task->params.return_progress) {
Expand Down Expand Up @@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
}
tasks.reserve(inputs.size());
states.reserve(inputs.size());
int idx = 0;
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);

task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.index = idx++;

task.tokens = std::move(inputs[i]);
task.params = server_task::params_from_json_cmpl(
Expand All @@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task.params.oaicompat_model = ctx_server.model_name;
states.push_back(task.params.oaicompat_chat_syntax);

if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1;
for (size_t j = 0; j < task.n_children; j++) {
server_task child = task.create_child(
task.id,
ctx_server.queue_tasks.get_new_id(),
idx++);
states.push_back(child.params.oaicompat_chat_syntax);
tasks.push_back(std::move(child));
Comment on lines +2672 to +2673
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should improve this by making tasks and states more associated with each other - feel like this is currently error-prone because one might forget to update the states when adding a new task.

Does it make sense to have the task_result_state be part of the server_task itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to have the task_result_state be part of the server_task itself?

The principle is that server_task will be std::move to task queue, and eventually be moved to slot, so it cannot hold task_result_state because the state need to stays in the HTTP thread

What I'm thinking is that we can just allow server_response_reader to create the state for each task, because currently tasks need to be posted by server_response_reader anyway

Btw, the further plan is to only expose server_response_reader to HTTP handlers as the API is easier to follow and it's also safer than managing directly the server_queue/response. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll implement this in a follow-up PR

}
}

tasks.push_back(std::move(task));
}

Expand All @@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
arr.push_back(res->to_json());
}
// if single request, return single object instead of array
res->ok(arr.size() == 1 ? arr[0] : arr);
GGML_ASSERT(!arr.empty() && "empty results");
if (arr.size() == 1) {
// if single request, return single object instead of array
res->ok(arr[0]);
} else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
// if multiple results in OAI format, we need to re-format them
json & choices = arr[0]["choices"];
for (size_t i = 1; i < arr.size(); i++) {
choices.push_back(std::move(arr[i]["choices"][0]));
}
res->ok(arr[0]);
} else {
// multi-results, non-OAI compat
res->ok(arr);
}
}
} else {
// in streaming mode, the first error must be treated as non-stream response
Expand Down
9 changes: 7 additions & 2 deletions tools/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ task_params server_task::params_from_json_cmpl(
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
Expand Down Expand Up @@ -453,6 +454,10 @@ task_params server_task::params_from_json_cmpl(
}
}

if (params.n_cmpl > params_base.n_parallel) {
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
}

return params;
}

Expand Down Expand Up @@ -664,7 +669,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {

json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"index", index},
{"message", msg.to_json_oaicompat<json>()},
};

Expand Down Expand Up @@ -1064,7 +1069,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
{"choices", json::array({
json {
{"finish_reason", nullptr},
{"index", 0},
{"index", index},
{"delta", delta},
},
})},
Expand Down
24 changes: 24 additions & 0 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct task_params {
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
int32_t n_cmpl = 1; // number of completions to generate from this prompt

int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
Expand Down Expand Up @@ -89,6 +90,10 @@ struct server_task {
int id_target = -1;
int id_slot = -1;

// used by parallel sampling (multiple completions from same prompt)
size_t n_children = 0; // number of tasks reusing this prompt
int id_parent = -1;

// used by SERVER_TASK_TYPE_INFERENCE
task_params params;
server_tokens tokens;
Expand Down Expand Up @@ -130,6 +135,17 @@ struct server_task {
}
return ids;
}

server_task create_child(int id_parent, int id_child, int idx) const {
server_task copy;
copy.id = id_child;
copy.index = idx;
copy.id_parent = id_parent;
copy.params = params;
copy.type = type;
copy.tokens = tokens.clone();
return copy;
}
};

struct result_timings {
Expand Down Expand Up @@ -466,6 +482,14 @@ struct server_prompt {
int n_tokens() const {
return tokens.size();
}

server_prompt clone() const {
return server_prompt {
tokens.clone(),
data,
checkpoints
};
}
};

struct server_prompt_cache {
Expand Down
19 changes: 19 additions & 0 deletions tools/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,22 @@ def make_cmpl_request():
assert last_progress["total"] > 0
assert last_progress["processed"] == last_progress["total"]
assert total_batch_count == batch_count


def test_chat_completions_multiple_choices():
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"n": 2,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
})
assert res.status_code == 200
assert len(res.body["choices"]) == 2
for choice in res.body["choices"]:
assert "assistant" == choice["message"]["role"]
assert match_regex("Suddenly", choice["message"]["content"])
assert choice["finish_reason"] == "length"
Loading