Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,16 @@ int main(int argc, char ** argv) {

params.embedding = true;

// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();

// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts
if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true;
params.n_parallel = n_seq_max;
}

// utilize the full context
Expand All @@ -123,9 +127,6 @@ int main(int argc, char ** argv) {
params.n_ubatch = params.n_batch;
}

// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();

llama_backend_init();
llama_numa_init(params.numa);

Expand Down
4 changes: 2 additions & 2 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ llama_context::llama_context(

cross.v_embd.clear();

const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

// avoid reserving graphs with zero outputs - assume one output per sequence
Expand Down Expand Up @@ -542,7 +542,7 @@ bool llama_context::memory_update(bool optimize) {
throw std::runtime_error("failed to initialize memory context");
}

const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't recall why this cparams.kv_unified check was added in #14363. It seems unnecessary now and removing it gives a better estimate of the worst-case graph.

const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
Expand Down
Loading