Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 12e8bb9

Browse files
committed
fix: prompt_tokens slot
1 parent f17598f commit 12e8bb9

File tree

1 file changed

+127
-123
lines changed

1 file changed

+127
-123
lines changed

context/llama_server_context.h

Lines changed: 127 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ struct llama_client_slot {
331331
int32_t num_prompt_tokens_processed = 0;
332332

333333
json prompt;
334+
335+
// when a task is submitted, we first tokenize the prompt and store it here
336+
std::vector<llama_token> prompt_tokens;
337+
334338
std::string generated_text;
335339
llama_token sampled;
336340
std::vector<llama_token> cache_tokens;
@@ -867,6 +871,7 @@ struct llama_server_context {
867871
slot->ctx_sampling = llama_sampling_init(slot->sparams);
868872
llama_set_rng_seed(ctx, slot->params.seed);
869873
slot->command = LOAD_PROMPT;
874+
slot->prompt_tokens.clear();
870875

871876
all_slots_are_idle = false;
872877

@@ -1037,7 +1042,7 @@ struct llama_server_context {
10371042
slot.has_next_token = false;
10381043
}
10391044

1040-
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model)) {
1045+
if (result.tok == llama_token_eos(model)) {
10411046
slot.stopped_eos = true;
10421047
slot.has_next_token = false;
10431048
LOG_VERBOSE("eos token found", {});
@@ -1651,139 +1656,138 @@ struct llama_server_context {
16511656

16521657
// need process the prompt
16531658
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
1654-
slot.state = PROCESSING;
1655-
slot.command = NONE;
1656-
std::vector<llama_token> prompt_tokens;
1657-
slot.t_start_process_prompt = ggml_time_us();
1658-
slot.t_start_genereration = 0;
1659-
1660-
if (slot.infill) {
1661-
bool suff_rm_leading_spc = true;
1662-
if (params.input_suffix.find_first_of(' ') == 0 &&
1663-
params.input_suffix.size() > 1) {
1664-
params.input_suffix.erase(0, 1);
1665-
suff_rm_leading_spc = false;
1666-
}
1667-
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
1668-
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
1669-
1670-
const int space_token =
1671-
29871; // TODO: this should not be hardcoded
1672-
if (suff_rm_leading_spc && !suffix_tokens.empty() &&
1673-
suffix_tokens[0] == space_token) {
1674-
suffix_tokens.erase(suffix_tokens.begin());
1675-
}
1676-
1677-
prefix_tokens.insert(prefix_tokens.begin(),
1678-
llama_token_prefix(model));
1679-
prefix_tokens.insert(prefix_tokens.begin(),
1680-
llama_token_bos(model)); // always add BOS
1681-
prefix_tokens.insert(prefix_tokens.end(),
1682-
llama_token_suffix(model));
1683-
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(),
1684-
suffix_tokens.end());
1685-
prefix_tokens.push_back(llama_token_middle(model));
1686-
prompt_tokens = prefix_tokens;
1687-
} else {
1688-
prompt_tokens = tokenize(
1689-
slot.prompt,
1690-
system_prompt.empty() &&
1691-
add_bos_token); // add BOS if there isn't system prompt
1692-
}
1693-
1694-
slot.n_past = 0;
1695-
slot.num_prompt_tokens = prompt_tokens.size();
1696-
1697-
LOG_DEBUG << "prompt tokenized - "
1698-
<< " id_slot: " << slot.id << ", task_id: " << slot.task_id
1699-
<< ", n_ctx: " << slot.n_ctx
1700-
<< ", n_keep: " << slot.params.n_keep
1701-
<< ", n_prompt_tokens: " << slot.num_prompt_tokens;
1702-
// << ", prompt_tokens: "
1703-
// << tokens_to_str(ctx, prompt_tokens.cbegin(),
1704-
// prompt_tokens.cend());
1659+
auto& prompt_tokens = slot.prompt_tokens;
1660+
1661+
// we haven't tokenized the prompt yet - do it now:
1662+
if (prompt_tokens.empty()) {
1663+
slot.t_start_process_prompt = ggml_time_us();
1664+
slot.t_start_genereration = 0;
1665+
1666+
if (slot.infill) {
1667+
bool suff_rm_leading_spc = true;
1668+
if (params.input_suffix.find_first_of(' ') == 0 &&
1669+
params.input_suffix.size() > 1) {
1670+
params.input_suffix.erase(0, 1);
1671+
suff_rm_leading_spc = false;
1672+
}
1673+
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
1674+
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
1675+
1676+
const int space_token =
1677+
29871; // TODO: this should not be hardcoded
1678+
if (suff_rm_leading_spc && !suffix_tokens.empty() &&
1679+
suffix_tokens[0] == space_token) {
1680+
suffix_tokens.erase(suffix_tokens.begin());
1681+
}
17051682

1706-
if (slot.embedding) {
1707-
// this prompt is too large to process - discard it
1708-
if (slot.num_prompt_tokens > n_ubatch) {
1709-
LOG_DEBUG << "embedding: num_promt_tokens: "
1710-
<< slot.num_prompt_tokens << ", n_ubatch: " << n_ubatch;
1711-
slot.state = PROCESSING;
1712-
slot.command = NONE;
1713-
slot.release();
1714-
slot.print_timings();
1715-
send_final_response(slot);
1716-
continue;
1717-
}
1718-
} else {
1719-
if (slot.params.n_keep < 0) {
1720-
slot.params.n_keep = slot.num_prompt_tokens;
1721-
}
1722-
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
1723-
1724-
// if input prompt is too big, truncate it
1725-
if (slot.num_prompt_tokens >= slot.n_ctx) {
1726-
const int n_left = slot.n_ctx - slot.params.n_keep;
1727-
const int n_block_size = n_left / 2;
1728-
const int erased_blocks =
1729-
(slot.num_prompt_tokens - slot.params.n_keep - n_block_size) /
1730-
n_block_size;
1731-
1732-
std::vector<llama_token> new_tokens(
1733-
prompt_tokens.begin(),
1734-
prompt_tokens.begin() + slot.params.n_keep);
1735-
new_tokens.insert(new_tokens.end(),
1736-
prompt_tokens.begin() + slot.params.n_keep +
1737-
erased_blocks * n_block_size,
1738-
prompt_tokens.end());
1739-
1740-
LOG_DEBUG << "input truncated - "
1741-
<< "n_ctx: " << slot.n_ctx << ", n_keep"
1742-
<< slot.params.n_keep << ", n_left: " << n_left
1743-
<< ", new_tokens: "
1744-
<< tokens_to_str(ctx, new_tokens.cbegin(),
1745-
new_tokens.cend());
1746-
slot.truncated = true;
1747-
prompt_tokens = new_tokens;
1748-
1749-
slot.num_prompt_tokens = prompt_tokens.size();
1750-
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
1683+
prefix_tokens.insert(prefix_tokens.begin(),
1684+
llama_token_prefix(model));
1685+
prefix_tokens.insert(prefix_tokens.begin(),
1686+
llama_token_bos(model)); // always add BOS
1687+
prefix_tokens.insert(prefix_tokens.end(),
1688+
llama_token_suffix(model));
1689+
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(),
1690+
suffix_tokens.end());
1691+
prefix_tokens.push_back(llama_token_middle(model));
1692+
prompt_tokens = prefix_tokens;
1693+
} else {
1694+
prompt_tokens = tokenize(
1695+
slot.prompt,
1696+
system_prompt.empty() &&
1697+
add_bos_token); // add BOS if there isn't system prompt
17511698
}
17521699

1753-
llama_sampling_reset(slot.ctx_sampling);
1754-
1755-
if (!slot.params.cache_prompt) {
1756-
slot.n_past = 0;
1757-
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
1700+
slot.n_past = 0;
1701+
slot.num_prompt_tokens = prompt_tokens.size();
1702+
1703+
LOG_DEBUG << "prompt tokenized - "
1704+
<< " id_slot: " << slot.id
1705+
<< ", task_id: " << slot.task_id
1706+
<< ", n_ctx: " << slot.n_ctx
1707+
<< ", n_keep: " << slot.params.n_keep
1708+
<< ", n_prompt_tokens: " << slot.num_prompt_tokens;
1709+
// << ", prompt_tokens: "
1710+
// << tokens_to_str(ctx, prompt_tokens.cbegin(),
1711+
// prompt_tokens.cend());
1712+
1713+
if (slot.embedding) {
1714+
// this prompt is too large to process - discard it
1715+
if (slot.num_prompt_tokens > n_ubatch) {
1716+
LOG_DEBUG << "embedding: num_promt_tokens: "
1717+
<< slot.num_prompt_tokens
1718+
<< ", n_ubatch: " << n_ubatch;
1719+
slot.state = PROCESSING;
1720+
slot.command = NONE;
1721+
slot.release();
1722+
slot.print_timings();
1723+
send_final_response(slot);
1724+
continue;
1725+
}
17581726
} else {
1759-
// push the prompt into the sampling context (do not apply grammar)
1760-
for (auto& token : prompt_tokens) {
1761-
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
1727+
if (slot.params.n_keep < 0) {
1728+
slot.params.n_keep = slot.num_prompt_tokens;
1729+
}
1730+
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
1731+
1732+
// if input prompt is too big, truncate it
1733+
if (slot.num_prompt_tokens >= slot.n_ctx) {
1734+
const int n_left = slot.n_ctx - slot.params.n_keep;
1735+
const int n_block_size = n_left / 2;
1736+
const int erased_blocks = (slot.num_prompt_tokens -
1737+
slot.params.n_keep - n_block_size) /
1738+
n_block_size;
1739+
1740+
std::vector<llama_token> new_tokens(
1741+
prompt_tokens.begin(),
1742+
prompt_tokens.begin() + slot.params.n_keep);
1743+
new_tokens.insert(new_tokens.end(),
1744+
prompt_tokens.begin() + slot.params.n_keep +
1745+
erased_blocks * n_block_size,
1746+
prompt_tokens.end());
1747+
1748+
LOG_DEBUG << "input truncated - "
1749+
<< "n_ctx: " << slot.n_ctx << ", n_keep"
1750+
<< slot.params.n_keep << ", n_left: " << n_left
1751+
<< ", new_tokens: "
1752+
<< tokens_to_str(ctx, new_tokens.cbegin(),
1753+
new_tokens.cend());
1754+
slot.truncated = true;
1755+
prompt_tokens = new_tokens;
1756+
1757+
slot.num_prompt_tokens = prompt_tokens.size();
1758+
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
17621759
}
17631760

1764-
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
1765-
slot.num_prompt_tokens_processed =
1766-
slot.num_prompt_tokens - slot.n_past;
1761+
llama_sampling_reset(slot.ctx_sampling);
17671762

1768-
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n",
1769-
slot.id, slot.n_past, slot.num_prompt_tokens_processed);
1770-
}
1771-
}
1763+
if (!slot.params.cache_prompt) {
1764+
slot.n_past = 0;
1765+
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
1766+
} else {
1767+
// push the prompt into the sampling context (do not apply grammar)
1768+
for (auto& token : prompt_tokens) {
1769+
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
1770+
}
17721771

1773-
// LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id,
1774-
// (int)system_tokens.size() + slot.n_past);
1772+
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
1773+
slot.num_prompt_tokens_processed =
1774+
slot.num_prompt_tokens - slot.n_past;
17751775

1776-
// llama_kv_cache_seq_rm(ctx, slot.id,
1777-
// system_tokens.size() + slot.n_past, -1);
1776+
LOG_TEE(
1777+
"slot %d : in cache: %i tokens | to process: %i tokens\n",
1778+
slot.id, slot.n_past, slot.num_prompt_tokens_processed);
1779+
}
1780+
}
17781781

1779-
// slot.cache_tokens = prompt_tokens;
1782+
if (slot.n_past == slot.num_prompt_tokens) {
1783+
// we have to evaluate at least 1 token to generate logits.
1784+
LOG_DEBUG << "slot " << slot.id
1785+
<< " : we have to evaluate at least 1 token to "
1786+
"generate logits";
1787+
slot.n_past--;
1788+
}
17801789

1781-
if (slot.n_past == slot.num_prompt_tokens) {
1782-
// we have to evaluate at least 1 token to generate logits.
1783-
LOG_DEBUG << "slot " << slot.id
1784-
<< " : we have to evaluate at least 1 token to "
1785-
"generate logits";
1786-
slot.n_past--;
1790+
slot.num_prompt_tokens_processed = 0;
17871791
}
17881792

17891793
if (slot.embedding) {

0 commit comments

Comments
 (0)