@@ -331,6 +331,10 @@ struct llama_client_slot {
331331 int32_t num_prompt_tokens_processed = 0 ;
332332
333333 json prompt;
334+
335+ // when a task is submitted, we first tokenize the prompt and store it here
336+ std::vector<llama_token> prompt_tokens;
337+
334338 std::string generated_text;
335339 llama_token sampled;
336340 std::vector<llama_token> cache_tokens;
@@ -867,6 +871,7 @@ struct llama_server_context {
867871 slot->ctx_sampling = llama_sampling_init (slot->sparams );
868872 llama_set_rng_seed (ctx, slot->params .seed );
869873 slot->command = LOAD_PROMPT;
874+ slot->prompt_tokens .clear ();
870875
871876 all_slots_are_idle = false ;
872877
@@ -1037,7 +1042,7 @@ struct llama_server_context {
10371042 slot.has_next_token = false ;
10381043 }
10391044
1040- if (!slot. cache_tokens . empty () && result.tok == llama_token_eos (model)) {
1045+ if (result.tok == llama_token_eos (model)) {
10411046 slot.stopped_eos = true ;
10421047 slot.has_next_token = false ;
10431048 LOG_VERBOSE (" eos token found" , {});
@@ -1651,139 +1656,138 @@ struct llama_server_context {
16511656
16521657 // need process the prompt
16531658 if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
1654- slot.state = PROCESSING;
1655- slot.command = NONE;
1656- std::vector<llama_token> prompt_tokens;
1657- slot.t_start_process_prompt = ggml_time_us ();
1658- slot.t_start_genereration = 0 ;
1659-
1660- if (slot.infill ) {
1661- bool suff_rm_leading_spc = true ;
1662- if (params.input_suffix .find_first_of (' ' ) == 0 &&
1663- params.input_suffix .size () > 1 ) {
1664- params.input_suffix .erase (0 , 1 );
1665- suff_rm_leading_spc = false ;
1666- }
1667- auto prefix_tokens = tokenize (slot.params .input_prefix , false );
1668- auto suffix_tokens = tokenize (slot.params .input_suffix , false );
1669-
1670- const int space_token =
1671- 29871 ; // TODO: this should not be hardcoded
1672- if (suff_rm_leading_spc && !suffix_tokens.empty () &&
1673- suffix_tokens[0 ] == space_token) {
1674- suffix_tokens.erase (suffix_tokens.begin ());
1675- }
1676-
1677- prefix_tokens.insert (prefix_tokens.begin (),
1678- llama_token_prefix (model));
1679- prefix_tokens.insert (prefix_tokens.begin (),
1680- llama_token_bos (model)); // always add BOS
1681- prefix_tokens.insert (prefix_tokens.end (),
1682- llama_token_suffix (model));
1683- prefix_tokens.insert (prefix_tokens.end (), suffix_tokens.begin (),
1684- suffix_tokens.end ());
1685- prefix_tokens.push_back (llama_token_middle (model));
1686- prompt_tokens = prefix_tokens;
1687- } else {
1688- prompt_tokens = tokenize (
1689- slot.prompt ,
1690- system_prompt.empty () &&
1691- add_bos_token); // add BOS if there isn't system prompt
1692- }
1693-
1694- slot.n_past = 0 ;
1695- slot.num_prompt_tokens = prompt_tokens.size ();
1696-
1697- LOG_DEBUG << " prompt tokenized - "
1698- << " id_slot: " << slot.id << " , task_id: " << slot.task_id
1699- << " , n_ctx: " << slot.n_ctx
1700- << " , n_keep: " << slot.params .n_keep
1701- << " , n_prompt_tokens: " << slot.num_prompt_tokens ;
1702- // << ", prompt_tokens: "
1703- // << tokens_to_str(ctx, prompt_tokens.cbegin(),
1704- // prompt_tokens.cend());
1659+ auto & prompt_tokens = slot.prompt_tokens ;
1660+
1661+ // we haven't tokenized the prompt yet - do it now:
1662+ if (prompt_tokens.empty ()) {
1663+ slot.t_start_process_prompt = ggml_time_us ();
1664+ slot.t_start_genereration = 0 ;
1665+
1666+ if (slot.infill ) {
1667+ bool suff_rm_leading_spc = true ;
1668+ if (params.input_suffix .find_first_of (' ' ) == 0 &&
1669+ params.input_suffix .size () > 1 ) {
1670+ params.input_suffix .erase (0 , 1 );
1671+ suff_rm_leading_spc = false ;
1672+ }
1673+ auto prefix_tokens = tokenize (slot.params .input_prefix , false );
1674+ auto suffix_tokens = tokenize (slot.params .input_suffix , false );
1675+
1676+ const int space_token =
1677+ 29871 ; // TODO: this should not be hardcoded
1678+ if (suff_rm_leading_spc && !suffix_tokens.empty () &&
1679+ suffix_tokens[0 ] == space_token) {
1680+ suffix_tokens.erase (suffix_tokens.begin ());
1681+ }
17051682
1706- if (slot.embedding ) {
1707- // this prompt is too large to process - discard it
1708- if (slot.num_prompt_tokens > n_ubatch) {
1709- LOG_DEBUG << " embedding: num_promt_tokens: "
1710- << slot.num_prompt_tokens << " , n_ubatch: " << n_ubatch;
1711- slot.state = PROCESSING;
1712- slot.command = NONE;
1713- slot.release ();
1714- slot.print_timings ();
1715- send_final_response (slot);
1716- continue ;
1717- }
1718- } else {
1719- if (slot.params .n_keep < 0 ) {
1720- slot.params .n_keep = slot.num_prompt_tokens ;
1721- }
1722- slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
1723-
1724- // if input prompt is too big, truncate it
1725- if (slot.num_prompt_tokens >= slot.n_ctx ) {
1726- const int n_left = slot.n_ctx - slot.params .n_keep ;
1727- const int n_block_size = n_left / 2 ;
1728- const int erased_blocks =
1729- (slot.num_prompt_tokens - slot.params .n_keep - n_block_size) /
1730- n_block_size;
1731-
1732- std::vector<llama_token> new_tokens (
1733- prompt_tokens.begin (),
1734- prompt_tokens.begin () + slot.params .n_keep );
1735- new_tokens.insert (new_tokens.end (),
1736- prompt_tokens.begin () + slot.params .n_keep +
1737- erased_blocks * n_block_size,
1738- prompt_tokens.end ());
1739-
1740- LOG_DEBUG << " input truncated - "
1741- << " n_ctx: " << slot.n_ctx << " , n_keep"
1742- << slot.params .n_keep << " , n_left: " << n_left
1743- << " , new_tokens: "
1744- << tokens_to_str (ctx, new_tokens.cbegin (),
1745- new_tokens.cend ());
1746- slot.truncated = true ;
1747- prompt_tokens = new_tokens;
1748-
1749- slot.num_prompt_tokens = prompt_tokens.size ();
1750- GGML_ASSERT (slot.num_prompt_tokens < slot.n_ctx );
1683+ prefix_tokens.insert (prefix_tokens.begin (),
1684+ llama_token_prefix (model));
1685+ prefix_tokens.insert (prefix_tokens.begin (),
1686+ llama_token_bos (model)); // always add BOS
1687+ prefix_tokens.insert (prefix_tokens.end (),
1688+ llama_token_suffix (model));
1689+ prefix_tokens.insert (prefix_tokens.end (), suffix_tokens.begin (),
1690+ suffix_tokens.end ());
1691+ prefix_tokens.push_back (llama_token_middle (model));
1692+ prompt_tokens = prefix_tokens;
1693+ } else {
1694+ prompt_tokens = tokenize (
1695+ slot.prompt ,
1696+ system_prompt.empty () &&
1697+ add_bos_token); // add BOS if there isn't system prompt
17511698 }
17521699
1753- llama_sampling_reset (slot.ctx_sampling );
1754-
1755- if (!slot.params .cache_prompt ) {
1756- slot.n_past = 0 ;
1757- slot.num_prompt_tokens_processed = slot.num_prompt_tokens ;
1700+ slot.n_past = 0 ;
1701+ slot.num_prompt_tokens = prompt_tokens.size ();
1702+
1703+ LOG_DEBUG << " prompt tokenized - "
1704+ << " id_slot: " << slot.id
1705+ << " , task_id: " << slot.task_id
1706+ << " , n_ctx: " << slot.n_ctx
1707+ << " , n_keep: " << slot.params .n_keep
1708+ << " , n_prompt_tokens: " << slot.num_prompt_tokens ;
1709+ // << ", prompt_tokens: "
1710+ // << tokens_to_str(ctx, prompt_tokens.cbegin(),
1711+ // prompt_tokens.cend());
1712+
1713+ if (slot.embedding ) {
1714+ // this prompt is too large to process - discard it
1715+ if (slot.num_prompt_tokens > n_ubatch) {
1716+ LOG_DEBUG << " embedding: num_promt_tokens: "
1717+ << slot.num_prompt_tokens
1718+ << " , n_ubatch: " << n_ubatch;
1719+ slot.state = PROCESSING;
1720+ slot.command = NONE;
1721+ slot.release ();
1722+ slot.print_timings ();
1723+ send_final_response (slot);
1724+ continue ;
1725+ }
17581726 } else {
1759- // push the prompt into the sampling context (do not apply grammar)
1760- for (auto & token : prompt_tokens) {
1761- llama_sampling_accept (slot.ctx_sampling , ctx, token, false );
1727+ if (slot.params .n_keep < 0 ) {
1728+ slot.params .n_keep = slot.num_prompt_tokens ;
1729+ }
1730+ slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
1731+
1732+ // if input prompt is too big, truncate it
1733+ if (slot.num_prompt_tokens >= slot.n_ctx ) {
1734+ const int n_left = slot.n_ctx - slot.params .n_keep ;
1735+ const int n_block_size = n_left / 2 ;
1736+ const int erased_blocks = (slot.num_prompt_tokens -
1737+ slot.params .n_keep - n_block_size) /
1738+ n_block_size;
1739+
1740+ std::vector<llama_token> new_tokens (
1741+ prompt_tokens.begin (),
1742+ prompt_tokens.begin () + slot.params .n_keep );
1743+ new_tokens.insert (new_tokens.end (),
1744+ prompt_tokens.begin () + slot.params .n_keep +
1745+ erased_blocks * n_block_size,
1746+ prompt_tokens.end ());
1747+
1748+ LOG_DEBUG << " input truncated - "
1749+ << " n_ctx: " << slot.n_ctx << " , n_keep"
1750+ << slot.params .n_keep << " , n_left: " << n_left
1751+ << " , new_tokens: "
1752+ << tokens_to_str (ctx, new_tokens.cbegin (),
1753+ new_tokens.cend ());
1754+ slot.truncated = true ;
1755+ prompt_tokens = new_tokens;
1756+
1757+ slot.num_prompt_tokens = prompt_tokens.size ();
1758+ GGML_ASSERT (slot.num_prompt_tokens < slot.n_ctx );
17621759 }
17631760
1764- slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
1765- slot.num_prompt_tokens_processed =
1766- slot.num_prompt_tokens - slot.n_past ;
1761+ llama_sampling_reset (slot.ctx_sampling );
17671762
1768- LOG_TEE (" slot %d : in cache: %i tokens | to process: %i tokens\n " ,
1769- slot.id , slot.n_past , slot.num_prompt_tokens_processed );
1770- }
1771- }
1763+ if (!slot.params .cache_prompt ) {
1764+ slot.n_past = 0 ;
1765+ slot.num_prompt_tokens_processed = slot.num_prompt_tokens ;
1766+ } else {
1767+ // push the prompt into the sampling context (do not apply grammar)
1768+ for (auto & token : prompt_tokens) {
1769+ llama_sampling_accept (slot.ctx_sampling , ctx, token, false );
1770+ }
17721771
1773- // LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id,
1774- // (int)system_tokens.size() + slot.n_past);
1772+ slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
1773+ slot.num_prompt_tokens_processed =
1774+ slot.num_prompt_tokens - slot.n_past ;
17751775
1776- // llama_kv_cache_seq_rm(ctx, slot.id,
1777- // system_tokens.size() + slot.n_past, -1);
1776+ LOG_TEE (
1777+ " slot %d : in cache: %i tokens | to process: %i tokens\n " ,
1778+ slot.id , slot.n_past , slot.num_prompt_tokens_processed );
1779+ }
1780+ }
17781781
1779- // slot.cache_tokens = prompt_tokens;
1782+ if (slot.n_past == slot.num_prompt_tokens ) {
1783+ // we have to evaluate at least 1 token to generate logits.
1784+ LOG_DEBUG << " slot " << slot.id
1785+ << " : we have to evaluate at least 1 token to "
1786+ " generate logits" ;
1787+ slot.n_past --;
1788+ }
17801789
1781- if (slot.n_past == slot.num_prompt_tokens ) {
1782- // we have to evaluate at least 1 token to generate logits.
1783- LOG_DEBUG << " slot " << slot.id
1784- << " : we have to evaluate at least 1 token to "
1785- " generate logits" ;
1786- slot.n_past --;
1790+ slot.num_prompt_tokens_processed = 0 ;
17871791 }
17881792
17891793 if (slot.embedding ) {
0 commit comments