-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LLaMATokenizerFast works abnormally #23818
Comments
Also have 2 questions related to First, loading a fast tokenizer from a saved slow one takes very long:
This is not the case for other tokenizers like Second, for a new model I'm working on (#23460) I wonder how to get the same behaviour between slow and fast tokenizers for the following:
=> this assertion fails since the input_ids differ:
|
cc'ing @ArthurZucker and @Narsil here |
Hey! Thanks for opening this issue.
|
In the |
Indeed, sorry for the confusion. I added a different token |
@ArthurZucker How is the progress now? |
I am still working on this, top priority! My PR did not fix it yet, so I am opening a new on just for llama and will see for the other ones. |
Thanks for working on this! I appreciate the update and look forward to getting the issue resolved. |
Update: in order to fix this, the |
@ArthurZucker >>> from transformers import AutoTokenizer
>>> lt = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
>>> lt
LlamaTokenizerFast(name_or_path='huggyllama/llama-7b', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True)}, clean_up_tokenization_spaces=False)
>>> lt.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>"})
>>> lt
LlamaTokenizerFast(name_or_path='huggyllama/llama-7b', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False)
>>> lt("ok</s>")
>>> {'input_ids': [1, 3431, 829, 29879, 29958], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]} It seems that the problem still exists? |
As shown in your example in #23889, if I do not modify the |
Yes. Basically, you have to correctly add the tokens when converting, ortherwise the underlying regex is not properly updated. We are thinking of adding a In [2]: lt.add_special_tokens({"eos_token": AddedToken("<//s>", normalized = False)})
Out[2]: 1
In [3]: lt.encode("Another tests<//s>")
Out[3]: [1, 7280, 6987, 32000]
In [4]: lt.add_special_tokens({"eos_token": AddedToken("<//s>", normalized = True)})
Out[4]: 0
In [5]: lt.encode("Another tests<//s>")
Out[5]: [1, 7280, 6987, 32000]
In [6]: lt.add_special_tokens({"eos_token": AddedToken("<///s>", normalized = True)})
Out[6]: 1
In [7]: lt.encode("Another tests<///s>")
Out[7]: [1, 7280, 6987, 29966, 6658, 29879, 29958] |
Thank you for your kind guidance! |
System Info
platform==Ubuntu18.04
python==3.10
transformers==4.29.2
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
</s>
is the special token of LLaMATokenizer(Fast), it is expected that</s>
can be recognized as a single token when encoding the text. However, it can be shown that the two tokenizers behave differently:also, LLaMATokenizerFast returns
token_type_ids
but LLaMATokenizer does not.Expected behavior
LLaMATokenizerFast to be consistent with LLaMATokenzier.
The text was updated successfully, but these errors were encountered: