-
Notifications
You must be signed in to change notification settings - Fork 14.1k
server: support multiple generations from one prompt (OAI "n" option) #17775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
15ce574
0d842cb
bf33d13
a768a5e
5cc3156
2a7728f
e066071
46f6fd2
b65ee64
6fb3226
ea7f066
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1; | |
| // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 | ||
| enum slot_state { | ||
| SLOT_STATE_IDLE, | ||
| SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future | ||
| SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt | ||
| SLOT_STATE_STARTED, // after assigning a task and about to process prompt | ||
| SLOT_STATE_PROCESSING_PROMPT, | ||
| SLOT_STATE_DONE_PROMPT, | ||
| SLOT_STATE_GENERATING, | ||
|
|
@@ -254,6 +255,15 @@ struct server_slot { | |
| generated_token_probs.push_back(token); | ||
| } | ||
|
|
||
| // note: a slot can also be either a parent or a child | ||
| bool is_parent() const { | ||
| return is_processing() && task->n_children > 0; | ||
| } | ||
|
|
||
| bool is_child() const { | ||
| return is_processing() && task->id_parent >= 0; | ||
| } | ||
|
|
||
| void release() { | ||
| if (is_processing()) { | ||
| GGML_ASSERT(task); | ||
|
|
@@ -383,6 +393,17 @@ struct server_slot { | |
|
|
||
| return res; | ||
| } | ||
|
|
||
| void copy_state_to(server_slot & other) const { | ||
| llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1); | ||
| llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1); | ||
| other.n_decoded = n_decoded; | ||
| other.n_remaining = n_remaining; | ||
| other.i_batch = i_batch; | ||
| other.n_prompt_tokens_cache = n_prompt_tokens_cache; | ||
| other.n_prompt_tokens_processed = n_prompt_tokens_processed; | ||
| other.prompt = prompt.clone(); | ||
| } | ||
| }; | ||
|
|
||
|
|
||
|
|
@@ -1022,7 +1043,9 @@ struct server_context_impl { | |
|
|
||
| slot.task = std::make_unique<const server_task>(std::move(task)); | ||
|
|
||
| slot.state = SLOT_STATE_STARTED; | ||
| slot.state = slot.is_child() | ||
| ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt | ||
| : SLOT_STATE_STARTED; | ||
|
|
||
| SLT_INF(slot, "%s", "processing task\n"); | ||
|
|
||
|
|
@@ -1684,6 +1707,12 @@ struct server_context_impl { | |
| GGML_ABORT("not supported by multimodal"); | ||
| } | ||
|
|
||
| if (slot.is_parent() || slot.is_child()) { | ||
| send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER); | ||
| slot.release(); | ||
| continue; | ||
| } | ||
|
|
||
|
Comment on lines
+1710
to
+1715
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, what is the reason to not support context shift here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite sure about this, but IIUC This is fine if the current (generating) token position is synchronized among all sequence, but we don't have an explicit logic to guarantee that this will always happen
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, the generation length of each sequence an be different, which can be quite difficult to keep track
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, that is correct. The problem is that some of the tokens are shared when we use unified KV cache. It would work with split KV cache, but maybe it's not worth the extra logic branching. Either way, context shifting is probably something that we should remove at some point - it does not have much value with today's models with more than 128k token contexts. |
||
| // Shift context | ||
| int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; | ||
|
|
||
|
|
@@ -2308,6 +2337,26 @@ struct server_context_impl { | |
| n_batch = llama_n_batch(ctx); | ||
|
|
||
| for (auto & slot : slots) { | ||
| // may need to copy state to other slots | ||
| if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) { | ||
| std::vector<server_slot *> child_slots; | ||
| for (auto & other : slots) { | ||
| if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { | ||
| child_slots.push_back(&other); | ||
| } | ||
| } | ||
|
|
||
| // we can only proceed if all child slots are having the correct tasks | ||
| if (child_slots.size() == slot.task->n_children) { | ||
| // copy state to the child slots | ||
| for (auto & child : child_slots) { | ||
| SLT_INF(slot, "copying state to child %d\n", child->id); | ||
| slot.copy_state_to(*child); | ||
| child->state = SLOT_STATE_DONE_PROMPT; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // optionally send prompt processing progress | ||
| if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { | ||
| if (slot.task->params.stream && slot.task->params.return_progress) { | ||
|
|
@@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl( | |
| } | ||
| tasks.reserve(inputs.size()); | ||
| states.reserve(inputs.size()); | ||
| int idx = 0; | ||
| for (size_t i = 0; i < inputs.size(); i++) { | ||
| server_task task = server_task(type); | ||
|
|
||
| task.id = ctx_server.queue_tasks.get_new_id(); | ||
| task.index = i; | ||
| task.index = idx++; | ||
|
|
||
| task.tokens = std::move(inputs[i]); | ||
| task.params = server_task::params_from_json_cmpl( | ||
|
|
@@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl( | |
| task.params.oaicompat_model = ctx_server.model_name; | ||
| states.push_back(task.params.oaicompat_chat_syntax); | ||
|
|
||
| if (task.params.n_cmpl > 1) { | ||
| task.n_children = task.params.n_cmpl - 1; | ||
| for (size_t j = 0; j < task.n_children; j++) { | ||
| server_task child = task.create_child( | ||
| task.id, | ||
| ctx_server.queue_tasks.get_new_id(), | ||
| idx++); | ||
| states.push_back(child.params.oaicompat_chat_syntax); | ||
| tasks.push_back(std::move(child)); | ||
|
Comment on lines
+2672
to
+2673
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should improve this by making Does it make sense to have the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The principle is that What I'm thinking is that we can just allow Btw, the further plan is to only expose
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll implement this in a follow-up PR |
||
| } | ||
| } | ||
|
|
||
| tasks.push_back(std::move(task)); | ||
| } | ||
|
|
||
|
|
@@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl( | |
| GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr); | ||
| arr.push_back(res->to_json()); | ||
| } | ||
| // if single request, return single object instead of array | ||
| res->ok(arr.size() == 1 ? arr[0] : arr); | ||
| GGML_ASSERT(!arr.empty() && "empty results"); | ||
| if (arr.size() == 1) { | ||
| // if single request, return single object instead of array | ||
| res->ok(arr[0]); | ||
| } else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) { | ||
| // if multiple results in OAI format, we need to re-format them | ||
| json & choices = arr[0]["choices"]; | ||
| for (size_t i = 1; i < arr.size(); i++) { | ||
| choices.push_back(std::move(arr[i]["choices"][0])); | ||
| } | ||
| res->ok(arr[0]); | ||
| } else { | ||
| // multi-results, non-OAI compat | ||
| res->ok(arr); | ||
| } | ||
| } | ||
| } else { | ||
| // in streaming mode, the first error must be treated as non-stream response | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we have this function, I think we can enable host-memory prompt caching with
mtmd:llama.cpp/tools/server/server-task.cpp
Lines 1396 to 1404 in 6fb3226
llama.cpp/tools/server/server-context.cpp
Lines 886 to 889 in 6fb3226
I haven't tested, but I think the only reason that prompt caching didn't work was because wasn't sure how to copy the
server_tokens. So it's worth giving it a try after these changes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will be nice to enable RAM cache for mtmd. I created an issue so we can have a look later on: #17821