Skip to content

Commit

Permalink
Add initial tiktoken and Phi3SmallTokenizer support (#729)
Browse files Browse the repository at this point in the history
* add initial tiktoken support

* add vector hash and equal for bpe ranks map

* change lambda comparator

* move phi-3-small files

* final changes

* move tiktoken files from data2 to data

* add unit test

* add tokenizer module

* merge json and tiktoken impl

* fix tiktoken encoding problem

* address comments

* remove dummy tokens

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 2, 2024
1 parent 46998e9 commit 7851b51
Show file tree
Hide file tree
Showing 10 changed files with 100,631 additions and 19 deletions.
23 changes: 22 additions & 1 deletion operators/tokenizer/bpe_json.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,23 @@ class TokenJsonConfig final {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open a json file: " + file_path.string());
}

vocab_path_ = (path(json_path) / "tokenizer.json").string();
auto vocab_file_path = path(json_path) / "tokenizer.json";
vocab_path_ = vocab_file_path.string();
std::ifstream vocab_fs = vocab_file_path.open();
if (!vocab_fs.is_open()) {
// No tokenizer.json file present; search for tokenizer module file
auto module_file_path = path(json_path) / "tokenizer_module.json";
module_path_ = module_file_path.string();
std::ifstream tok_module_ifs = module_file_path.open();
if (!tok_module_ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "No tokenizer.json or tokenizer_module.json file found.");
} else {
nlohmann::json tok_module_json_config = nlohmann::json::parse(tok_module_ifs);
auto tiktoken_path = tok_module_json_config.value("tiktoken_file", "");
vocab_file_path = path(json_path) / tiktoken_path.c_str();
vocab_path_ = vocab_file_path.string();
}
}
nlohmann::json json_config = nlohmann::json::parse(ifs);
add_bos_token_ = json_config.value("add_bos_token", false);
add_eos_token_ = json_config.value("add_eos_token", false);
Expand Down Expand Up @@ -66,6 +82,10 @@ class TokenJsonConfig final {

const std::string& GetVocabDataFile() const { return vocab_path_; }

const std::string& GetTikTokenModuleFile() const {
return module_path_;
}

public:
bool add_bos_token_{};
bool add_eos_token_{};
Expand All @@ -80,6 +100,7 @@ class TokenJsonConfig final {

private:
std::string vocab_path_;
std::string module_path_;
};

} // namespace ort_extensions::bpe
162 changes: 150 additions & 12 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "bpe_json.hpp"
#include "bpe_tokenizer.hpp"

#include "base64.h"

#include <optional>
#include <limits>

Expand Down Expand Up @@ -552,13 +554,140 @@ SpmTokenizer::SpmTokenizer()

JsonFastTokenizer::JsonFastTokenizer() : KernelBpeTokenizer(kGPT2Configuration) {}

/*
Read more here: https://github.com/huggingface/transformers/blob/60bb571e993b7d73257fb64044726b569fef9403/src/transformers/convert_slow_tokenizer.py#L1454
Note: this is similar to the BPE CreateByteEncoder, however for decoding the .tiktoken bytes
we need to store the strings rather than their IDs, and thereby need a separate map.
*/
void JsonFastTokenizer::CreateUnicodeByteEncoder() {
char32_t index = 256;
for (char32_t i = 0; i < 256; ++i) {
if ((i >= 0 && i < 33) || (i >= 127 && i < 161) || (i == 173)) {
unicode_byte_encoder_[i] = ustring::EncodeUTF8Char(index++);
} else {
unicode_byte_encoder_[i] = ustring::EncodeUTF8Char(i);
}
}
}

std::string JsonFastTokenizer::TokenBytesToString(std::vector<uint8_t>& bytes) {
std::string result;
for (auto c : bytes) {
result += unicode_byte_encoder_[static_cast<unsigned char>(c)];
}
return result;
}

// Custom hash function for the vector key
struct VectorHash {
size_t operator()(const std::vector<uint8_t>& v) const {
std::hash<uint8_t> hasher;
size_t seed = 0;
for (uint8_t i : v) {
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};

// Custom equality function for the vector key
struct VectorEqual {
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
return a == b;
}
};

OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
}

// consider to use SAX parser for large json file
nlohmann::json tok_json;
std::ifstream module_ifs;

// Following vocab and merges only used for tiktoken case but accessed outside scope below
std::unordered_map<std::string, uint32_t> vocab;
std::vector<std::pair<std::string, std::string>> merges;

if (tiktoken_){
std::string module_file = config.GetTikTokenModuleFile();

module_ifs = path(module_file).open();
if (!module_ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open module file: " + module_file);
}

std::unordered_map<std::vector<uint8_t>, uint32_t, VectorHash, VectorEqual> bpe_ranks;

std::string line;
while (std::getline(ifs, line)) {
if (!line.empty()) {
std::istringstream lineStream(line);
std::string token;
uint32_t rank;
while (lineStream >> token >> rank) {
// Decode base64 token and convert rank to int
std::vector<uint8_t> decoded_token;
base64_decode(token, decoded_token);
// Store bpe token and rank
bpe_ranks[decoded_token] = rank;
}
}
}

std::vector<std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>> byte_merges;

bbpe_tokenizer_ = std::make_unique<BpeModel>();
JsonFastTokenizer::CreateUnicodeByteEncoder();

for (const auto& item : bpe_ranks) {
std::vector<uint8_t> token = item.first;
uint32_t rank = item.second;
vocab[JsonFastTokenizer::TokenBytesToString(token)] = rank;

if (token.size() == 1) {
continue;
}

std::vector<std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>> local;
for (size_t index = 1; index < token.size(); index++) {
std::vector<uint8_t> piece_l(token.begin(), token.begin() + index);
std::vector<uint8_t> piece_r(token.begin() + index, token.end());
if (bpe_ranks.count(piece_l) && bpe_ranks.count(piece_r)) {
local.emplace_back(piece_l, piece_r, rank);
}
}

auto compare_bpe_tuples = [&](const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& a,
const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& b) {
// Compare comparator based on the ranks in bpe_ranks
return bpe_ranks[std::get<0>(a)] < bpe_ranks[std::get<0>(b)] ||
(bpe_ranks[std::get<0>(a)] == bpe_ranks[std::get<0>(b)] && bpe_ranks[std::get<1>(a)] < bpe_ranks[std::get<1>(b)]);
};

std::sort(local.begin(), local.end(), compare_bpe_tuples);

byte_merges.insert(byte_merges.end(), local.begin(), local.end());
}

// Custom comparator that compares the third element of the tuples
auto compare_merge_tuples = [&](const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& a,
const std::tuple<std::vector<uint8_t>, std::vector<uint8_t>, uint32_t>& b) {
return std::get<2>(a) < std::get<2>(b);
};

std::sort(byte_merges.begin(), byte_merges.end(), compare_merge_tuples);

// Populate merges
for (auto& val : byte_merges) {
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)), JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
}
}

const char token_sub[] = "Tokenizer";
model_name_ = config.tokenizer_class_.substr(0, config.tokenizer_class_.find(token_sub));
json_conf_.name_ = model_name_.c_str();
Expand All @@ -570,18 +699,27 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
// re-bind the configuration object
bpe_conf_ = json_conf_;

// consider to use SAX parser for large json file
nlohmann::json tok_json;
ifs >> tok_json;
auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
}
OrtxStatus status;
if (tiktoken_){
status = bbpe_tokenizer_->Load(vocab,
merges,
bpe_conf_.get().GetSpecialTokens().c_str(),
false);

bbpe_tokenizer_ = std::make_unique<BpeModel>();
auto status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
module_ifs >> tok_json;
} else {
ifs >> tok_json;
auto model_node = tok_json.find("model");
if (model_node == tok_json.end()) {
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
}

bbpe_tokenizer_ = std::make_unique<BpeModel>();
status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
}


auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
Expand Down Expand Up @@ -640,4 +778,4 @@ OrtxStatus JsonFastTokenizer::Compute(const ortc::Tensor<std::string>& input,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
}
}
24 changes: 24 additions & 0 deletions operators/tokenizer/bpe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ struct SpmTokenizer : KernelBpeTokenizer {
class JsonFastTokenizer : KernelBpeTokenizer {
public:
JsonFastTokenizer();
bool tiktoken_ = false;
std::string unicode_byte_encoder_[256] = {};
void CreateUnicodeByteEncoder();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
Expand All @@ -121,3 +125,23 @@ class JsonFastTokenizer : KernelBpeTokenizer {
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
};

class TikTokenizer : KernelBpeTokenizer {
public:
TikTokenizer();
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;

public:
const auto& GetAddedTokens() const { return added_tokens_; }
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }

private:
std::unique_ptr<ort_extensions::BpeModel>bbpe_tokenizer_;
BpeModelConf json_conf_;
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
};
41 changes: 41 additions & 0 deletions operators/tokenizer/bpe_tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,47 @@ class BpeModel {
return {};
}

OrtxStatus Load(std::unordered_map<std::string, uint32_t>& vocab,
std::vector<std::pair<std::string, std::string>>& merges,
const char* /* special_tokens */,
bool spm_converted) {
vocab_map_ = vocab;

if (spm_converted) {
UpdateSpmByteToken(vocab_map_);
} else {
CreateByteEncoder();
}

uint32_t index = 0;
for (auto& tuple : merges){
std::string w1 = tuple.first;
std::string w2 = tuple.second;
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
if (w2.find("</w>") != std::string::npos || w1.find("</w>") != std::string::npos) {
token_length -= 4;
}
auto iw1 = GetTokenId(w1);
auto iw2 = GetTokenId(w2);
auto iww = GetTokenId(w1 + w2);
BpeNode value{iww, index++, token_length};
bpe_rank_[GetRankKey(iw1, iw2)] = value;
}

id2token_map_.resize(vocab_map_.size());
for (const auto& [t, i] : vocab_map_) {
if (i > static_cast<uint32_t>((std::numeric_limits<int32_t>::max)())) {
continue; // safe purpose.
}
if (i > id2token_map_.size()) {
id2token_map_.resize(static_cast<size_t>(i) + 1);
}
id2token_map_[i] = t;
}

return {};
}

OrtxStatus LoadAddedTokens(const char* added_tokens) {
int id = bpe::kInvalidTokenId;
std::istringstream strm_tokens(added_tokens);
Expand Down
24 changes: 18 additions & 6 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,24 @@ OrtxStatus TokenizerImpl::Load(const std::string& dir) {
return status;
}

auto vocab_file_path = path(dir) / "tokenizer.json";
std::ifstream vocab_fs = vocab_file_path.open();

tokenizer_ = std::make_unique<JsonFastTokenizer>();
// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);
if (status.IsOk()) {
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
status = detokenizer_->Load(tok_config_, *tokenizer_);
if (!vocab_fs.is_open()) {
// No tokenizer.json file present; use TikToken tokenizer
tokenizer_->tiktoken_ = true;

// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);
} else {
// load the tokenizer from a config
status = tokenizer_->Load(*tok_config_);

if (status.IsOk()) {
detokenizer_ = std::make_unique<BpeStreamingDecoder>();
status = detokenizer_->Load(tok_config_, *tokenizer_);
}
}

return status;
Expand All @@ -34,7 +46,7 @@ OrtxStatus TokenizerImpl::BatchEncode(const std::vector<std::string_view>& input
for (const auto& s : input) {
ortc::Tensor<int64_t> ts_output(&CppAllocator::Instance());
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(s)});
auto status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
OrtxStatus status = tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);

if (!status.IsOk()) {
return status;
Expand Down
1 change: 1 addition & 0 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TokenizerImpl : public OrtxObjectImpl {
std::vector<std::vector<extTokenId_t>>& t_ids) const;

private:
bool tiktoken = false;
std::string tokenizer_dir_;
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
std::unique_ptr<JsonFastTokenizer> tokenizer_;
Expand Down
Loading

0 comments on commit 7851b51

Please sign in to comment.