Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions src/pre_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

// Local
#include <pytorch/tokenizers/pre_tokenizer.h>
#include <pytorch/tokenizers/third-party/llama.cpp-unicode/unicode.h>
#include <unicode.h>

// Standard
#include <algorithm>
Expand Down Expand Up @@ -63,37 +63,35 @@ PreTokenizer::Ptr PreTokenizerConfig::create() const {
"Missing pretokenizers for PreTokenizer of type Sequence");
}
std::vector<PreTokenizer::Ptr> pretoks;
std::transform(
pretokenizers->begin(),
pretokenizers->end(),
std::back_inserter(pretoks),
[](const PreTokenizerConfig& cfg) { return cfg.create(); });
std::transform(pretokenizers->begin(), pretokenizers->end(),
std::back_inserter(pretoks),
[](const PreTokenizerConfig &cfg) { return cfg.create(); });
return PreTokenizer::Ptr(new SequencePreTokenizer(pretoks));
}
throw std::runtime_error("Unsupported PreTokenizer type: " + type);
}

PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) {
PreTokenizerConfig &PreTokenizerConfig::parse_json(const json &json_config) {
type = json_config.at("type");
if (type == "Split") {
try {
pattern = json_config.at("pattern");
} catch (json::out_of_range&) {
} catch (json::out_of_range &) {
}
} else if (type == "Digits") {
try {
individual_digits = json_config.at("individual_digits");
} catch (json::out_of_range&) {
} catch (json::out_of_range &) {
}
} else if (type == "ByteLevel") {
try {
add_prefix_space = json_config.at("add_prefix_space");
} catch (json::out_of_range&) {
} catch (json::out_of_range &) {
}
// TODO: trim_offsets, use_regex
} else if (type == "Sequence") {
pretokenizers = std::vector<PreTokenizerConfig>();
for (const auto& entry : json_config.at("pretokenizers")) {
for (const auto &entry : json_config.at("pretokenizers")) {
pretokenizers->push_back(PreTokenizerConfig().parse_json(entry));
}
} else {
Expand All @@ -104,14 +102,14 @@ PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) {

// RegexPreTokenizer ///////////////////////////////////////////////////////////

RegexPreTokenizer::Re2UPtr RegexPreTokenizer::create_regex_(
const std::string& pattern) {
RegexPreTokenizer::Re2UPtr
RegexPreTokenizer::create_regex_(const std::string &pattern) {
assert(!pattern.empty());
return std::make_unique<re2::RE2>("(" + pattern + ")");
}

std::vector<std::string> RegexPreTokenizer::pre_tokenize(
re2::StringPiece input) const {
std::vector<std::string>
RegexPreTokenizer::pre_tokenize(re2::StringPiece input) const {
std::vector<std::string> result;
std::string piece;
while (RE2::FindAndConsume(&input, *regex_, &piece)) {
Expand All @@ -138,14 +136,13 @@ constexpr char GPT2_EXPR[] =
// Construction //
//////////////////

ByteLevelPreTokenizer::ByteLevelPreTokenizer(
bool add_prefix_space,
const std::string& pattern)
ByteLevelPreTokenizer::ByteLevelPreTokenizer(bool add_prefix_space,
const std::string &pattern)
: pattern_(pattern.empty() ? GPT2_EXPR : pattern),
add_prefix_space_(add_prefix_space) {}

std::vector<std::string> ByteLevelPreTokenizer::pre_tokenize(
re2::StringPiece input) const {
std::vector<std::string>
ByteLevelPreTokenizer::pre_tokenize(re2::StringPiece input) const {
// Add the prefix space if configured to do so
std::string input_str(input);
if (add_prefix_space_ && !input_str.empty() && input_str[0] != ' ') {
Expand All @@ -161,13 +158,13 @@ SequencePreTokenizer::SequencePreTokenizer(
std::vector<PreTokenizer::Ptr> pre_tokenizers)
: pre_tokenizers_(std::move(pre_tokenizers)) {}

std::vector<std::string> SequencePreTokenizer::pre_tokenize(
re2::StringPiece input) const {
std::vector<std::string>
SequencePreTokenizer::pre_tokenize(re2::StringPiece input) const {
std::vector<std::string> pieces{std::string(input)};
for (const auto& pre_tokenizer : pre_tokenizers_) {
for (const auto &pre_tokenizer : pre_tokenizers_) {
std::vector<std::string> new_pieces;
for (const auto& piece : pieces) {
for (const auto& subpiece : pre_tokenizer->pre_tokenize(piece)) {
for (const auto &piece : pieces) {
for (const auto &subpiece : pre_tokenizer->pre_tokenize(piece)) {
new_pieces.push_back(subpiece);
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/token_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <nlohmann/json.hpp>

// Local
#include <pytorch/tokenizers/third-party/llama.cpp-unicode/unicode.h>
#include <unicode.h>

using json = nlohmann::json;

Expand All @@ -37,7 +37,7 @@ TokenDecoder::Ptr TokenDecoderConfig::create() const {
throw std::runtime_error("Unsupported TokenDecoder type: " + type);
}

TokenDecoderConfig& TokenDecoderConfig::parse_json(const json& json_config) {
TokenDecoderConfig &TokenDecoderConfig::parse_json(const json &json_config) {
type = json_config.at("type");
if (type == "ByteLevel") {
// No parameters to parse
Expand All @@ -54,7 +54,7 @@ namespace {
// Copied from llama.cpp
// CITE:
// https://github.com/ggerganov/llama.cpp/blob/master/src/llama-vocab.cpp#L20
static std::string format(const char* fmt, ...) {
static std::string format(const char *fmt, ...) {
va_list ap;
va_list ap2;
va_start(ap, fmt);
Expand Down Expand Up @@ -84,7 +84,7 @@ std::string ByteLevelTokenDecoder::decode(re2::StringPiece token) const {
const auto utf8 = unicode_cpt_to_utf8(cpt);
try {
decoded_text += unicode_utf8_to_byte(utf8);
} catch (const std::out_of_range& /*e*/) {
} catch (const std::out_of_range & /*e*/) {
decoded_text += "[UNK_BYTE_0x";
for (const auto c : utf8) {
decoded_text += format("%02x", (uint8_t)c);
Expand Down
20 changes: 3 additions & 17 deletions targets.bzl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "FBCODE")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime", "get_executorch_supported_platforms")
load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob")

PLATFORMS = (CXX, ANDROID, APPLE, FBCODE)
PLATFORMS = get_executorch_supported_platforms()

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
Expand Down Expand Up @@ -68,19 +67,6 @@ def define_common_targets():
platforms = PLATFORMS,
)

runtime.cxx_library(
name = "unicode",
srcs = [
"third-party/llama.cpp-unicode/src/unicode.cpp",
"third-party/llama.cpp-unicode/src/unicode-data.cpp",
],
exported_headers = subdir_glob([
("include", "pytorch/tokenizers/third-party/llama.cpp-unicode/*.h"),
]),
header_namespace = "",
platforms = PLATFORMS,
)

runtime.cxx_library(
name = "hf_tokenizer",
srcs = [
Expand All @@ -91,7 +77,7 @@ def define_common_targets():
],
exported_deps = [
":headers",
":unicode",
"//pytorch/tokenizers/third-party:unicode",
],
visibility = [
"@EXECUTORCH_CLIENTS",
Expand Down
13 changes: 13 additions & 0 deletions third-party/TARGETS
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob")

oncall("executorch")

Expand Down Expand Up @@ -45,3 +46,15 @@ runtime.cxx_library(
visibility = ["PUBLIC"],
_is_external_target = True,
)

runtime.cxx_library(
name = "unicode",
srcs = [
"llama.cpp-unicode/src/unicode.cpp",
"llama.cpp-unicode/src/unicode-data.cpp",
],
exported_headers = subdir_glob([
("include", "*.h"),
]),
header_namespace = "",
)
2 changes: 1 addition & 1 deletion third-party/llama.cpp-unicode/src/unicode-data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SOFTWARE.

// generated with scripts/gen-unicode-data.py

#include <pytorch/tokenizers/third-party/llama.cpp-unicode/unicode-data.h>
#include "unicode-data.h"

#include <cstdint>
#include <unordered_map>
Expand Down
Loading