From 9858195481e0d29e9b720705d359f98620680a06 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 12 Apr 2023 18:10:04 +0200 Subject: [PATCH] add fast support and option (#22724) * add fast support and option * update based on review * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/llama/convert_llama_weights_to_hf.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * nit * add print * fixup --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../llama/convert_llama_weights_to_hf.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index 3dc6c7d697004..9a0a2e672ff77 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -17,12 +17,22 @@ import math import os import shutil +import warnings import torch from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +try: + from transformers import LlamaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + LlamaTokenizerFast = None + """ Sample usage: @@ -232,9 +242,10 @@ def permute(w): def write_tokenizer(tokenizer_path, input_tokenizer_path): - print(f"Fetching the tokenizer from {input_tokenizer_path}.") # Initialize the tokenizer based on the `spm` model - tokenizer = LlamaTokenizer(input_tokenizer_path) + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + print("Saving a {tokenizer_class} to {tokenizer_path}") + tokenizer = tokenizer_class(input_tokenizer_path) tokenizer.save_pretrained(tokenizer_path) @@ -259,10 +270,8 @@ def main(): input_base_path=os.path.join(args.input_dir, args.model_size), model_size=args.model_size, ) - write_tokenizer( - tokenizer_path=args.output_dir, - input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"), - ) + spm_path = os.path.join(args.input_dir, "tokenizer.model") + write_tokenizer(args.output_dir, spm_path) if __name__ == "__main__":