Skip to content

Conversation

@ngxson
Copy link
Collaborator

@ngxson ngxson commented Dec 5, 2025

Fix #11142


Implementation

The requirement is that number of slots must be equal or larger than number of "n" completion choices.

  1. When task is created, we create N-1 child tasks and 1 parent task
  2. Parent task is guaranteed to be loaded first (because we push all tasks into queue under one lock acquired). Child tasks are also loaded into slots at this point, but state will be set to SLOT_STATE_WAIT_OTHER
  3. We begin processing parent's prompt
  4. When parent prompt processing is done, we gather all children having SLOT_STATE_WAIT_OTHER state, then copy parent's state into these slots via llama_memory_seq_cp
    • Note: at this point, if we cannot yet gather all children (maybe one of slot is busying with another task), we wait until the task is done and all children are on their slots. This can potentially decrease the overall throughput, but makes the implementation easier to understand
  5. Continue to sampling and token generation as usual

TODO:

  • fix "invalid input batch" error
  • do not allow context shifting
  • add OAI output format
  • add tests

@ngxson ngxson marked this pull request as ready for review December 5, 2025 15:40
@ngxson ngxson requested a review from ggerganov as a code owner December 5, 2025 15:40
@ngxson
Copy link
Collaborator Author

ngxson commented Dec 5, 2025

@allozaur @ServeurpersoCom one application of this feature can be having multiple response choices on web UI. Kinda a low-prio feature, I think could be quite nice to add!

Edit: we could technically also add per-response sampling control, for example one response with temperature=0.0 and another response with 1.0; there are many possibilities, but we need to see what's the use case exactly

Example on chatgpt:

image

@allozaur
Copy link
Collaborator

allozaur commented Dec 5, 2025

@allozaur @ServeurpersoCom one application of this feature can be having multiple response choices on web UI. Kinda a low-prio feature, I think could be quite nice to add!

Edit: we could technically also add per-response sampling control, for example one response with temperature=0.0 and another response with 1.0; there are many possibilities, but we need to see what's the use case exactly

Example on chatgpt:

image

Oh, absolutely! I would love to take over this one, maybe still this year?

@ngxson
Copy link
Collaborator Author

ngxson commented Dec 5, 2025

Oh, absolutely! I would love to take over this one, maybe still this year?

yeah no rush! feel free to start the task as soon as this PR is merged

@github-actions github-actions bot added the python python script changes label Dec 5, 2025
@ITankForCAD
Copy link

This is more of an idea than a desired feature, at least for the moment but, multiple generations from the same prompt would allow for "best-of-n" scenarios. optillm is a good example of this.

@ngxson
Copy link
Collaborator Author

ngxson commented Dec 6, 2025

@ggerganov pinging in case you missed this PR

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! The implementation is much simpler than I anticipated.

Comment on lines +497 to +508
server_tokens server_tokens::clone() const {
server_tokens res;
res.has_mtmd = has_mtmd;
res.tokens = tokens;
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
size_t idx = it->first;
const mtmd::input_chunk_ptr & chunk = it->second;
res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get()));
}
return res;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have this function, I think we can enable host-memory prompt caching with mtmd:

  • Update this code:

// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
auto & cur = states.emplace_back();
cur = {
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
/*.data =*/ std::move(state_data),
/*.checkpoints =*/ prompt.checkpoints,
};

  • Remove this condition:

// TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr);

I haven't tested, but I think the only reason that prompt caching didn't work was because wasn't sure how to copy the server_tokens. So it's worth giving it a try after these changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will be nice to enable RAM cache for mtmd. I created an issue so we can have a look later on: #17821

Comment on lines +1710 to +1715
if (slot.is_parent() || slot.is_child()) {
send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
slot.release();
continue;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, what is the reason to not support context shift here?

Copy link
Collaborator Author

@ngxson ngxson Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure about this, but IIUC llama_kv_cache::seq_add does not have a notion of copy-on-write. For example, if a KV cell is both used by 2 sequences, one seq shifting it will also cause the second to also be shifted

This is fine if the current (generating) token position is synchronized among all sequence, but we don't have an explicit logic to guarantee that this will always happen

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the generation length of each sequence an be different, which can be quite difficult to keep track

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that is correct. The problem is that some of the tokens are shared when we use unified KV cache. It would work with split KV cache, but maybe it's not worth the extra logic branching.

Either way, context shifting is probably something that we should remove at some point - it does not have much value with today's models with more than 128k token contexts.

slot.copy_state_to(*child);
child->state = SLOT_STATE_DONE_PROMPT;
}
slot.state = SLOT_STATE_DONE_PROMPT;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in ea7f066

Comment on lines +2673 to +2674
states.push_back(child.params.oaicompat_chat_syntax);
tasks.push_back(std::move(child));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should improve this by making tasks and states more associated with each other - feel like this is currently error-prone because one might forget to update the states when adding a new task.

Does it make sense to have the task_result_state be part of the server_task itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to have the task_result_state be part of the server_task itself?

The principle is that server_task will be std::move to task queue, and eventually be moved to slot, so it cannot hold task_result_state because the state need to stays in the HTTP thread

What I'm thinking is that we can just allow server_response_reader to create the state for each task, because currently tasks need to be posted by server_response_reader anyway

Btw, the further plan is to only expose server_response_reader to HTTP handlers as the API is easier to follow and it's also safer than managing directly the server_queue/response. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll implement this in a follow-up PR

@ggerganov
Copy link
Member

ggerganov commented Dec 6, 2025

Can we make this work with the /completions and /infill endpoints?

Edit: nvm it works - just use "n_cmpl" instead of "n"

@ngxson
Copy link
Collaborator Author

ngxson commented Dec 6, 2025

btw for /completions and /infill, I added support for both n_cmpl and n fields

@ngxson ngxson merged commit c42712b into ggml-org:master Dec 6, 2025
64 of 75 checks passed
@jacekpoplawski
Copy link
Contributor

Do I understand correctly that with this change, instead of sending multiple separate requests with the same prompt, I can now send a single request and it will be faster?

@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Dec 6, 2025

Do I understand correctly that with this change, instead of sending multiple separate requests with the same prompt, I can now send a single request and it will be faster?

Try it with -np, --parallel (not tested yet, I'm not sure)

@ngxson
Copy link
Collaborator Author

ngxson commented Dec 6, 2025

Do I understand correctly that with this change, instead of sending multiple separate requests with the same prompt, I can now send a single request and it will be faster?

Yes it will be faster - the "n" option allow prompt to be process exactly once

@ServeurpersoCom
Copy link
Collaborator

Great work on this PR!

I can confirm parallel sequences work perfectly. Here's my test setup:

# Server with 4 parallel slots
/root/llama.cpp.pascal/build/bin/llama-server \
  --port 8082 -ngl 999 \
  -ctk q8_0 -ctv q8_0 -fa on --mlock \
  -np 4 -kvu --ctx-size 32768 \
  --models-dir /var/www/ia/models

# Testing n=4 with streaming
curl -N https://www.serveurperso.com/ia/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "unsloth/Qwen3-Coder-30B-A3B-Instruct-GGUF",
    "messages": [{"role": "user", "content": "Raconte une histoire"}],
    "n": 4,
    "stream": true
  }'
  
data: {"choices":[{"finish_reason":null,"index":0,"delta":{"content":"rena"}}],"created":1765035211,"id":"chatcmpl-6sAuIJ14qwTCRYq7lgPcbPvisCDZv6A5","model":"unsloth/Qwen3-Coder-30B-A3B-Instruct-GGUF","system_fingerprint":"b7359-7b09f44a5","object":"chat.completion.chunk"}

data: {"choices":[{"finish_reason":null,"index":1,"delta":{"content":"érie"}}],"created":1765035211,"id":"chatcmpl-6sAuIJ14qwTCRYq7lgPcbPvisCDZv6A5","model":"unsloth/Qwen3-Coder-30B-A3B-Instruct-GGUF","system_fingerprint":"b7359-7b09f44a5","object":"chat.completion.chunk"}

data: {"choices":[{"finish_reason":null,"index":2,"delta":{"content":" ton"}}],"created":1765035211,"id":"chatcmpl-6sAuIJ14qwTCRYq7lgPcbPvisCDZv6A5","model":"unsloth/Qwen3-Coder-30B-A3B-Instruct-GGUF","system_fingerprint":"b7359-7b09f44a5","object":"chat.completion.chunk"}

data: {"choices":[{"finish_reason":null,"index":3,"delta":{"content":"res"}}],"created":1765035211,"id":"chatcmpl-6sAuIJ14qwTCRYq7lgPcbPvisCDZv6A5","model":"unsloth/Qwen3-Coder-30B-A3B-Instruct-GGUF","system_fingerprint":"b7359-7b09f44a5","object":"chat.completion.chunk"}

The implementation correctly:

  • Processes the prompt once (shared via llama_memory_seq_cp)
  • Generates 4 different completions in parallel using 4 slots
  • Returns proper SSE stream with "index": 0-3 for each choice

This is especially efficient when memory-bound since parallel batching allows better compute utilization while waiting for memory bandwidth: getting 3 to 4x total throughput!

JayZenith pushed a commit to JayZenith/llama.cpp that referenced this pull request Dec 7, 2025
…ggml-org#17775)

* backend support

* server: support multiple generations from one prompt (OAI "n" option)

* fix invalid batch

* format oai

* clean up

* disable ctx shift

* add test

* update comments

* fix style

* add n_cmpl to docs [no ci]

* allowing using both n_cmpl and n
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples python python script changes server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

server : add support for multiple responses

6 participants