From 2961a458d2661f02b1ec13c41f416f9e633e9c1a Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Sun, 16 Mar 2025 00:42:56 -0700 Subject: [PATCH 1/4] Rely on runtime_wrapper to provide supported platforms As titled. --- targets.bzl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/targets.bzl b/targets.bzl index b5eecbd..27fb94a 100644 --- a/targets.bzl +++ b/targets.bzl @@ -1,8 +1,7 @@ -load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "FBCODE") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob") -PLATFORMS = (CXX, ANDROID, APPLE, FBCODE) +PLATFORMS = runtime.get_executorch_supported_platforms() def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. From 586fc0145fa6abffef5f7e1fc9020ad0d16535eb Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Sun, 16 Mar 2025 01:04:02 -0700 Subject: [PATCH 2/4] Update targets.bzl --- targets.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/targets.bzl b/targets.bzl index 27fb94a..28fd224 100644 --- a/targets.bzl +++ b/targets.bzl @@ -1,7 +1,7 @@ -load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime", "get_executorch_supported_platforms") load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob") -PLATFORMS = runtime.get_executorch_supported_platforms() +PLATFORMS = get_executorch_supported_platforms() def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. From ec61ab1489e2d0fb6ac82b39288ce505bf8bdeca Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Sun, 16 Mar 2025 13:02:29 -0700 Subject: [PATCH 3/4] Move llama.cpp-unicode headers into llama.cpp-unicode/include Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- targets.bzl | 15 +- third-party/TARGETS | 13 ++ .../llama.cpp-unicode/include}/unicode-data.h | 0 .../llama.cpp-unicode/include}/unicode.h | 0 .../llama.cpp-unicode/src/unicode-data.cpp | 2 +- third-party/llama.cpp-unicode/src/unicode.cpp | 140 +++++++++--------- 6 files changed, 88 insertions(+), 82 deletions(-) rename {include/pytorch/tokenizers/third-party/llama.cpp-unicode => third-party/llama.cpp-unicode/include}/unicode-data.h (100%) rename {include/pytorch/tokenizers/third-party/llama.cpp-unicode => third-party/llama.cpp-unicode/include}/unicode.h (100%) diff --git a/targets.bzl b/targets.bzl index 28fd224..42e906b 100644 --- a/targets.bzl +++ b/targets.bzl @@ -67,19 +67,6 @@ def define_common_targets(): platforms = PLATFORMS, ) - runtime.cxx_library( - name = "unicode", - srcs = [ - "third-party/llama.cpp-unicode/src/unicode.cpp", - "third-party/llama.cpp-unicode/src/unicode-data.cpp", - ], - exported_headers = subdir_glob([ - ("include", "pytorch/tokenizers/third-party/llama.cpp-unicode/*.h"), - ]), - header_namespace = "", - platforms = PLATFORMS, - ) - runtime.cxx_library( name = "hf_tokenizer", srcs = [ @@ -90,7 +77,7 @@ def define_common_targets(): ], exported_deps = [ ":headers", - ":unicode", + "//pytorch/tokenizers/third-party:unicode", ], visibility = [ "@EXECUTORCH_CLIENTS", diff --git a/third-party/TARGETS b/third-party/TARGETS index 978c123..16c4551 100644 --- a/third-party/TARGETS +++ b/third-party/TARGETS @@ -1,4 +1,5 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob") oncall("executorch") @@ -45,3 +46,15 @@ runtime.cxx_library( visibility = ["PUBLIC"], _is_external_target = True, ) + +runtime.cxx_library( + name = "unicode", + srcs = [ + "llama.cpp-unicode/src/unicode.cpp", + "llama.cpp-unicode/src/unicode-data.cpp", + ], + exported_headers = subdir_glob([ + ("include", "*.h"), + ]), + header_namespace = "", +) diff --git a/include/pytorch/tokenizers/third-party/llama.cpp-unicode/unicode-data.h b/third-party/llama.cpp-unicode/include/unicode-data.h similarity index 100% rename from include/pytorch/tokenizers/third-party/llama.cpp-unicode/unicode-data.h rename to third-party/llama.cpp-unicode/include/unicode-data.h diff --git a/include/pytorch/tokenizers/third-party/llama.cpp-unicode/unicode.h b/third-party/llama.cpp-unicode/include/unicode.h similarity index 100% rename from include/pytorch/tokenizers/third-party/llama.cpp-unicode/unicode.h rename to third-party/llama.cpp-unicode/include/unicode.h diff --git a/third-party/llama.cpp-unicode/src/unicode-data.cpp b/third-party/llama.cpp-unicode/src/unicode-data.cpp index c924f0c..0317793 100644 --- a/third-party/llama.cpp-unicode/src/unicode-data.cpp +++ b/third-party/llama.cpp-unicode/src/unicode-data.cpp @@ -27,7 +27,7 @@ SOFTWARE. // generated with scripts/gen-unicode-data.py -#include +#include "unicode-data.h" #include #include diff --git a/third-party/llama.cpp-unicode/src/unicode.cpp b/third-party/llama.cpp-unicode/src/unicode.cpp index 152fca7..096fdce 100644 --- a/third-party/llama.cpp-unicode/src/unicode.cpp +++ b/third-party/llama.cpp-unicode/src/unicode.cpp @@ -29,8 +29,8 @@ SOFTWARE. #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING #endif -#include -#include +#include "unicode.h" +#include "unicode-data.h" #include #include @@ -62,7 +62,7 @@ size_t unicode_len_utf8(char src) { // return result; // } -uint32_t unicode_cpt_from_utf8(const std::string &utf8, size_t &offset) { +uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) { assert(offset < utf8.size()); if (!(utf8[offset + 0] & 0x80)) { auto result = utf8[offset + 0]; @@ -86,7 +86,7 @@ uint32_t unicode_cpt_from_utf8(const std::string &utf8, size_t &offset) { throw std::invalid_argument("invalid character"); } auto result = ((utf8[offset + 0] & 0x0f) << 12) | - ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); + ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); offset += 3; return result; } @@ -97,8 +97,8 @@ uint32_t unicode_cpt_from_utf8(const std::string &utf8, size_t &offset) { throw std::invalid_argument("invalid character"); } auto result = ((utf8[offset + 0] & 0x07) << 18) | - ((utf8[offset + 1] & 0x3f) << 12) | - ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f); + ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | + (utf8[offset + 3] & 0x3f); offset += 4; return result; } @@ -157,12 +157,13 @@ uint32_t unicode_cpt_from_utf8(const std::string &utf8, size_t &offset) { // } static std::vector unicode_cpt_flags_array() { - std::vector cpt_flags(MAX_CODEPOINTS, - codepoint_flags::UNDEFINED); + std::vector cpt_flags( + MAX_CODEPOINTS, codepoint_flags::UNDEFINED); assert(unicode_ranges_flags.begin()[0].first == 0); - assert(unicode_ranges_flags.begin()[unicode_ranges_flags.size() - 1].first == - MAX_CODEPOINTS); + assert( + unicode_ranges_flags.begin()[unicode_ranges_flags.size() - 1].first == + MAX_CODEPOINTS); for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) { const auto range_ini = unicode_ranges_flags.begin()[i - 1]; // codepoint_ini, flags @@ -185,7 +186,7 @@ static std::vector unicode_cpt_flags_array() { cpt_flags[p.second].is_uppercase = true; } - for (auto &range : unicode_ranges_nfd) { // start, last, nfd + for (auto& range : unicode_ranges_nfd) { // start, last, nfd cpt_flags[range.nfd].is_nfd = true; } @@ -240,15 +241,15 @@ static std::unordered_map unicode_utf8_to_byte_map() { return map; } -static inline std::wstring unicode_wstring_from_utf8(const std::string &s) { +static inline std::wstring unicode_wstring_from_utf8(const std::string& s) { std::wstring_convert> conv; return conv.from_bytes(s); } -static std::vector -unicode_byte_encoding_process(const std::vector &bpe_words) { +static std::vector unicode_byte_encoding_process( + const std::vector& bpe_words) { std::vector bpe_encoded_words; - for (const auto &word : bpe_words) { + for (const auto& word : bpe_words) { std::string text_utf; auto utf_word = unicode_cpts_from_utf8(word); for (size_t i = 0; i < utf_word.size(); ++i) { @@ -256,7 +257,7 @@ unicode_byte_encoding_process(const std::vector &bpe_words) { } std::string encoded_token; - for (char &c : text_utf) { + for (char& c : text_utf) { encoded_token += unicode_byte_to_utf8(c); } bpe_encoded_words.emplace_back(encoded_token); @@ -266,9 +267,9 @@ unicode_byte_encoding_process(const std::vector &bpe_words) { // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| // ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ -static std::vector -unicode_regex_split_custom_gpt2(const std::string &text, - const std::vector &offsets) { +static std::vector unicode_regex_split_custom_gpt2( + const std::string& text, + const std::vector& offsets) { std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve( offsets.size()); // Reserve memory for the approximate size @@ -289,8 +290,8 @@ unicode_regex_split_custom_gpt2(const std::string &text, auto _get_flags = [&](const size_t pos) -> codepoint_flags { return (offset_ini <= pos && pos < offset_end) - ? unicode_cpt_flags(cpts[pos]) - : codepoint_flags{}; + ? unicode_cpt_flags(cpts[pos]) + : codepoint_flags{}; }; size_t _prev_end = offset_ini; @@ -395,9 +396,9 @@ unicode_regex_split_custom_gpt2(const std::string &text, // LLAMA3 system regex: // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| // ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" -static std::vector -unicode_regex_split_custom_llama3(const std::string &text, - const std::vector &offsets) { +static std::vector unicode_regex_split_custom_llama3( + const std::string& text, + const std::vector& offsets) { std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve( offsets.size()); // Reserve memory for the approximate size @@ -418,8 +419,8 @@ unicode_regex_split_custom_llama3(const std::string &text, auto _get_flags = [&](const size_t pos) -> codepoint_flags { return (offset_ini <= pos && pos < offset_end) - ? unicode_cpt_flags(cpts[pos]) - : codepoint_flags{}; + ? unicode_cpt_flags(cpts[pos]) + : codepoint_flags{}; }; size_t _prev_end = offset_ini; @@ -546,18 +547,18 @@ unicode_regex_split_custom_llama3(const std::string &text, } // use std::wregex to split the text -static std::vector -unicode_regex_split_stl(const std::wstring &wtext, - const std::wstring ®ex_expr, - const std::vector &offsets) { +static std::vector unicode_regex_split_stl( + const std::wstring& wtext, + const std::wstring& regex_expr, + const std::vector& offsets) { std::wregex expr(regex_expr); std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve( offsets.size()); // Reserve memory for the approximate size size_t start = 0; for (auto offset : offsets) { - std::wcregex_iterator it(wtext.data() + start, - wtext.data() + start + offset, expr); + std::wcregex_iterator it( + wtext.data() + start, wtext.data() + start + offset, expr); std::wcregex_iterator end; int64_t start_idx = 0; @@ -581,17 +582,18 @@ unicode_regex_split_stl(const std::wstring &wtext, } // use std::regex to split the text -static std::vector -unicode_regex_split_stl(const std::string &text, const std::string ®ex_expr, - const std::vector &offsets) { +static std::vector unicode_regex_split_stl( + const std::string& text, + const std::string& regex_expr, + const std::vector& offsets) { std::regex expr(regex_expr); std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve( offsets.size()); // Reserve memory for the approximate size size_t start = 0; for (auto offset : offsets) { - std::cregex_iterator it(text.data() + start, text.data() + start + offset, - expr); + std::cregex_iterator it( + text.data() + start, text.data() + start + offset, expr); std::cregex_iterator end; int64_t start_idx = 0; @@ -614,14 +616,15 @@ unicode_regex_split_stl(const std::string &text, const std::string ®ex_expr, return bpe_offsets; } -static std::vector -unicode_regex_split_custom(const std::string &text, - const std::string ®ex_expr, - const std::vector &offsets) { +static std::vector unicode_regex_split_custom( + const std::string& text, + const std::string& regex_expr, + const std::vector& offsets) { std::vector bpe_offsets; - if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| " - "?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { + if (regex_expr == + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| " + "?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets); } else if ( regex_expr == @@ -631,7 +634,6 @@ unicode_regex_split_custom(const std::string &text, "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^" "\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| " "?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { - bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); } @@ -671,23 +673,24 @@ std::string unicode_cpt_to_utf8(uint32_t cp) { throw std::invalid_argument("invalid codepoint"); } -std::vector -unicode_cpts_normalize_nfd(const std::vector &cpts) { - auto comp = [](const uint32_t cpt, const range_nfd &range) { +std::vector unicode_cpts_normalize_nfd( + const std::vector& cpts) { + auto comp = [](const uint32_t cpt, const range_nfd& range) { return cpt < range.first; }; std::vector result(cpts.size()); for (size_t i = 0; i < cpts.size(); ++i) { const uint32_t cpt = cpts[i]; - auto it = std::upper_bound(unicode_ranges_nfd.begin(), - unicode_ranges_nfd.end(), cpt, comp) - - 1; + auto it = + std::upper_bound( + unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - + 1; result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt; } return result; } -std::vector unicode_cpts_from_utf8(const std::string &utf8) { +std::vector unicode_cpts_from_utf8(const std::string& utf8) { std::vector result; result.reserve(utf8.size()); size_t offset = 0; @@ -703,7 +706,7 @@ codepoint_flags unicode_cpt_flags(const uint32_t cp) { return cp < cpt_flags.size() ? cpt_flags[cp] : undef; } -codepoint_flags unicode_cpt_flags(const std::string &utf8) { +codepoint_flags unicode_cpt_flags(const std::string& utf8) { static const codepoint_flags undef(codepoint_flags::UNDEFINED); if (utf8.empty()) { return undef; // undefined @@ -718,7 +721,7 @@ std::string unicode_byte_to_utf8(uint8_t byte) { return map.at(byte); } -uint8_t unicode_utf8_to_byte(const std::string &utf8) { +uint8_t unicode_utf8_to_byte(const std::string& utf8) { static std::unordered_map map = unicode_utf8_to_byte_map(); return map.at(utf8); @@ -726,19 +729,22 @@ uint8_t unicode_utf8_to_byte(const std::string &utf8) { uint32_t unicode_tolower(uint32_t cp) { // binary search - auto it = std::lower_bound(unicode_map_lowercase.begin(), - unicode_map_lowercase.end(), cp, - [](const std::pair &pair, - uint32_t value) { return pair.first < value; }); + auto it = std::lower_bound( + unicode_map_lowercase.begin(), + unicode_map_lowercase.end(), + cp, + [](const std::pair& pair, uint32_t value) { + return pair.first < value; + }); if (it != unicode_map_lowercase.end() && it->first == cp) { return it->second; } return cp; // Return the original code point if no lowercase mapping is found } -std::vector -unicode_regex_split(const std::string &text, - const std::vector ®ex_exprs) { +std::vector unicode_regex_split( + const std::string& text, + const std::vector& regex_exprs) { // unicode categories static const std::map k_ucat_enum = { {"\\p{N}", codepoint_flags::NUMBER}, @@ -753,7 +759,7 @@ unicode_regex_split(const std::string &text, }; static const std::map k_ucat_map = { - {codepoint_flags::NUMBER, "\x30-\x39"}, // 0-9 + {codepoint_flags::NUMBER, "\x30-\x39"}, // 0-9 {codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A"}, // A-Za-z {codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-" @@ -762,9 +768,9 @@ unicode_regex_split(const std::string &text, // compute collapsed codepoints only if needed by at least one regex bool need_collapse = false; - for (auto ®ex_expr : regex_exprs) { + for (auto& regex_expr : regex_exprs) { // search for unicode categories - for (const auto &ucat : k_ucat_enum) { + for (const auto& ucat : k_ucat_enum) { if (std::string::npos != regex_expr.find(ucat.first)) { need_collapse = true; break; @@ -806,7 +812,7 @@ unicode_regex_split(const std::string &text, std::vector bpe_offsets = {cpts.size()}; - for (auto ®ex_expr : regex_exprs) { + for (auto& regex_expr : regex_exprs) { // first, see if we have an efficient custom regex implementation auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); @@ -821,7 +827,7 @@ unicode_regex_split(const std::string &text, // and replace the unicode category with the corresponding collapsed // representation bool use_collapsed = false; - for (auto &ucat : k_ucat_enum) { + for (auto& ucat : k_ucat_enum) { if (std::string::npos != regex_expr.find(ucat.first)) { use_collapsed = true; break; @@ -900,7 +906,7 @@ unicode_regex_split(const std::string &text, // printf("regex_expr: %s\n", regex_expr.c_str()); bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets); } - } catch (std::regex_error &e) { + } catch (std::regex_error& e) { fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); fprintf(stderr, "Regex error: %s\n", e.what()); throw std::runtime_error("Failed to process regex"); @@ -912,7 +918,7 @@ unicode_regex_split(const std::string &text, bpe_offsets.size()); // reserve memory for the approximate size size_t start = 0; - for (size_t &offset : bpe_offsets) { + for (size_t& offset : bpe_offsets) { bpe_words.emplace_back(); for (size_t i = start; i < start + offset; ++i) { bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); From d70f5a760552d8d3bb288cdd93eebde477bb6eb0 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 17 Mar 2025 12:13:34 -0700 Subject: [PATCH 4/4] Fix the build Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- src/pre_tokenizer.cpp | 47 ++++++++++++++++++++----------------------- src/token_decoder.cpp | 8 ++++---- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/pre_tokenizer.cpp b/src/pre_tokenizer.cpp index 5e6e662..6448363 100644 --- a/src/pre_tokenizer.cpp +++ b/src/pre_tokenizer.cpp @@ -8,7 +8,7 @@ // Local #include -#include +#include // Standard #include @@ -63,37 +63,35 @@ PreTokenizer::Ptr PreTokenizerConfig::create() const { "Missing pretokenizers for PreTokenizer of type Sequence"); } std::vector pretoks; - std::transform( - pretokenizers->begin(), - pretokenizers->end(), - std::back_inserter(pretoks), - [](const PreTokenizerConfig& cfg) { return cfg.create(); }); + std::transform(pretokenizers->begin(), pretokenizers->end(), + std::back_inserter(pretoks), + [](const PreTokenizerConfig &cfg) { return cfg.create(); }); return PreTokenizer::Ptr(new SequencePreTokenizer(pretoks)); } throw std::runtime_error("Unsupported PreTokenizer type: " + type); } -PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) { +PreTokenizerConfig &PreTokenizerConfig::parse_json(const json &json_config) { type = json_config.at("type"); if (type == "Split") { try { pattern = json_config.at("pattern"); - } catch (json::out_of_range&) { + } catch (json::out_of_range &) { } } else if (type == "Digits") { try { individual_digits = json_config.at("individual_digits"); - } catch (json::out_of_range&) { + } catch (json::out_of_range &) { } } else if (type == "ByteLevel") { try { add_prefix_space = json_config.at("add_prefix_space"); - } catch (json::out_of_range&) { + } catch (json::out_of_range &) { } // TODO: trim_offsets, use_regex } else if (type == "Sequence") { pretokenizers = std::vector(); - for (const auto& entry : json_config.at("pretokenizers")) { + for (const auto &entry : json_config.at("pretokenizers")) { pretokenizers->push_back(PreTokenizerConfig().parse_json(entry)); } } else { @@ -104,14 +102,14 @@ PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) { // RegexPreTokenizer /////////////////////////////////////////////////////////// -RegexPreTokenizer::Re2UPtr RegexPreTokenizer::create_regex_( - const std::string& pattern) { +RegexPreTokenizer::Re2UPtr +RegexPreTokenizer::create_regex_(const std::string &pattern) { assert(!pattern.empty()); return std::make_unique("(" + pattern + ")"); } -std::vector RegexPreTokenizer::pre_tokenize( - re2::StringPiece input) const { +std::vector +RegexPreTokenizer::pre_tokenize(re2::StringPiece input) const { std::vector result; std::string piece; while (RE2::FindAndConsume(&input, *regex_, &piece)) { @@ -138,14 +136,13 @@ constexpr char GPT2_EXPR[] = // Construction // ////////////////// -ByteLevelPreTokenizer::ByteLevelPreTokenizer( - bool add_prefix_space, - const std::string& pattern) +ByteLevelPreTokenizer::ByteLevelPreTokenizer(bool add_prefix_space, + const std::string &pattern) : pattern_(pattern.empty() ? GPT2_EXPR : pattern), add_prefix_space_(add_prefix_space) {} -std::vector ByteLevelPreTokenizer::pre_tokenize( - re2::StringPiece input) const { +std::vector +ByteLevelPreTokenizer::pre_tokenize(re2::StringPiece input) const { // Add the prefix space if configured to do so std::string input_str(input); if (add_prefix_space_ && !input_str.empty() && input_str[0] != ' ') { @@ -161,13 +158,13 @@ SequencePreTokenizer::SequencePreTokenizer( std::vector pre_tokenizers) : pre_tokenizers_(std::move(pre_tokenizers)) {} -std::vector SequencePreTokenizer::pre_tokenize( - re2::StringPiece input) const { +std::vector +SequencePreTokenizer::pre_tokenize(re2::StringPiece input) const { std::vector pieces{std::string(input)}; - for (const auto& pre_tokenizer : pre_tokenizers_) { + for (const auto &pre_tokenizer : pre_tokenizers_) { std::vector new_pieces; - for (const auto& piece : pieces) { - for (const auto& subpiece : pre_tokenizer->pre_tokenize(piece)) { + for (const auto &piece : pieces) { + for (const auto &subpiece : pre_tokenizer->pre_tokenize(piece)) { new_pieces.push_back(subpiece); } } diff --git a/src/token_decoder.cpp b/src/token_decoder.cpp index 669f6dd..c6e5b10 100644 --- a/src/token_decoder.cpp +++ b/src/token_decoder.cpp @@ -16,7 +16,7 @@ #include // Local -#include +#include using json = nlohmann::json; @@ -37,7 +37,7 @@ TokenDecoder::Ptr TokenDecoderConfig::create() const { throw std::runtime_error("Unsupported TokenDecoder type: " + type); } -TokenDecoderConfig& TokenDecoderConfig::parse_json(const json& json_config) { +TokenDecoderConfig &TokenDecoderConfig::parse_json(const json &json_config) { type = json_config.at("type"); if (type == "ByteLevel") { // No parameters to parse @@ -54,7 +54,7 @@ namespace { // Copied from llama.cpp // CITE: // https://github.com/ggerganov/llama.cpp/blob/master/src/llama-vocab.cpp#L20 -static std::string format(const char* fmt, ...) { +static std::string format(const char *fmt, ...) { va_list ap; va_list ap2; va_start(ap, fmt); @@ -84,7 +84,7 @@ std::string ByteLevelTokenDecoder::decode(re2::StringPiece token) const { const auto utf8 = unicode_cpt_to_utf8(cpt); try { decoded_text += unicode_utf8_to_byte(utf8); - } catch (const std::out_of_range& /*e*/) { + } catch (const std::out_of_range & /*e*/) { decoded_text += "[UNK_BYTE_0x"; for (const auto c : utf8) { decoded_text += format("%02x", (uint8_t)c);