Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions common/chat-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,14 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
void common_chat_msg_parser::clear_tools() {
result_.tool_calls.clear();
}

void common_chat_msg_parser::remove_content_suffix(size_t len) {
if (len == 0 || result_.content.empty()) {
return;
}
if (len >= result_.content.size()) {
result_.content.clear();
return;
}
result_.content.erase(result_.content.size() - len);
}
2 changes: 2 additions & 0 deletions common/chat-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,6 @@ class common_chat_msg_parser {
);

void clear_tools();

void remove_content_suffix(size_t len);
};
61 changes: 61 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2116,6 +2116,17 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);

static const std::vector<std::string> wrapper_open_tags = {
"<tool_call>",
"<function_call>",
"<tool>",
"<tools>",
"<response>",
"<json>",
"<xml>",
"<JSON>",
};

while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
Expand All @@ -2142,24 +2153,74 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
} else {
auto prelude = res->prelude;
bool prelude_wrappers_only = true;
auto trimmed_prelude = string_strip(prelude);
while (!trimmed_prelude.empty()) {
bool matched_wrapper = false;
for (const auto & tag : wrapper_open_tags) {
if (string_starts_with(trimmed_prelude, tag)) {
trimmed_prelude = string_strip(trimmed_prelude.substr(tag.size()));
matched_wrapper = true;
break;
}
}
if (!matched_wrapper) {
prelude_wrappers_only = false;
break;
}
}
if (!prelude.empty() && prelude_wrappers_only) {
builder.remove_content_suffix(prelude.size());
}

auto function_name = builder.str(res->groups[4]);
if (function_name.empty()) {
function_name = builder.str(res->groups[5]);
}
GGML_ASSERT(!function_name.empty());

close_tag = "</function>";
const bool had_block_start = res->prelude.find("```") != std::string::npos;

if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);

static const std::vector<std::string> wrapper_close_tags = {
"</tool_call>",
"</tool>",
"</tools>",
"</response>",
"</json>",
"</xml>",
"</JSON>",
};

while (true) {
builder.consume_spaces();
bool matched_wrapper = false;
for (const auto & wrapper_close : wrapper_close_tags) {
if (builder.try_consume_literal(wrapper_close)) {
matched_wrapper = true;
break;
}
}
if (!matched_wrapper) {
break;
}
}

builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
} else if (had_block_start) {
builder.try_consume_literal("```");
builder.consume_spaces();
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,16 @@ static void test_template_output_parsers() {
"</function>",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
assert_msg_equals(
message_assist_call,
common_chat_parse(
"<tool_call>\n"
"<function=special_function>\n"
"{\"arg1\": 1}\n"
"</function>\n"
"</tool_call>\n",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
assert_msg_equals(
message_assist_call,
common_chat_parse(
Expand Down