diff --git a/common/arg.cpp b/common/arg.cpp index 4203da4a0a6..f8d84c93ff3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2574,12 +2574,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); add_opt(common_arg( {"--reasoning-budget"}, "N", - "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)", + "controls the maximum number of thinking tokens allowed; -1 for unlimited, 0 to disable thinking, or a positive value to limit thinking tokens (default: -1)", [](common_params & params, int value) { - if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); } + if (value < -1) { throw std::invalid_argument("invalid value: must be -1 (unlimited), 0 (disabled), or a positive number"); } params.reasoning_budget = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET")); + add_opt(common_arg( + {"--reasoning-force-close-message"}, "STRING", + string_format( + "if specified, forces the model to close its reasoning/thoughts when generating this message (default: %s)\n", + params.reasoning_force_close_message.c_str() + ), + [](common_params & params, const std::string & value) { + params.reasoning_force_close_message = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_FORCE_CLOSE_MESSAGE")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( diff --git a/common/chat-parser-xml-toolcall.cpp b/common/chat-parser-xml-toolcall.cpp index a80900ff8d8..03654957fce 100644 --- a/common/chat-parser-xml-toolcall.cpp +++ b/common/chat-parser-xml-toolcall.cpp @@ -705,6 +705,9 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons // Parse content bool reasoning_unclosed = builder.syntax().thinking_forced_open; + if (reasoning_unclosed) { + builder.mark_reasoning_active(end_think); + } std::string unclosed_reasoning_content(""); for (;;) { auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start); @@ -730,6 +733,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons } } else { reasoning_unclosed = false; + builder.mark_reasoning_closed(); std::string reasoning_content; if (pos == std::string::npos) { reasoning_content = std::move(content); @@ -766,6 +770,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons bool toolcall_in_think = false; for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) { if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) { + builder.mark_reasoning_active(end_think); if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size()); builder.add_reasoning_content(reasoning_content); @@ -773,6 +778,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons } else { think_start = think_end + end_think.size() - 1; } + builder.mark_reasoning_closed(); } else { // This start is in thinking block, skip this tool call // This start is in thinking block @@ -782,6 +788,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start; } reasoning_unclosed = true; + builder.mark_reasoning_active(end_think); content.resize(think_start); toolcall_in_think = true; } diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index d740dac0651..7a7cfdf0f33 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -156,6 +156,20 @@ void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_ result_.reasoning_content += reasoning_content; } +void common_chat_msg_parser::mark_reasoning_active(const std::string & end_tag) { + result_.reasoning_status.detected = true; + result_.reasoning_status.active = true; + if (!end_tag.empty()) { + result_.reasoning_status.end_tag = end_tag; + } +} + +void common_chat_msg_parser::mark_reasoning_closed() { + if (result_.reasoning_status.detected) { + result_.reasoning_status.active = false; + } +} + bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) { if (name.empty()) { return false; @@ -329,11 +343,13 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think const size_t saved_pos = pos_; const size_t saved_content_size = result_.content.size(); const size_t saved_reasoning_size = result_.reasoning_content.size(); + const auto saved_reasoning_status = result_.reasoning_status; auto restore_state = [&]() { move_to(saved_pos); result_.content.resize(saved_content_size); result_.reasoning_content.resize(saved_reasoning_size); + result_.reasoning_status = saved_reasoning_status; }; // Allow leading whitespace to be preserved as content when reasoning is present at the start @@ -370,9 +386,11 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think if (whitespace_end > pos_) { add_content(input_.substr(pos_, whitespace_end - pos_)); } + mark_reasoning_active(end_think); set_reasoning_prefix(cursor); cursor += start_think.size(); } else if (syntax_.thinking_forced_open) { + mark_reasoning_active(end_think); cursor = whitespace_end; } else { restore_state(); @@ -398,8 +416,10 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think if (end_pos > cursor) { handle_reasoning(input_.substr(cursor, end_pos - cursor), /* closed */ true); + mark_reasoning_closed(); } else { handle_reasoning("", /* closed */ true); + mark_reasoning_closed(); } cursor = end_pos + end_think.size(); @@ -420,6 +440,7 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think move_to(input_.size()); return true; } + mark_reasoning_active(end_think); set_reasoning_prefix(cursor); cursor += start_think.size(); continue; @@ -1492,10 +1513,12 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co return common_chat_peg_parse(syntax.parser, input, is_partial, syntax); } common_chat_msg_parser builder(input, is_partial, syntax); + bool partial_exception_caught = false; try { common_chat_parse(builder); } catch (const common_chat_msg_partial_exception & ex) { LOG_DBG("Partial parse: %s\n", ex.what()); + partial_exception_caught = true; if (!is_partial) { builder.clear_tools(); builder.move_to(0); @@ -1503,6 +1526,11 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co } } auto msg = builder.result(); + // Mark tool_call_in_progress if we caught a partial exception during partial parsing + // and there are tool calls in progress (indicates incomplete tool call parsing) + if (is_partial && partial_exception_caught && !msg.tool_calls.empty()) { + msg.tool_call_in_progress = true; + } if (!is_partial) { LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); } diff --git a/common/chat-parser.h b/common/chat-parser.h index 78c4b74c2db..9d8c8ff9a28 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -56,6 +56,10 @@ class common_chat_msg_parser { // Appends to the result.reasoning_content field void add_reasoning_content(const std::string & reasoning_content); + // Track reasoning status to expose start/end markers to callers + void mark_reasoning_active(const std::string & end_tag); + void mark_reasoning_closed(); + // Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything. bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); diff --git a/common/chat.h b/common/chat.h index 6085510a402..2a7058b035e 100644 --- a/common/chat.h +++ b/common/chat.h @@ -22,6 +22,19 @@ struct common_chat_tool_call { } }; +struct common_chat_reasoning_status { + bool detected = false; // a reasoning block start was observed + bool active = false; // we are currently inside a reasoning block (not closed yet) + std::string end_tag; // closing tag to use when forcing a close + + bool operator==(const common_chat_reasoning_status & other) const { + return detected == other.detected && active == other.active && end_tag == other.end_tag; + } + bool operator!=(const common_chat_reasoning_status & other) const { + return !(*this == other); + } +}; + struct common_chat_msg_content_part { std::string type; std::string text; @@ -37,6 +50,8 @@ struct common_chat_msg { std::vector content_parts; std::vector tool_calls; std::string reasoning_content; + common_chat_reasoning_status reasoning_status; + bool tool_call_in_progress = false; std::string tool_name; std::string tool_call_id; @@ -63,6 +78,7 @@ struct common_chat_msg { && content_parts == other.content_parts && tool_calls == other.tool_calls && reasoning_content == other.reasoning_content + && reasoning_status == other.reasoning_status && tool_name == other.tool_name && tool_call_id == other.tool_call_id; } diff --git a/common/common.cpp b/common/common.cpp index 0497f90a280..c3e114b99ed 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1078,6 +1078,14 @@ struct common_init_result common_init_from_params(common_params & params) { common_init_sampler_from_model(model, params.sampling); + // Allow models to override the forced reasoning close message via GGUF metadata + if (params.reasoning_force_close_message == COMMON_DEFAULT_REASONING_FORCE_CLOSE_MESSAGE) { + char buf[512] = {0}; + if (llama_model_meta_val_str(model, "tokenizer.ggml.reasoning_force_close_message", buf, sizeof(buf)) > 0) { + params.reasoning_force_close_message = buf; + } + } + const llama_vocab * vocab = llama_model_get_vocab(model); auto cparams = common_context_params_to_llama(params); diff --git a/common/common.h b/common/common.h index d28e48991c3..66075c62356 100644 --- a/common/common.h +++ b/common/common.h @@ -102,6 +102,8 @@ enum llama_example { LLAMA_EXAMPLE_COUNT, }; +inline constexpr const char * COMMON_DEFAULT_REASONING_FORCE_CLOSE_MESSAGE = "... I now conclude my reasoning and will provide the final answer."; + enum common_sampler_type { COMMON_SAMPLER_TYPE_NONE = 0, COMMON_SAMPLER_TYPE_DRY = 1, @@ -466,6 +468,7 @@ struct common_params { bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int reasoning_budget = -1; + std::string reasoning_force_close_message = COMMON_DEFAULT_REASONING_FORCE_CLOSE_MESSAGE; bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response std::vector api_keys; diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp index 4766518fe69..f8e0644d36d 100644 --- a/tests/test-chat-parser.cpp +++ b/tests/test-chat-parser.cpp @@ -119,6 +119,9 @@ static void test_reasoning() { auto msg = common_chat_parse(input, false, syntax); assert_equals(variant, std::string("Pense"), msg.reasoning_content); assert_equals(variant, std::string("Bonjour"), msg.content); + assert_equals(variant, true, msg.reasoning_status.detected); + assert_equals(variant, false, msg.reasoning_status.active); + assert_equals(variant, std::string(""), msg.reasoning_status.end_tag); } { const std::string variant("llama_3_inline_think"); @@ -133,6 +136,9 @@ static void test_reasoning() { auto msg = common_chat_parse(input, false, syntax); assert_equals(variant, std::string("Plan"), msg.reasoning_content); assert_equals(variant, std::string("Réponse"), msg.content); + assert_equals(variant, true, msg.reasoning_status.detected); + assert_equals(variant, false, msg.reasoning_status.active); + assert_equals(variant, std::string(""), msg.reasoning_status.end_tag); } // Test DeepSeek V3.1 parsing - reasoning content followed by "" and then regular content { diff --git a/tools/server/README.md b/tools/server/README.md index f98fb44c7bc..2680517fca7 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -203,7 +203,8 @@ For the ful list of features, please refer to [server's changelog](https://githu | `--jinja` | use jinja template for chat (default: enabled)

(env: LLAMA_ARG_JINJA) | | `--no-jinja` | disable jinja template for chat (default: enabled)

(env: LLAMA_ARG_NO_JINJA) | | `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: auto)
(env: LLAMA_ARG_THINK) | -| `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | +| `--reasoning-budget N` | controls the maximum number of thinking tokens allowed; -1 for unlimited, 0 to disable thinking, or a positive value to limit thinking tokens. When the budget is exceeded, the server automatically injects a closing `
` and continues with the final answer. Individual OpenAI-compatible requests can override this value with `thinking_budget_tokens`. (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | +| `--reasoning-force-close-message STRING` | when the reasoning budget is exceeded, this message is appended to the current user message to signal the model to close any open thought tags. (default: '... I now conclude my reasoning and will provide the final answer.')
(env: LLAMA_ARG_THINK_FORCE_CLOSE_MESSAGE) | | `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | | `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)
when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled

(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) | diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 4578f8d7a9f..eb49c755695 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -18,6 +18,8 @@ #include #include #include +#include +#include // fix problem with std::min and std::max #if defined(_WIN32) @@ -47,6 +49,13 @@ enum server_state { SERVER_STATE_READY, // Server is ready and model is loaded }; +enum reasoning_state { + REASONING_STATE_NONE, + REASONING_STATE_REASONING, + REASONING_STATE_PENDING_FORCE_CLOSE, + REASONING_STATE_FINISHED, +}; + static bool server_task_type_need_embd(server_task_type task_type) { switch (task_type) { case SERVER_TASK_TYPE_EMBEDDING: @@ -113,6 +122,12 @@ struct server_slot { bool has_new_line = false; bool truncated = false; + // reasoning budget tracking + int32_t n_reasoning_tokens = 0; // number of tokens generated while in reasoning/thinking mode + reasoning_state reasoning = REASONING_STATE_NONE; // are we currently in reasoning mode + std::string reasoning_end_tag; // the closing tag to inject when budget is exceeded (e.g., "") + std::deque forced_tokens; // tokens we must feed back to the model (e.g., forced ) + stop_type stop; std::string stopping_word; @@ -162,9 +177,11 @@ struct server_slot { size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; + int64_t t_start_reasoning; int64_t t_start_generation; double t_prompt_processing; // ms + double t_reasoning_token_generation; // ms double t_token_generation; // ms std::function callback_on_release; @@ -188,6 +205,13 @@ struct server_slot { drafted.clear(); i_batch_dft.clear(); + + // reset reasoning budget tracking + n_reasoning_tokens = 0; + reasoning = REASONING_STATE_NONE; + reasoning_end_tag = ""; + forced_tokens.clear(); + generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -372,15 +396,20 @@ struct server_slot { const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + const double t_reasoning = t_reasoning_token_generation / n_reasoning_tokens; + const double n_reasoning_second = 1e3 / t_reasoning_token_generation * n_reasoning_tokens; + const double t_gen = t_token_generation / n_decoded; const double n_gen_second = 1e3 / t_token_generation * n_decoded; SLT_INF(*this, "\n" "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " reasoning time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" " total time = %10.2f ms / %5d tokens\n", t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_reasoning_token_generation, n_reasoning_tokens, t_reasoning, n_reasoning_second, t_token_generation, n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); @@ -1079,6 +1108,13 @@ struct server_context_impl { ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt : SLOT_STATE_STARTED; + // Initialize reasoning tracking + slot.forced_tokens.clear(); + slot.n_reasoning_tokens = 0; + slot.reasoning = REASONING_STATE_NONE; + slot.reasoning_end_tag.clear(); + + SLT_INF(slot, "%s", "processing task\n"); return true; @@ -1154,6 +1190,85 @@ struct server_context_impl { SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); } + const int32_t reasoning_budget = (slot.task ? slot.task->params.reasoning_budget : params_base.reasoning_budget); + + // check reasoning budget limit + // Track reasoning tokens using the chat parser to detect reasoning segments consistently across formats + // When the budget is exceeded we enqueue the closing tag tokens so they get sent to the client + // and fed back into the model before continuing normal generation + if (slot.has_next_token && reasoning_budget > 0 && slot.reasoning != REASONING_STATE_FINISHED) { + const auto parsed_msg = common_chat_parse( + slot.generated_text, + /* is_partial = */ true, + slot.task->params.oaicompat_chat_syntax); + const auto & rstatus = parsed_msg.reasoning_status; + + if (rstatus.active && slot.reasoning != REASONING_STATE_PENDING_FORCE_CLOSE) { + if (slot.reasoning != REASONING_STATE_REASONING) { + SLT_DBG(slot, "detected reasoning start via parser%s\n", ""); + slot.reasoning = REASONING_STATE_REASONING; + slot.reasoning_end_tag = rstatus.end_tag; + slot.n_reasoning_tokens = 0; + slot.t_start_reasoning = ggml_time_us(); + } + } else if (!rstatus.active && slot.reasoning == REASONING_STATE_REASONING) { + SLT_DBG(slot, "detected reasoning end '%s' via parser\n", rstatus.end_tag.c_str()); + slot.reasoning = REASONING_STATE_FINISHED; + slot.t_reasoning_token_generation = (ggml_time_us() - slot.t_start_reasoning) / 1e3; + } + + if (slot.reasoning == REASONING_STATE_REASONING) { + slot.n_reasoning_tokens++; + + // Detect if we are in the middle of emitting a tool call this step. + // The parser sets tool_call_in_progress when it catches a partial exception + // while parsing tool calls, indicating incomplete tool call parsing. + // We also check for tool call diffs in this token as a fallback. + if (!parsed_msg.tool_call_in_progress && slot.n_reasoning_tokens >= reasoning_budget) { + SLT_INF(slot, "reasoning budget exceeded, forcing close with '%s', n_reasoning_tokens = %d, reasoning_budget = %d\n", + slot.reasoning_end_tag.c_str(), slot.n_reasoning_tokens, reasoning_budget); + + auto fail_close = [&](const char * reason) { + SLT_WRN(slot, "failed to inject reasoning close tag (%s) -> stopping generation\n", reason); + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + }; + + if (slot.reasoning_end_tag.empty()) { + fail_close("no closing tag detected"); + } else { + const std::string forced_message = slot.task->params.reasoning_force_close_message.empty() + ? std::string(COMMON_DEFAULT_REASONING_FORCE_CLOSE_MESSAGE) + : slot.task->params.reasoning_force_close_message; + const std::string forced_injection = forced_message + slot.reasoning_end_tag; + + llama_tokens closing_tokens; + try { + closing_tokens = common_tokenize(ctx, forced_injection, /*add_special=*/false, /*parse_special=*/true); + } catch (const std::exception & err) { + SLT_WRN(slot, "tokenization error while forcing reasoning close: %s\n", err.what()); + fail_close("tokenization error"); + closing_tokens.clear(); + } + + if (!closing_tokens.empty()) { + slot.forced_tokens.insert(slot.forced_tokens.end(), closing_tokens.begin(), closing_tokens.end()); + slot.reasoning = REASONING_STATE_PENDING_FORCE_CLOSE; + } else if (slot.has_next_token) { + fail_close("closing tag produced no tokens"); + } + } + } + } else if (slot.reasoning == REASONING_STATE_PENDING_FORCE_CLOSE) { + // We've already scheduled the forced close, wait until it's done + if (slot.forced_tokens.empty()) { + SLT_DBG(slot, "completed forced reasoning close with '%s'\n", slot.reasoning_end_tag.c_str()); + slot.reasoning = REASONING_STATE_FINISHED; + slot.t_reasoning_token_generation = (ggml_time_us() - slot.t_start_reasoning) / 1e3; + } + } + } + if (slot.has_new_line) { // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent if (slot.task->params.n_indent > 0) { @@ -2484,7 +2599,15 @@ struct server_context_impl { const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + const bool has_forced_token = !slot.forced_tokens.empty(); + llama_token id = 0; + + if (has_forced_token) { + id = slot.forced_tokens.front(); + slot.forced_tokens.pop_front(); + } else { + id = common_sampler_sample(slot.smpl, ctx, tok_idx); + } slot.i_batch = -1; @@ -2522,7 +2645,7 @@ struct server_context_impl { // speculative decoding - main model sample and accept for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) { + if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty() || !slot.forced_tokens.empty()) { continue; } diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 360826062b1..db53b19b775 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -130,6 +130,7 @@ json task_params::to_json(bool only_metrics) const { {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, + {"reasoning_force_close_message", reasoning_force_close_message}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -137,6 +138,7 @@ json task_params::to_json(bool only_metrics) const { {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"lora", lora}, + {"thinking_budget_tokens", reasoning_budget}, }; } @@ -159,8 +161,8 @@ task_params server_task::params_from_json_cmpl( defaults.speculative = params_base.speculative; defaults.n_keep = params_base.n_keep; defaults.n_predict = params_base.n_predict; - defaults.n_cache_reuse = params_base.n_cache_reuse; defaults.antiprompt = params_base.antiprompt; + defaults.reasoning_force_close_message = params_base.reasoning_force_close_message; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; @@ -182,6 +184,9 @@ task_params server_task::params_from_json_cmpl( params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); params.response_fields = json_value(data, "response_fields", std::vector()); + params.reasoning_budget = json_value(data, "thinking_budget_tokens", params_base.reasoning_budget); + params.reasoning_force_close_message = json_value(data, "reasoning_force_close_message", defaults.reasoning_force_close_message); + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 9011ff944b9..1425c0b85e3 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -72,11 +72,13 @@ struct task_params { struct common_params_speculative speculative; // response formatting - bool verbose = false; - task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; + int32_t reasoning_budget; + std::string reasoning_force_close_message; // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) diff --git a/tools/server/tests/unit/test_reasoning_budget_stream.py b/tools/server/tests/unit/test_reasoning_budget_stream.py new file mode 100644 index 00000000000..690ee293705 --- /dev/null +++ b/tools/server/tests/unit/test_reasoning_budget_stream.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +import pytest + +from utils import ServerPreset, ServerProcess + +server: ServerProcess + + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.jinja = True + server.reasoning_budget = 1 + server.chat_template_file = "../../../models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja" + server.reasoning_force_close_message = "!!!ABORT REASONING" + server.start() + yield + server.stop() + + +def test_reasoning_budget_forces_close(): + """Ensure reasoning budget triggers forced close injection with end tag.""" + res = server.make_request("POST", "/v1/chat/completions", data={ + "model": server.model_alias or "test", + "messages": [ + {"role": "user", "content": "Tell me a short story."}, + ], + "max_tokens": 32, + }) + + assert res.status_code == 200 + body = res.body + assert "choices" in body and body["choices"], "no choices returned" + + message = body["choices"][0]["message"] + reasoning_content = message.get("reasoning_content", "") + + assert server.reasoning_force_close_message in reasoning_content, "reasoning force close message not found in reasoning content" + +def test_reasoning_custom_budget(): + """Ensure reasoning budget triggers forced close injection with end tag.""" + res = server.make_request("POST", "/v1/chat/completions", data={ + "model": server.model_alias or "test", + "messages": [ + {"role": "user", "content": "Tell me a short story."}, + ], + "max_tokens": 32, + "thinking_budget_tokens": 5 + }) + + assert res.status_code == 200 + body = res.body + assert "choices" in body and body["choices"], "no choices returned" + + message = body["choices"][0]["message"] + reasoning_content = message.get("reasoning_content", "") + + reasoning_before_abort = reasoning_content.split(server.reasoning_force_close_message)[0] + assert len(reasoning_before_abort.split()) > 1, "reasoning content too short before force close" + + assert server.reasoning_force_close_message in reasoning_content, "reasoning force close message not found in reasoning content" \ No newline at end of file diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 48e7403602f..fb6edf41499 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -95,6 +95,7 @@ class ServerProcess: jinja: bool | None = None reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None reasoning_budget: int | None = None + reasoning_force_close_message: str | None = None chat_template: str | None = None chat_template_file: str | None = None server_path: str | None = None @@ -222,6 +223,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(("--reasoning-format", self.reasoning_format)) if self.reasoning_budget is not None: server_args.extend(("--reasoning-budget", self.reasoning_budget)) + if self.reasoning_force_close_message is not None: + server_args.extend(("--reasoning-force-close-message", self.reasoning_force_close_message)) if self.chat_template: server_args.extend(["--chat-template", self.chat_template]) if self.chat_template_file: