Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions include/tokenizers_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,30 @@ extern "C" {

typedef void* TokenizerHandle;

typedef struct {
int* token_ids;
size_t len;
} TokenizerEncodeResult;

TokenizerHandle tokenizers_new_from_str(const char* json, size_t len);

TokenizerHandle byte_level_bpe_tokenizers_new_from_str(const char* vocab, size_t vocab_len,
const char* merges, size_t merges_len,
const char* added_tokens,
size_t added_tokens_len);

void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token);
void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token, TokenizerEncodeResult* result);

void tokenizers_encode_batch(TokenizerHandle handle, const char** data, size_t* len, size_t num_seqs,
int add_special_token, TokenizerEncodeResult* results);

void tokenizers_free_encode_results(TokenizerEncodeResult* results, size_t num_seqs);

void tokenizers_decode(TokenizerHandle handle, const uint32_t* data, size_t len,
int skip_special_token);

void tokenizers_get_decode_str(TokenizerHandle handle, const char** data, size_t* len);

void tokenizers_get_encode_ids(TokenizerHandle handle, const uint32_t** id_data, size_t* len);

void tokenizers_get_vocab_size(TokenizerHandle handle, size_t* size);

void tokenizers_id_to_token(TokenizerHandle handle, uint32_t id, const char** data, size_t* len);
Expand Down
15 changes: 15 additions & 0 deletions include/tokenizers_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ class Tokenizer {
*/
virtual std::vector<int32_t> Encode(const std::string& text) = 0;

/*!
* \brief Encode a batch of texts into ids.
* \param texts The input texts.
* \returns The encoded token ids.
*/
virtual std::vector<std::vector<int32_t>> EncodeBatch(const std::vector<std::string>& texts) {
// Fall back when the derived class does not implement this function.
std::vector<std::vector<int32_t>> ret;
ret.reserve(texts.size());
for (const auto& text : texts) {
ret.push_back(Encode(text));
}
return ret;
}

/*!
* \brief Decode token ids into text.
* \param text The token ids.
Expand Down
65 changes: 53 additions & 12 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,23 @@ use tokenizers::tokenizer::Tokenizer;

pub struct TokenizerWrapper {
tokenizer: Tokenizer,
encode_ids: Vec<u32>,
decode_str: String,
id_to_token_result: String,
}

pub type Vocab = HashMap<String, u32>;
pub type Merges = Vec<(String, String)>;

#[repr(C)]
pub struct TokenizerEncodeResult {
token_ids: *mut u32,
len: usize,
}

impl TokenizerWrapper {
pub fn from_str(json: &str) -> TokenizerWrapper {
TokenizerWrapper {
tokenizer: Tokenizer::from_str(json).unwrap().into(),
encode_ids: Vec::new(),
decode_str: String::new(),
id_to_token_result: String::new(),
}
Expand Down Expand Up @@ -77,16 +81,22 @@ impl TokenizerWrapper {
.with_decoder(byte_level);
TokenizerWrapper {
tokenizer: tokenizer,
encode_ids: Vec::new(),
decode_str: String::new(),
id_to_token_result: String::new(),
}
}

pub fn encode(&mut self, text: &str, add_special_tokens: bool) {
pub fn encode(&mut self, text: &str, add_special_tokens: bool) -> Vec<u32> {
let encoded = self.tokenizer.encode(text, add_special_tokens).unwrap();
self.encode_ids.resize(encoded.len(), 0);
self.encode_ids.copy_from_slice(encoded.get_ids());
return encoded.get_ids().to_vec();
}

pub fn encode_batch(&mut self, texts: Vec<&str>, add_special_tokens: bool) -> Vec<Vec<u32>> {
let results = self.tokenizer.encode_batch(texts, add_special_tokens).unwrap()
.into_iter()
.map(|encoded| encoded.get_ids().to_vec())
.collect::<Vec<Vec<u32>>>();
return results;
}

pub fn decode(&mut self, ids: &[u32], skip_special_tokens: bool) {
Expand Down Expand Up @@ -135,22 +145,53 @@ extern "C" fn tokenizers_encode(
input_cstr: *const u8,
len: usize,
add_special_tokens: i32,
out_result: *mut TokenizerEncodeResult,
) {
unsafe {
let input_data = std::str::from_utf8(std::slice::from_raw_parts(input_cstr, len)).unwrap();
(*handle).encode(input_data, add_special_tokens != 0);
let encoded = (*handle).encode(input_data, add_special_tokens != 0);
let len = encoded.len();
*out_result = TokenizerEncodeResult {
token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32,
len: len,
};
}
}

#[no_mangle]
extern "C" fn tokenizers_get_encode_ids(
extern "C" fn tokenizers_encode_batch(
handle: *mut TokenizerWrapper,
out_data: *mut *mut u32,
out_len: *mut usize,
input_cstr: *const *const u8,
input_len: *const usize,
num_seqs: usize,
add_special_tokens: i32,
out_result: *mut TokenizerEncodeResult,
) {
unsafe {
*out_data = (*handle).encode_ids.as_mut_ptr();
*out_len = (*handle).encode_ids.len()
let input_data = (0..num_seqs)
.map(|i| {
std::str::from_utf8(std::slice::from_raw_parts(*input_cstr.offset(i as isize), *input_len.offset(i as isize))).unwrap()
})
.collect::<Vec<&str>>();
let encoded_batch = (*handle).encode_batch(input_data, add_special_tokens != 0);
for (i, encoded) in encoded_batch.into_iter().enumerate() {
let len = encoded.len();
let result = TokenizerEncodeResult {
token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32,
len: len,
};
*out_result.offset(i as isize) = result;
}
}
}

#[no_mangle]
extern "C" fn tokenizers_free_encode_results(results: *mut TokenizerEncodeResult, num_seqs: usize) {
unsafe {
let slice = std::slice::from_raw_parts_mut(results, num_seqs);
for result in &mut *slice {
drop(Box::from_raw(std::slice::from_raw_parts_mut(result.token_ids, result.len)));
}
}
}

Expand Down
46 changes: 36 additions & 10 deletions src/huggingface_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,44 @@ class HFTokenizer : public Tokenizer {

// use i32 to be consistent with sentencepiece
std::vector<int32_t> Encode(const std::string& text, bool add_special_tokens) {
tokenizers_encode(handle_, text.data(), text.length(), static_cast<int>(add_special_tokens));
const uint32_t* data;
size_t len;
tokenizers_get_encode_ids(handle_, &data, &len);
const int32_t* data_i32 = reinterpret_cast<const int32_t*>(data);
auto res = std::vector<int32_t>(data_i32, data_i32 + len);
return res;
TokenizerEncodeResult result;
tokenizers_encode(handle_, text.data(), text.length(), static_cast<int>(add_special_tokens),
&result);
std::vector<int32_t> ret(result.token_ids, result.token_ids + result.len);
tokenizers_free_encode_results(&result, 1);
return ret;
}

// use i32 to be consistent with sentencepiece
std::vector<int32_t> Encode(const std::string& text) final {
return Encode(text, false);
// use i32 to be consistent with sentencepiece
std::vector<int32_t> Encode(const std::string& text) final {
return Encode(text, false);
}

std::vector<std::vector<int32_t>> EncodeBatch(const std::vector<std::string>& texts, bool add_special_tokens) final {
std::vector<const char*> texts_raw;
std::vector<size_t> seq_lens;
size_t num_seqs = texts.size();
texts_raw.reserve(num_seqs);
seq_lens.reserve(num_seqs);
for (const auto& text : texts) {
texts_raw.push_back(text.data());
seq_lens.push_back(text.length());
}
std::vector<TokenizerEncodeResult> results(num_seqs);
tokenizers_encode_batch(handle_, texts_raw.data(), seq_lens.data(), texts.size(),
static_cast<int>(add_special_tokens), results.data());
std::vector<std::vector<int32_t>> ret;
ret.reserve(texts.size());
for (size_t i = 0; i < texts.size(); ++i) {
ret.push_back(
std::vector<int32_t>(results[i].token_ids, results[i].token_ids + results[i].len));
}
tokenizers_free_encode_results(results.data(), texts.size());
return ret;
}

std::vector<std::vector<int32_t>> EncodeBatch(const std::vector<std::string>& texts) final {
return EncodeBatch(texts, false);
}

// use i32 to be consistent with sentencepiece
Expand Down