diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index b3362519a68f3..5a7923b9e1063 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -192,10 +192,6 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think if (!rest.empty()) { handle_reasoning(rest, /* closed */ !is_partial()); } - // Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877) - // if (!syntax_.thinking_forced_open) { - // throw common_chat_msg_partial_exception(end_think); - // } return true; } } diff --git a/common/chat.cpp b/common/chat.cpp index afbb2a2bdd3c4..af282a61bee4d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -148,6 +149,7 @@ struct templates_params { bool add_bos; bool add_eos; bool is_inference = true; + bool supports_enable_thinking = false; }; common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { @@ -170,10 +172,8 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates * msg.content = "test"; dummy_inputs.messages = {msg}; dummy_inputs.enable_thinking = false; - const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs); - dummy_inputs.enable_thinking = true; - const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs); - return rendered_no_thinking.prompt != rendered_with_thinking.prompt; + const auto rendered = common_chat_templates_apply(chat_templates, dummy_inputs); + return rendered.supports_enable_thinking; } template <> @@ -826,6 +826,7 @@ static std::string apply( static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; auto tool_call_schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -943,6 +944,7 @@ static void common_chat_parse_generic(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); @@ -988,6 +990,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; data.prompt = apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_MAGISTRAL; data.preserved_tokens = { @@ -1068,6 +1071,7 @@ static void common_chat_parse_magistral(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; auto adjusted_messages = json::array(); for (const auto & msg : inputs.messages) { @@ -1201,6 +1205,7 @@ static void expect_tool_parameters(const std::string & name, const json & parame static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { auto builtin_tools = json::array(); common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; if (!inputs.tools.is_null()) { data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { @@ -1280,6 +1285,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; // Generate the prompt using the apply() function with the template data.prompt = apply(tmpl, inputs); @@ -1341,6 +1347,7 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_ static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; // Generate the prompt using the apply() function with the template data.prompt = apply(tmpl, inputs); @@ -1465,6 +1472,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; auto prompt = apply(tmpl, inputs); // Hacks to fix the official (broken) prompt. @@ -1539,6 +1547,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; // Pass thinking context for DeepSeek V3.1 template json additional_context = { @@ -1684,6 +1693,7 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; auto prompt = apply(tmpl, inputs); // Check if we need to replace the return token with end token during @@ -1903,6 +1913,7 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; const std::optional tools_override = json(); const std::optional additional_context = json { {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, @@ -1961,6 +1972,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; data.prompt = apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; if (inputs.tools.is_array() && !inputs.tools.empty()) { @@ -2037,6 +2049,7 @@ static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; if (!inputs.tools.is_null()) { std::string python_code_argument_name; @@ -2120,6 +2133,7 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; json extra_context = json { {"enable_thinking", inputs.enable_thinking}, @@ -2313,6 +2327,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; // Pass thinking context for Granite template json additional_context = { @@ -2587,6 +2602,7 @@ static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + data.supports_enable_thinking = inputs.supports_enable_thinking; data.prompt = apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; data.grammar_lazy = false; @@ -2598,6 +2614,23 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha } else { data.grammar = inputs.grammar; } + + if (inputs.supports_enable_thinking) { + static constexpr size_t think_tag_len = 7; // strlen("") + size_t prompt_trimmed_size = data.prompt.size(); + while (prompt_trimmed_size > 0 && + std::isspace(static_cast(data.prompt[prompt_trimmed_size - 1]))) { + --prompt_trimmed_size; + } + if (prompt_trimmed_size >= think_tag_len && + data.prompt.compare(prompt_trimmed_size - think_tag_len, think_tag_len, "") == 0) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + } return data; } @@ -2607,6 +2640,7 @@ static common_chat_params common_chat_params_init_seed_oss( const common_chat_templates_inputs & inputs) { common_chat_params data; + data.supports_enable_thinking = params.supports_enable_thinking; data.prompt = apply(tmpl, params); data.format = COMMON_CHAT_FORMAT_SEED_OSS; if (string_ends_with(data.prompt, "")) { @@ -2680,6 +2714,15 @@ static common_chat_params common_chat_templates_apply_jinja( params.extra_context[el.first] = json::parse(el.second); } + { + auto params_with_thinking = params; + params_with_thinking.enable_thinking = true; + auto params_without_thinking = params; + params_without_thinking.enable_thinking = false; + params.supports_enable_thinking = + apply(tmpl, params_with_thinking) != apply(tmpl, params_without_thinking); + } + if (!inputs.json_schema.empty()) { params.json_schema = json::parse(inputs.json_schema); } diff --git a/common/chat.h b/common/chat.h index a1afe574bd0ca..23f0c5bf1559e 100644 --- a/common/chat.h +++ b/common/chat.h @@ -144,6 +144,7 @@ struct common_chat_params { std::string grammar; bool grammar_lazy = false; bool thinking_forced_open = false; + bool supports_enable_thinking = false; std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 52e23b5ac61f5..dd7b091b55e8b 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1330,6 +1330,89 @@ static void test_template_output_parsers() { // /* expect_grammar_triggered= */ true, // /* test_grammar_if_triggered= */ false); } + { + // Generic fallback template that appends when add_generation_prompt is true. + static const char * tmpl_str = R"( +{% for message in messages %} +<|{{ message.role }}|> +{{ message.content }} +{% endfor %} +{% if add_generation_prompt %}<|assistant|> + +{% endif %} +)"; + + auto tmpls = common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, tmpl_str)); + + common_chat_templates_inputs inputs_base; + inputs_base.messages = { message_user }; + inputs_base.add_generation_prompt = true; + + auto inputs_no_thinking = inputs_base; + inputs_no_thinking.enable_thinking = false; + auto params_no_thinking = common_chat_templates_apply(tmpls.get(), inputs_no_thinking); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_no_thinking.format); + assert_equals(false, params_no_thinking.thinking_forced_open); + assert_equals(false, params_no_thinking.supports_enable_thinking); + assert_equals(true, string_ends_with(string_strip(params_no_thinking.prompt), "")); + + auto inputs_with_thinking = inputs_base; + inputs_with_thinking.enable_thinking = true; + auto params_with_thinking = common_chat_templates_apply(tmpls.get(), inputs_with_thinking); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_with_thinking.format); + assert_equals(false, params_with_thinking.thinking_forced_open); + assert_equals(false, params_with_thinking.supports_enable_thinking); + assert_equals(true, string_ends_with(string_strip(params_with_thinking.prompt), "")); + + assert_equals(false, common_chat_templates_support_enable_thinking(tmpls.get())); + } + { + // Template that conditionally appends when enable_thinking is true. + static const char * tmpl_str = R"( +{% for message in messages %} +<|{{ message.role }}|> +{{ message.content }} +{% endfor %} +{% if add_generation_prompt %}<|assistant|> +{% if enable_thinking %}{% endif %} +{% endif %} +)"; + + auto tmpls = common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, tmpl_str)); + + common_chat_templates_inputs inputs_base; + inputs_base.messages = { message_user }; + inputs_base.add_generation_prompt = true; + + auto inputs_no_thinking = inputs_base; + inputs_no_thinking.enable_thinking = false; + auto params_no_thinking = common_chat_templates_apply(tmpls.get(), inputs_no_thinking); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_no_thinking.format); + assert_equals(false, params_no_thinking.thinking_forced_open); + assert_equals(true, params_no_thinking.supports_enable_thinking); + assert_equals(false, string_ends_with(string_strip(params_no_thinking.prompt), "")); + + auto inputs_with_thinking = inputs_base; + inputs_with_thinking.enable_thinking = true; + auto params_with_thinking = common_chat_templates_apply(tmpls.get(), inputs_with_thinking); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_with_thinking.format); + assert_equals(true, params_with_thinking.thinking_forced_open); + assert_equals(true, params_with_thinking.supports_enable_thinking); + assert_equals(true, string_ends_with(string_strip(params_with_thinking.prompt), "")); + + assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get())); + + common_chat_syntax syntax; + syntax.format = params_with_thinking.format; + syntax.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + syntax.thinking_forced_open = params_with_thinking.thinking_forced_open; + + assert_msg_equals(simple_assist_msg("Final answer", "Reasoning trace"), + common_chat_parse( + "Reasoning traceFinal answer", + /* is_partial= */ false, + syntax)); + } { // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all. auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");