diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 96ba8f533ef1b..7f28557ddedc5 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -391,3 +391,14 @@ std::optional common_chat_msg_parse void common_chat_msg_parser::clear_tools() { result_.tool_calls.clear(); } + +void common_chat_msg_parser::remove_content_suffix(size_t len) { + if (len == 0 || result_.content.empty()) { + return; + } + if (len >= result_.content.size()) { + result_.content.clear(); + return; + } + result_.content.erase(result_.content.size() - len); +} diff --git a/common/chat-parser.h b/common/chat-parser.h index 0e64c341a50aa..d5e4fd19b43ad 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -117,4 +117,6 @@ class common_chat_msg_parser { ); void clear_tools(); + + void remove_content_suffix(size_t len); }; diff --git a/common/chat.cpp b/common/chat.cpp index e2bacdcf52753..1ea5b847ca219 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -2116,6 +2116,17 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { "|" // match 5 (function name again) ); + static const std::vector wrapper_open_tags = { + "", + "", + "", + "", + "", + "", + "", + "", + }; + while (auto res = builder.try_find_regex(open_regex)) { const auto & block_start = res->groups[1]; std::string block_end = block_start.empty() ? "" : "```"; @@ -2142,6 +2153,27 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { throw common_chat_msg_partial_exception("failed to parse tool call"); } } else { + auto prelude = res->prelude; + bool prelude_wrappers_only = true; + auto trimmed_prelude = string_strip(prelude); + while (!trimmed_prelude.empty()) { + bool matched_wrapper = false; + for (const auto & tag : wrapper_open_tags) { + if (string_starts_with(trimmed_prelude, tag)) { + trimmed_prelude = string_strip(trimmed_prelude.substr(tag.size())); + matched_wrapper = true; + break; + } + } + if (!matched_wrapper) { + prelude_wrappers_only = false; + break; + } + } + if (!prelude.empty() && prelude_wrappers_only) { + builder.remove_content_suffix(prelude.size()); + } + auto function_name = builder.str(res->groups[4]); if (function_name.empty()) { function_name = builder.str(res->groups[5]); @@ -2149,6 +2181,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { GGML_ASSERT(!function_name.empty()); close_tag = ""; + const bool had_block_start = res->prelude.find("```") != std::string::npos; if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { @@ -2156,10 +2189,38 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { } builder.consume_spaces(); builder.consume_literal(close_tag); + + static const std::vector wrapper_close_tags = { + "", + "", + "", + "", + "", + "", + "", + }; + + while (true) { + builder.consume_spaces(); + bool matched_wrapper = false; + for (const auto & wrapper_close : wrapper_close_tags) { + if (builder.try_consume_literal(wrapper_close)) { + matched_wrapper = true; + break; + } + } + if (!matched_wrapper) { + break; + } + } + builder.consume_spaces(); if (!block_end.empty()) { builder.consume_literal(block_end); builder.consume_spaces(); + } else if (had_block_start) { + builder.try_consume_literal("```"); + builder.consume_spaces(); } } } diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index ce0f4b0a2a9f3..efd300f525428 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -840,6 +840,16 @@ static void test_template_output_parsers() { "", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + "\n" + "{\"arg1\": 1}\n" + "\n" + "\n", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, common_chat_parse(