Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaMATokenizerFast works abnormally #23818

Closed
4 tasks
jiangwangyi opened this issue May 27, 2023 · 14 comments · Fixed by #24042 or #23909
Closed
4 tasks

LLaMATokenizerFast works abnormally #23818

jiangwangyi opened this issue May 27, 2023 · 14 comments · Fixed by #24042 or #23909

Comments

@jiangwangyi
Copy link
Contributor

jiangwangyi commented May 27, 2023

System Info

platform==Ubuntu18.04
python==3.10
transformers==4.29.2

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

</s> is the special token of LLaMATokenizer(Fast), it is expected that </s> can be recognized as a single token when encoding the text. However, it can be shown that the two tokenizers behave differently:

>>> t1 = transformers.AutoTokenizer.from_pretrained("huggyllama/llama-7b", use_fast=True)
>>> t2 = transformers.AutoTokenizer.from_pretrained("huggyllama/llama-7b", use_fast=False)
>>> text = "I love you.</s>"
>>> t1(text)
>>> {'input_ids': [1, 306, 5360, 366, 21106, 29879, 29958], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
>>> t2(text)
>>> {'input_ids': [1, 306, 5360, 366, 29889, 2], 'attention_mask': [1, 1, 1, 1, 1, 1]}

also, LLaMATokenizerFast returns token_type_ids but LLaMATokenizer does not.

Expected behavior

LLaMATokenizerFast to be consistent with LLaMATokenzier.

@jiangwangyi jiangwangyi changed the title LLaMATokenizerFast cannot recognize special tokens when encoding LLaMATokenizerFast works abnormally May 30, 2023
@NielsRogge
Copy link
Contributor

NielsRogge commented May 30, 2023

Also have 2 questions related to LlamaTokenizerFast:

First, loading a fast tokenizer from a saved slow one takes very long:

from transformers import LlamaTokenizer, LlamaTokenizerFast

tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.save_pretrained(".")

# the following line takes > 1 min
fast_tokenizer = LlamaTokenizerFast.from_pretrained(".")

This is not the case for other tokenizers like BertTokenizerFast.

Second, for a new model I'm working on (#23460) I wonder how to get the same behaviour between slow and fast tokenizers for the following:

from transformers import LlamaTokenizer, LlamaTokenizerFast

tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", truncation_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.add_special_tokens({"bos_token": "</s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
tokenizer.add_special_tokens({"unk_token": "</s>"})

fast_tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", truncation_side="left")
fast_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
fast_tokenizer.add_special_tokens({"bos_token": "</s>"})
fast_tokenizer.add_special_tokens({"eos_token": "</s>"})
fast_tokenizer.add_special_tokens({"unk_token": "</s>"})

prompt = "What is unusual about this image?"

encoding = tokenizer(prompt, return_tensors="pt")

fast_encoding = fast_tokenizer(prompt, return_tensors="pt")

for k,v in encoding.items():
    assert torch.allclose(fast_encoding[k], v)

=> this assertion fails since the input_ids differ:

tensor([[    2,  1724,   338, 22910,  1048,   445,  1967, 29973]])
tensor([[    1,  1724,   338, 22910,  1048,   445,  1967, 29973]])

@NielsRogge
Copy link
Contributor

cc'ing @ArthurZucker and @Narsil here

@ArthurZucker
Copy link
Collaborator

Hey! Thanks for opening this issue.

  • return_token_type_ids should be set to None by default but is updated with "token_type_ids" in self.model_input_names. This is specific to the fast tokenizer, and is a known difference. I am not sure why this was added only in the fast tokenizer but it's more than 2yo!
  • The BPE models splits on (spaces), before encoding the tokens. When converting the models from slow to fast the special tokens were added to the BPE vocabulary, with a score of 0. We probably forgot to add them to the list of additional_special_tokens, which is why they are not properly split. ( quick fix: t1.additional_special_tokens = ["</s>, ... ])
  • @NielsRogge when you load a slow from a fast, it takes a long time because you need to convert the BPE sentenpiece model, which is very long. Nothing we can do about that.
  • About your second question, the best thing would be to open a new issue. Seems like it might be another slow/fast discrepency but you are not completely doing this the way the API is designed! (check that each call to add a token actively adds it!)

@jiangwangyi
Copy link
Contributor Author

Hey! Thanks for opening this issue.

  • return_token_type_ids should be set to None by default but is updated with "token_type_ids" in self.model_input_names. This is specific to the fast tokenizer, and is a known difference. I am not sure why this was added only in the fast tokenizer but it's more than 2yo!
  • The BPE models splits on (spaces), before encoding the tokens. When converting the models from slow to fast the special tokens were added to the BPE vocabulary, with a score of 0. We probably forgot to add them to the list of additional_special_tokens, which is why they are not properly split. ( quick fix: t1.additional_special_tokens = ["</s>, ... ])
  • @NielsRogge when you load a slow from a fast, it takes a long time because you need to convert the BPE sentenpiece model, which is very long. Nothing we can do about that.
  • About your second question, the best thing would be to open a new issue. Seems like it might be another slow/fast discrepency but you are not completely doing this the way the API is designed! (check that each call to add a token actively adds it!)

In the tokenizer_config.json of huggyllama/llama-7b, </s> is quite a special token (eos_token). Adding </s> to t1.additional_special_tokens does not fix the problem.

@ArthurZucker
Copy link
Collaborator

Indeed, sorry for the confusion. I added a different token <//s> with add_special_token which worked as expected ( meaning whether there was a space or not, the output was properly encode) which is why the issue most probably lies with the handling of the special tokens ( maybe we should not have added them to the voab? I'll check). I'll dig into this!

@jiangwangyi
Copy link
Contributor Author

@ArthurZucker How is the progress now?

@ArthurZucker
Copy link
Collaborator

I am still working on this, top priority! My PR did not fix it yet, so I am opening a new on just for llama and will see for the other ones.

@jiangwangyi
Copy link
Contributor Author

I am still working on this, top priority! My PR did not fix it yet, so I am opening a new on just for llama and will see for the other ones.

Thanks for working on this! I appreciate the update and look forward to getting the issue resolved.

@ArthurZucker
Copy link
Collaborator

Update: in order to fix this, the tokenizer.json should be modified: the special tokens should not be normalized (so set normalized = False. There is a more profound issue, since the slow tokenizer is not bother by that and handles this differently.

@jiangwangyi
Copy link
Contributor Author

jiangwangyi commented Jun 11, 2023

@ArthurZucker
My transformer version is 4.30.1. I do not change the tokenizer_config.json, instead I replace the default special tokens by add_special_tokens like

>>> from transformers import AutoTokenizer
>>> lt = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
>>> lt
LlamaTokenizerFast(name_or_path='huggyllama/llama-7b', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True)}, clean_up_tokenization_spaces=False)
>>> lt.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>"})
>>> lt
LlamaTokenizerFast(name_or_path='huggyllama/llama-7b', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False)
>>> lt("ok</s>")
>>> {'input_ids': [1, 3431, 829, 29879, 29958], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}

It seems that the problem still exists?

@ArthurZucker
Copy link
Collaborator

Hey, as mentioned in #23889, as well as in #24042 the tokenizer.json has to be modified. I did not have time to open pr on all models yet, but you still have normalized = True on the special tokens, which is why they are split.

@jiangwangyi
Copy link
Contributor Author

Hey, as mentioned in #23889, as well as in #24042 the tokenizer.json has to be modified. I did not have time to open pr on all models yet, but you still have normalized = True on the special tokens, which is why they are split.

As shown in your example in #23889, if I do not modify the tokenizer.json, reseting the bos_token and eos_token when initializing the fast tokenizer or using the add_special_tokens method do not work (the normalized=True attribute still exists), even if the special_tokens_dict attribute has been changed to {"bos_token": "<s>", "eos_token": "</s>"}. Is that true?

@ArthurZucker
Copy link
Collaborator

Yes. Basically, you have to correctly add the tokens when converting, ortherwise the underlying regex is not properly updated. We are thinking of adding a update_tokens feature, which would allow to modify a token that is already part of the vocab.
See the following problem:

In [2]: lt.add_special_tokens({"eos_token": AddedToken("<//s>", normalized = False)})
Out[2]: 1

In [3]: lt.encode("Another tests<//s>")
Out[3]: [1, 7280, 6987, 32000]

In [4]: lt.add_special_tokens({"eos_token": AddedToken("<//s>", normalized = True)})
Out[4]: 0

In [5]: lt.encode("Another tests<//s>")
Out[5]: [1, 7280, 6987, 32000]

In [6]: lt.add_special_tokens({"eos_token": AddedToken("<///s>", normalized = True)})
Out[6]: 1

In [7]: lt.encode("Another tests<///s>")
Out[7]: [1, 7280, 6987, 29966, 6658, 29879, 29958]

@jiangwangyi
Copy link
Contributor Author

Yes. Basically, you have to correctly add the tokens when converting, ortherwise the underlying regex is not properly updated. We are thinking of adding a update_tokens feature, which would allow to modify a token that is already part of the vocab. See the following problem:

In [2]: lt.add_special_tokens({"eos_token": AddedToken("<//s>", normalized = False)})
Out[2]: 1

In [3]: lt.encode("Another tests<//s>")
Out[3]: [1, 7280, 6987, 32000]

In [4]: lt.add_special_tokens({"eos_token": AddedToken("<//s>", normalized = True)})
Out[4]: 0

In [5]: lt.encode("Another tests<//s>")
Out[5]: [1, 7280, 6987, 32000]

In [6]: lt.add_special_tokens({"eos_token": AddedToken("<///s>", normalized = True)})
Out[6]: 1

In [7]: lt.encode("Another tests<///s>")
Out[7]: [1, 7280, 6987, 29966, 6658, 29879, 29958]

Thank you for your kind guidance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants