diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 7365782e7d6d8..ff83102788d49 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -432,7 +432,7 @@ std::optional common_chat_msg_parse if (is_arguments_path({})) { // Entire JSON is the arguments and was parsed fully. return consume_json_result { - partial->json.dump(), + partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true), /* .is_partial = */ false, }; } @@ -444,7 +444,7 @@ std::optional common_chat_msg_parse std::vector path; std::function remove_unsupported_healings_and_dump_args = [&](const json & j) -> json { if (is_arguments_path(path)) { - auto arguments = j.dump(); + auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true); if (is_partial() && !partial->healing_marker.marker.empty()) { auto idx = arguments.find(partial->healing_marker.json_dump_marker); if (idx != std::string::npos) { diff --git a/common/json-partial.cpp b/common/json-partial.cpp index d9d91699899f7..919927dc32446 100644 --- a/common/json-partial.cpp +++ b/common/json-partial.cpp @@ -5,6 +5,7 @@ #include #include +#include using json = nlohmann::ordered_json; @@ -168,6 +169,47 @@ bool common_json_parse( } } + // Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX + static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)"); + + auto is_high_surrogate = [&](const std::string & s) { + // Check if a partial of a high surrogate (U+D800-U+DBFF) + return s.length() >= 4 && + s[0] == '\\' && s[1] == 'u' && + std::tolower(s[2]) == 'd' && + (s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b'); + }; + + // Initialize the unicode marker to a low surrogate to handle the edge case + // where a high surrogate (U+D800-U+DBFF) is immediately followed by a + // backslash (\) + std::string unicode_marker_padding = "udc00"; + std::smatch last_unicode_seq; + + if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) { + std::smatch second_last_seq; + std::string prelude = str.substr(0, last_unicode_seq.position()); + + // Pad the escape sequence with 0s until it forms a complete sequence of 6 characters + unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0'); + + if (is_high_surrogate(last_unicode_seq.str())) { + // If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF) + unicode_marker_padding += "\\udc00"; + } else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) { + if (is_high_surrogate(second_last_seq.str())) { + // If this follows a high surrogate, pad it to be a low surrogate + if (last_unicode_seq.length() == 2) { + unicode_marker_padding = "dc00"; + } else if (last_unicode_seq.length() == 3) { + unicode_marker_padding = "c00"; + } else { + // The original unicode_marker_padding is already padded with 0s + } + } + } + } + const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$"; if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) { @@ -186,6 +228,9 @@ bool common_json_parse( } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { // Was inside an object value string after an escape str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; + } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) { + // Was inside an object value string after a partial unicode escape + str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing; } else { // find last : auto last_pos = str.find_last_of(':'); @@ -205,6 +250,9 @@ bool common_json_parse( } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { // Was inside an array value string after an escape str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; + } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) { + // Was inside an array value string after a partial unicode escape + str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing; } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) { // Had just finished a value str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing; @@ -230,6 +278,9 @@ bool common_json_parse( } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) { // Was inside an object key string after an escape str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing; + } else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) { + // Was inside an object key string after a partial unicode escape + str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing; } else { auto last_pos = str.find_last_of(':'); if (last_pos == std::string::npos) { diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp index 0b275befb8bf4..4766518fe6955 100644 --- a/tests/test-chat-parser.cpp +++ b/tests/test-chat-parser.cpp @@ -524,6 +524,64 @@ static void test_json_with_dumped_args() { R"({"foo": "bar", "args": {"arg1": [)", R"({"foo":"bar","args":"{\"arg1\":["})" ); + + // Unicode tests + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\u)", + R"({"foo":"bar","args":"{\"arg1\":\"\\u"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\u0)", + R"({"foo":"bar","args":"{\"arg1\":\"\\u0"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\u00)", + R"({"foo":"bar","args":"{\"arg1\":\"\\u00"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\u000)", + R"({"foo":"bar","args":"{\"arg1\":\"\\u000"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\u0000)", + R"({"foo":"bar","args":"{\"arg1\":\"\\u0000"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\ud8)", + R"({"foo":"bar","args":"{\"arg1\":\"\\ud8"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\ud80)", + R"({"foo":"bar","args":"{\"arg1\":\"\\ud80"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\ud800)", + R"({"foo":"bar","args":"{\"arg1\":\"\\ud800"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\ud800\)", + R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\ud800\u)", + R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\u"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\ud800\ud)", + R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\ud"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\ud800\udc)", + R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\ud800\udc0)", + R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc0"})" + ); + test_with_args( + R"({"foo": "bar", "args": {"arg1": "\ud800\udc00)", + R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc00"})" + ); } static void test_positions() { diff --git a/tests/test-json-partial.cpp b/tests/test-json-partial.cpp index bc136beceb9ae..39da9276ef459 100644 --- a/tests/test-json-partial.cpp +++ b/tests/test-json-partial.cpp @@ -58,7 +58,7 @@ static void test_json_healing() { for (const auto & input : inputs) { common_json out; assert_equals(true, common_json_parse(input, "$foo", out)); - assert_equals(expected, out.json.dump()); + assert_equals(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true)); assert_equals(expected_marker, out.healing_marker.json_dump_marker); } }; @@ -228,6 +228,56 @@ static void test_json_healing() { R"({"key":"$foo"})", R"(:"$foo)" ); + // Test unicode escape sequences + test( + { + R"({"a":"\u)", + }, + R"({"a":"\u0000$foo"})", + R"(0000$foo)" + ); + test( + { + R"({"a":"\u00)", + }, + R"({"a":"\u0000$foo"})", + R"(00$foo)" + ); + test( + { + R"({"a":"\ud300)", + }, + R"({"a":"\ud300$foo"})", + R"($foo)" + ); + test( + { + R"({"a":"\ud800)", + }, + R"({"a":"\ud800\udc00$foo"})", + R"(\udc00$foo)" + ); + test( + { + R"({"a":"\ud800\)", + }, + R"({"a":"\ud800\udc00$foo"})", + R"(udc00$foo)" + ); + test( + { + R"({"a":"\ud800\u)", + }, + R"({"a":"\ud800\udc00$foo"})", + R"(dc00$foo)" + ); + test( + { + R"({"a":"\ud800\udc00)", + }, + R"({"a":"\ud800\udc00$foo"})", + R"($foo)" + ); } int main() {