diff --git a/src/tokenizers/tokenizer.cpp b/src/tokenizers/tokenizer.cpp index ebbc1a506..556cadd84 100644 --- a/src/tokenizers/tokenizer.cpp +++ b/src/tokenizers/tokenizer.cpp @@ -107,13 +107,24 @@ void Tokenizer::pad_tokens(std::vector& tokens, if (final_length > out_tokens.size()) { const size_t pad_count = final_length - out_tokens.size(); - out_tokens.insert(out_tokens.end(), pad_count, PAD_TOKEN_ID); - - if (use_weights) { - out_weights.insert(out_weights.end(), pad_count, 1.0f); - } - if (use_mask) { - out_mask.insert(out_mask.end(), pad_count, 0.0f); + if (pad_left) { + out_tokens.insert(out_tokens.begin(), pad_count, PAD_TOKEN_ID); + + if (use_weights) { + out_weights.insert(out_weights.begin(), pad_count, 1.0f); + } + if (use_mask) { + out_mask.insert(out_mask.begin(), pad_count, 0.0f); + } + } else { + out_tokens.insert(out_tokens.end(), pad_count, PAD_TOKEN_ID); + + if (use_weights) { + out_weights.insert(out_weights.end(), pad_count, 1.0f); + } + if (use_mask) { + out_mask.insert(out_mask.end(), pad_count, 0.0f); + } } } }; diff --git a/src/tokenizers/tokenizer.h b/src/tokenizers/tokenizer.h index 21ba067df..e044285bb 100644 --- a/src/tokenizers/tokenizer.h +++ b/src/tokenizers/tokenizer.h @@ -14,6 +14,7 @@ class Tokenizer { std::vector special_tokens; bool add_bos_token = false; bool add_eos_token = false; + bool pad_left = false; std::string end_of_word_suffix; virtual std::string decode_token(int token_id) const = 0;