Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions common/chat-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,6 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
if (!rest.empty()) {
handle_reasoning(rest, /* closed */ !is_partial());
}
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
// if (!syntax_.thinking_forced_open) {
// throw common_chat_msg_partial_exception(end_think);
// }
return true;
}
}
Expand Down
51 changes: 47 additions & 4 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <minja/chat-template.hpp>
#include <minja/minja.hpp>

#include <cctype>
#include <cstdio>
#include <exception>
#include <iostream>
Expand Down Expand Up @@ -148,6 +149,7 @@ struct templates_params {
bool add_bos;
bool add_eos;
bool is_inference = true;
bool supports_enable_thinking = false;
};

common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
Expand All @@ -170,10 +172,8 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
msg.content = "test";
dummy_inputs.messages = {msg};
dummy_inputs.enable_thinking = false;
const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
dummy_inputs.enable_thinking = true;
const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
const auto rendered = common_chat_templates_apply(chat_templates, dummy_inputs);
return rendered.supports_enable_thinking;
}

template <>
Expand Down Expand Up @@ -826,6 +826,7 @@ static std::string apply(

static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;

auto tool_call_schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
Expand Down Expand Up @@ -943,6 +944,7 @@ static void common_chat_parse_generic(common_chat_msg_parser & builder) {

static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
Expand Down Expand Up @@ -988,6 +990,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat

static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_MAGISTRAL;
data.preserved_tokens = {
Expand Down Expand Up @@ -1068,6 +1071,7 @@ static void common_chat_parse_magistral(common_chat_msg_parser & builder) {

static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;

auto adjusted_messages = json::array();
for (const auto & msg : inputs.messages) {
Expand Down Expand Up @@ -1201,6 +1205,7 @@ static void expect_tool_parameters(const std::string & name, const json & parame
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array();
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;
if (!inputs.tools.is_null()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
Expand Down Expand Up @@ -1280,6 +1285,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te

static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;

// Generate the prompt using the apply() function with the template
data.prompt = apply(tmpl, inputs);
Expand Down Expand Up @@ -1341,6 +1347,7 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_

static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;

// Generate the prompt using the apply() function with the template
data.prompt = apply(tmpl, inputs);
Expand Down Expand Up @@ -1465,6 +1472,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w

static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;
auto prompt = apply(tmpl, inputs);

// Hacks to fix the official (broken) prompt.
Expand Down Expand Up @@ -1539,6 +1547,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_

static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;

// Pass thinking context for DeepSeek V3.1 template
json additional_context = {
Expand Down Expand Up @@ -1684,6 +1693,7 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {

static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;
auto prompt = apply(tmpl, inputs);

// Check if we need to replace the return token with end token during
Expand Down Expand Up @@ -1903,6 +1913,7 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;
const std::optional<json> tools_override = json();
const std::optional<json> additional_context = json {
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
Expand Down Expand Up @@ -1961,6 +1972,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
// If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
if (inputs.tools.is_array() && !inputs.tools.empty()) {
Expand Down Expand Up @@ -2037,6 +2049,7 @@ static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder)
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;

if (!inputs.tools.is_null()) {
std::string python_code_argument_name;
Expand Down Expand Up @@ -2120,6 +2133,7 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser

static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;

json extra_context = json {
{"enable_thinking", inputs.enable_thinking},
Expand Down Expand Up @@ -2313,6 +2327,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {

static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;

// Pass thinking context for Granite template
json additional_context = {
Expand Down Expand Up @@ -2587,6 +2602,7 @@ static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {

static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.supports_enable_thinking = inputs.supports_enable_thinking;
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
data.grammar_lazy = false;
Expand All @@ -2598,6 +2614,23 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
} else {
data.grammar = inputs.grammar;
}

if (inputs.supports_enable_thinking) {
static constexpr size_t think_tag_len = 7; // strlen("<think>")
size_t prompt_trimmed_size = data.prompt.size();
while (prompt_trimmed_size > 0 &&
std::isspace(static_cast<unsigned char>(data.prompt[prompt_trimmed_size - 1]))) {
--prompt_trimmed_size;
}
if (prompt_trimmed_size >= think_tag_len &&
data.prompt.compare(prompt_trimmed_size - think_tag_len, think_tag_len, "<think>") == 0) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
}
return data;
}

Expand All @@ -2607,6 +2640,7 @@ static common_chat_params common_chat_params_init_seed_oss(
const common_chat_templates_inputs & inputs)
{
common_chat_params data;
data.supports_enable_thinking = params.supports_enable_thinking;
data.prompt = apply(tmpl, params);
data.format = COMMON_CHAT_FORMAT_SEED_OSS;
if (string_ends_with(data.prompt, "<seed:think>")) {
Expand Down Expand Up @@ -2680,6 +2714,15 @@ static common_chat_params common_chat_templates_apply_jinja(
params.extra_context[el.first] = json::parse(el.second);
}

{
auto params_with_thinking = params;
params_with_thinking.enable_thinking = true;
auto params_without_thinking = params;
params_without_thinking.enable_thinking = false;
params.supports_enable_thinking =
apply(tmpl, params_with_thinking) != apply(tmpl, params_without_thinking);
}

if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ struct common_chat_params {
std::string grammar;
bool grammar_lazy = false;
bool thinking_forced_open = false;
bool supports_enable_thinking = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
Expand Down
83 changes: 83 additions & 0 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,89 @@ static void test_template_output_parsers() {
// /* expect_grammar_triggered= */ true,
// /* test_grammar_if_triggered= */ false);
}
{
// Generic fallback template that appends <think> when add_generation_prompt is true.
static const char * tmpl_str = R"(
{% for message in messages %}
<|{{ message.role }}|>
{{ message.content }}
{% endfor %}
{% if add_generation_prompt %}<|assistant|>
<think>
{% endif %}
)";

auto tmpls = common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, tmpl_str));

common_chat_templates_inputs inputs_base;
inputs_base.messages = { message_user };
inputs_base.add_generation_prompt = true;

auto inputs_no_thinking = inputs_base;
inputs_no_thinking.enable_thinking = false;
auto params_no_thinking = common_chat_templates_apply(tmpls.get(), inputs_no_thinking);
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_no_thinking.format);
assert_equals(false, params_no_thinking.thinking_forced_open);
assert_equals(false, params_no_thinking.supports_enable_thinking);
assert_equals(true, string_ends_with(string_strip(params_no_thinking.prompt), "<think>"));

auto inputs_with_thinking = inputs_base;
inputs_with_thinking.enable_thinking = true;
auto params_with_thinking = common_chat_templates_apply(tmpls.get(), inputs_with_thinking);
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_with_thinking.format);
assert_equals(false, params_with_thinking.thinking_forced_open);
assert_equals(false, params_with_thinking.supports_enable_thinking);
assert_equals(true, string_ends_with(string_strip(params_with_thinking.prompt), "<think>"));

assert_equals(false, common_chat_templates_support_enable_thinking(tmpls.get()));
}
{
// Template that conditionally appends <think> when enable_thinking is true.
static const char * tmpl_str = R"(
{% for message in messages %}
<|{{ message.role }}|>
{{ message.content }}
{% endfor %}
{% if add_generation_prompt %}<|assistant|>
{% if enable_thinking %}<think>{% endif %}
{% endif %}
)";

auto tmpls = common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, tmpl_str));

common_chat_templates_inputs inputs_base;
inputs_base.messages = { message_user };
inputs_base.add_generation_prompt = true;

auto inputs_no_thinking = inputs_base;
inputs_no_thinking.enable_thinking = false;
auto params_no_thinking = common_chat_templates_apply(tmpls.get(), inputs_no_thinking);
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_no_thinking.format);
assert_equals(false, params_no_thinking.thinking_forced_open);
assert_equals(true, params_no_thinking.supports_enable_thinking);
assert_equals(false, string_ends_with(string_strip(params_no_thinking.prompt), "<think>"));

auto inputs_with_thinking = inputs_base;
inputs_with_thinking.enable_thinking = true;
auto params_with_thinking = common_chat_templates_apply(tmpls.get(), inputs_with_thinking);
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_with_thinking.format);
assert_equals(true, params_with_thinking.thinking_forced_open);
assert_equals(true, params_with_thinking.supports_enable_thinking);
assert_equals(true, string_ends_with(string_strip(params_with_thinking.prompt), "<think>"));

assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get()));

common_chat_syntax syntax;
syntax.format = params_with_thinking.format;
syntax.reasoning_format = COMMON_REASONING_FORMAT_AUTO;
syntax.thinking_forced_open = params_with_thinking.thinking_forced_open;

assert_msg_equals(simple_assist_msg("Final answer", "Reasoning trace"),
common_chat_parse(
"Reasoning trace</think>Final answer",
/* is_partial= */ false,
syntax));
}
{
// Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
Expand Down
Loading