diff --git a/include/tokenizers_c.h b/include/tokenizers_c.h index 6563f53..4276ee8 100644 --- a/include/tokenizers_c.h +++ b/include/tokenizers_c.h @@ -16,6 +16,11 @@ extern "C" { typedef void* TokenizerHandle; +typedef struct { + int* token_ids; + size_t len; +} TokenizerEncodeResult; + TokenizerHandle tokenizers_new_from_str(const char* json, size_t len); TokenizerHandle byte_level_bpe_tokenizers_new_from_str(const char* vocab, size_t vocab_len, @@ -23,15 +28,18 @@ TokenizerHandle byte_level_bpe_tokenizers_new_from_str(const char* vocab, size_t const char* added_tokens, size_t added_tokens_len); -void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token); +void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token, TokenizerEncodeResult* result); + +void tokenizers_encode_batch(TokenizerHandle handle, const char** data, size_t* len, size_t num_seqs, + int add_special_token, TokenizerEncodeResult* results); + +void tokenizers_free_encode_results(TokenizerEncodeResult* results, size_t num_seqs); void tokenizers_decode(TokenizerHandle handle, const uint32_t* data, size_t len, int skip_special_token); void tokenizers_get_decode_str(TokenizerHandle handle, const char** data, size_t* len); -void tokenizers_get_encode_ids(TokenizerHandle handle, const uint32_t** id_data, size_t* len); - void tokenizers_get_vocab_size(TokenizerHandle handle, size_t* size); void tokenizers_id_to_token(TokenizerHandle handle, uint32_t id, const char** data, size_t* len); diff --git a/include/tokenizers_cpp.h b/include/tokenizers_cpp.h index 7de6721..d37aa57 100644 --- a/include/tokenizers_cpp.h +++ b/include/tokenizers_cpp.h @@ -29,6 +29,21 @@ class Tokenizer { */ virtual std::vector Encode(const std::string& text) = 0; + /*! + * \brief Encode a batch of texts into ids. + * \param texts The input texts. + * \returns The encoded token ids. + */ + virtual std::vector> EncodeBatch(const std::vector& texts) { + // Fall back when the derived class does not implement this function. + std::vector> ret; + ret.reserve(texts.size()); + for (const auto& text : texts) { + ret.push_back(Encode(text)); + } + return ret; + } + /*! * \brief Decode token ids into text. * \param text The token ids. diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 10206a0..98ce523 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -7,7 +7,6 @@ use tokenizers::tokenizer::Tokenizer; pub struct TokenizerWrapper { tokenizer: Tokenizer, - encode_ids: Vec, decode_str: String, id_to_token_result: String, } @@ -15,11 +14,16 @@ pub struct TokenizerWrapper { pub type Vocab = HashMap; pub type Merges = Vec<(String, String)>; +#[repr(C)] +pub struct TokenizerEncodeResult { + token_ids: *mut u32, + len: usize, +} + impl TokenizerWrapper { pub fn from_str(json: &str) -> TokenizerWrapper { TokenizerWrapper { tokenizer: Tokenizer::from_str(json).unwrap().into(), - encode_ids: Vec::new(), decode_str: String::new(), id_to_token_result: String::new(), } @@ -77,16 +81,22 @@ impl TokenizerWrapper { .with_decoder(byte_level); TokenizerWrapper { tokenizer: tokenizer, - encode_ids: Vec::new(), decode_str: String::new(), id_to_token_result: String::new(), } } - pub fn encode(&mut self, text: &str, add_special_tokens: bool) { + pub fn encode(&mut self, text: &str, add_special_tokens: bool) -> Vec { let encoded = self.tokenizer.encode(text, add_special_tokens).unwrap(); - self.encode_ids.resize(encoded.len(), 0); - self.encode_ids.copy_from_slice(encoded.get_ids()); + return encoded.get_ids().to_vec(); + } + + pub fn encode_batch(&mut self, texts: Vec<&str>, add_special_tokens: bool) -> Vec> { + let results = self.tokenizer.encode_batch(texts, add_special_tokens).unwrap() + .into_iter() + .map(|encoded| encoded.get_ids().to_vec()) + .collect::>>(); + return results; } pub fn decode(&mut self, ids: &[u32], skip_special_tokens: bool) { @@ -135,22 +145,53 @@ extern "C" fn tokenizers_encode( input_cstr: *const u8, len: usize, add_special_tokens: i32, + out_result: *mut TokenizerEncodeResult, ) { unsafe { let input_data = std::str::from_utf8(std::slice::from_raw_parts(input_cstr, len)).unwrap(); - (*handle).encode(input_data, add_special_tokens != 0); + let encoded = (*handle).encode(input_data, add_special_tokens != 0); + let len = encoded.len(); + *out_result = TokenizerEncodeResult { + token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32, + len: len, + }; } } #[no_mangle] -extern "C" fn tokenizers_get_encode_ids( +extern "C" fn tokenizers_encode_batch( handle: *mut TokenizerWrapper, - out_data: *mut *mut u32, - out_len: *mut usize, + input_cstr: *const *const u8, + input_len: *const usize, + num_seqs: usize, + add_special_tokens: i32, + out_result: *mut TokenizerEncodeResult, ) { unsafe { - *out_data = (*handle).encode_ids.as_mut_ptr(); - *out_len = (*handle).encode_ids.len() + let input_data = (0..num_seqs) + .map(|i| { + std::str::from_utf8(std::slice::from_raw_parts(*input_cstr.offset(i as isize), *input_len.offset(i as isize))).unwrap() + }) + .collect::>(); + let encoded_batch = (*handle).encode_batch(input_data, add_special_tokens != 0); + for (i, encoded) in encoded_batch.into_iter().enumerate() { + let len = encoded.len(); + let result = TokenizerEncodeResult { + token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32, + len: len, + }; + *out_result.offset(i as isize) = result; + } + } +} + +#[no_mangle] +extern "C" fn tokenizers_free_encode_results(results: *mut TokenizerEncodeResult, num_seqs: usize) { + unsafe { + let slice = std::slice::from_raw_parts_mut(results, num_seqs); + for result in &mut *slice { + drop(Box::from_raw(std::slice::from_raw_parts_mut(result.token_ids, result.len))); + } } } diff --git a/src/huggingface_tokenizer.cc b/src/huggingface_tokenizer.cc index 17138e1..3657305 100644 --- a/src/huggingface_tokenizer.cc +++ b/src/huggingface_tokenizer.cc @@ -28,18 +28,44 @@ class HFTokenizer : public Tokenizer { // use i32 to be consistent with sentencepiece std::vector Encode(const std::string& text, bool add_special_tokens) { - tokenizers_encode(handle_, text.data(), text.length(), static_cast(add_special_tokens)); - const uint32_t* data; - size_t len; - tokenizers_get_encode_ids(handle_, &data, &len); - const int32_t* data_i32 = reinterpret_cast(data); - auto res = std::vector(data_i32, data_i32 + len); - return res; + TokenizerEncodeResult result; + tokenizers_encode(handle_, text.data(), text.length(), static_cast(add_special_tokens), + &result); + std::vector ret(result.token_ids, result.token_ids + result.len); + tokenizers_free_encode_results(&result, 1); + return ret; } - // use i32 to be consistent with sentencepiece - std::vector Encode(const std::string& text) final { - return Encode(text, false); + // use i32 to be consistent with sentencepiece + std::vector Encode(const std::string& text) final { + return Encode(text, false); + } + + std::vector> EncodeBatch(const std::vector& texts, bool add_special_tokens) final { + std::vector texts_raw; + std::vector seq_lens; + size_t num_seqs = texts.size(); + texts_raw.reserve(num_seqs); + seq_lens.reserve(num_seqs); + for (const auto& text : texts) { + texts_raw.push_back(text.data()); + seq_lens.push_back(text.length()); + } + std::vector results(num_seqs); + tokenizers_encode_batch(handle_, texts_raw.data(), seq_lens.data(), texts.size(), + static_cast(add_special_tokens), results.data()); + std::vector> ret; + ret.reserve(texts.size()); + for (size_t i = 0; i < texts.size(); ++i) { + ret.push_back( + std::vector(results[i].token_ids, results[i].token_ids + results[i].len)); + } + tokenizers_free_encode_results(results.data(), texts.size()); + return ret; + } + + std::vector> EncodeBatch(const std::vector& texts) final { + return EncodeBatch(texts, false); } // use i32 to be consistent with sentencepiece