Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special token handling breaks idempotency of sentencepiece due to extra spaces #31513

Open
cat-state opened this issue May 9, 2024 · 29 comments · May be fixed by #33988
Open

Special token handling breaks idempotency of sentencepiece due to extra spaces #31513

cat-state opened this issue May 9, 2024 · 29 comments · May be fixed by #33988
Labels
Core: Tokenization Internals of the library; Tokenization.

Comments

@cat-state
Copy link

cat-state commented May 9, 2024

Sentenpiece tokenizers have the property that Decode(Encode(Normalize(input))) == Normalize(input).. This property is very useful when combining and re-inferring prompts. However, when used through tokenizers with special tokens added for BOS/EOS etc, tokenizers will inject an extra space around special tokens when decoding - i.e, <s>A will become <s> A, which when encoded and decoded will become <s> A, <s> A, etc.

A previous issue was raised about this but incorrectly closed as intended behavior/unfixable: huggingface/tokenizers#1237 . Although not all tokenizers have this property, sentencepiece is very widely used now due to llama and mistral so it would make sense for this behavior to be preserved.

There could be two fixes for this: either not add the extra space, or tokenize <s> A the same as <s>A (i think could be accomplished by changing the AddedToken params for these tokens.

@ArthurZucker
Copy link
Collaborator

Do you have a reproducer?
I'd love to fix it, but I'm not sure this is still happening

@ArthurZucker
Copy link
Collaborator

Llama based tokenizer don't have this issue anymore and was fixed by the metaspace refactoring.

@ArthurZucker
Copy link
Collaborator

Are you using legacy=False (mistral does not)

@ArthurZucker
Copy link
Collaborator

Also the snipper shared:

from transformers import LlamaTokenizer
model_id = "lmsys/vicuna-13b-delta-v1.1"
tokenizer = LlamaTokenizer.from_pretrained(model_id, add_bos_token = False, )
message = "<s>hello</s>"
decoded = tokenizer.decode(tokenizer(message)['input_ids'])
print(decoded, decoded == message)

this is on transformers side. Not tokenizers. I'll open a PR right away, it's super weird that it was not caught up until now

Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@ArthurZucker
Copy link
Collaborator

Update:

In [2]: tokenizer.tokenize(message)
Out[2]: ['<s>', '▁hello', '</s>']

This is kind of expected: we add a prefix space at the beginning.
I think we can try to fix it, but it might break a lot of stuff 😢

@ArthurZucker ArthurZucker transferred this issue from huggingface/tokenizers Jun 20, 2024
@ArthurZucker
Copy link
Collaborator

cc @itazap

@itazap
Copy link
Collaborator

itazap commented Jun 21, 2024

Hi! #31315 will fix this with legacy=False, add_prefix_space=False. Will comment when merged

tokenizer = LlamaTokenizer.from_pretrained(model_id, add_bos_token = False, legacy=False, add_prefix_space=False)```

@huggingface huggingface deleted a comment from github-actions bot Jul 16, 2024
@huggingface huggingface deleted a comment from github-actions bot Aug 12, 2024
@amyeroberts amyeroberts added the Core: Tokenization Internals of the library; Tokenization. label Aug 12, 2024
@vince62s
Copy link

vince62s commented Sep 9, 2024

Hi,
While working with Tower-Instruct-7B-v0.2 I am having the same issue I think:

import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Unbabel/TowerInstruct-7B-v0.2", padding_side='left', legacy=False)
prompt = f"<|im_start|>user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|>\n<|im_start|>assistant\n"

input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=256, truncation=True).input_ids.cuda()
print(input_ids)
print(prompt)
outputs = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
print(outputs)

Gives this, with an extra space after the added_tokens <|im_start|> and <|im_end|>

tensor([[    1, 32006,  1404,    13,  4300,  9632,   278,  1494,  1426,   515,
          4223,   964,  5332, 29889,    13, 24636, 29901, 15043,  3186,    13,
         29954,  3504, 29901, 32005, 29871,    13, 32006, 20255,    13]],
       device='cuda:0')
<|im_start|>user
Translate the following text from English into German.
English: Hello world
German:<|im_end|>
<|im_start|>assistant

['<s><|im_start|> user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|> \n<|im_start|> assistant\n']

if so, will the PR be merged shortly ?

Thanks.

@itazap
Copy link
Collaborator

itazap commented Sep 9, 2024

Hey @vince62s 😊 ! passing add_prefix_space=False should address this:

tokenizer = AutoTokenizer.from_pretrained("Unbabel/TowerInstruct-7B-v0.2", padding_side='left', legacy=False, add_prefix_space=False)

# Output
'<s><|im_start|>user
Translate the following text from English into German.
English: Hello world
German:<|im_end|>
<|im_start|>assistant
'

Would this be suitable for your use-case?

@vince62s
Copy link

vince62s commented Sep 9, 2024

well, I thought this setting was part of the unmerged #31315 but there is some strange behavior.
as soon as I add add_prefix_space=True or False indeed the space disappear, is this expected ?
another side effect is that "assistant" is then broken down in two pieces because the vocab has "_assistant" but not "assistant".

@itazap
Copy link
Collaborator

itazap commented Sep 10, 2024

They should have different behaviours, but for False it is correct that the space disappears. I think if assistant is not in the vocab and there is no space before it in your prompt, it should be split up (compared this with using slow / sentencepiece). Would you agree?

@vince62s
Copy link

vince62s commented Oct 4, 2024

@ArthurZucker @itazap

There is really something strange with the tokenizer behavior.

Using Unbabel/TowerInstruct-7B-v0.2 sentencepiece tokenizer, which is the llama2 one.

With no add_prefix_space flag, I am getting the following:

import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Unbabel/TowerInstruct-7B-v0.2", padding_side='left')
prompt = f"<|im_start|>user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=256, truncation=True).input_ids.cuda()
print(input_ids)
print(prompt)
outputs = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
print(outputs)
tensor([[    1, 32006,  1404,    13,  4300,  9632,   278,  1494,  1426,   515,
          4223,   964,  5332, 29889,    13, 24636, 29901, 15043,  3186,    13,
         29954,  3504, 29901, 32005, 29871,    13, 32006, 20255,    13]],
       device='cuda:0')
<|im_start|>user
Translate the following text from English into German.
English: Hello world
German:<|im_end|>
<|im_start|>assistant

['<s><|im_start|> user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|> \n<|im_start|> assistant\n']

So you see the space added 3 times before "user", between "<|im_end|>" and "\n", and before "assistant".

As said before, if we add the flag add_prefix_space with False or True then the space disappear.

import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Unbabel/TowerInstruct-7B-v0.2", padding_side='left', add_prefix_space=True)
prompt = f"<|im_start|>user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=256, truncation=True).input_ids.cuda()
print(input_ids)
print(prompt)
outputs = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
print(outputs)
tensor([[    1, 32006,  1792,    13,  4300,  9632,   278,  1494,  1426,   515,
          4223,   964,  5332, 29889,    13, 24636, 29901, 15043,  3186,    13,
         29954,  3504, 29901, 32005,    13, 32006,   465, 22137,    13]],
       device='cuda:0')
<|im_start|>user
Translate the following text from English into German.
English: Hello world
German:<|im_end|>
<|im_start|>assistant

['<s><|im_start|>user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|>\n<|im_start|>assistant\n']

You can note that the tokens are not the same (1792=user instead of 1404=_user, and assistant broken in 465, 22137 instead of _aasistant=20255

=> why the same behavior with False or True

NOW
with the utter-project/EuroLLM-1.7B-Instruct which uses another sentencepiece model.
Without the flag, again the space is added:

import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("utter-project/EuroLLM-1.7B-Instruct", padding_side='left')
prompt = f"<|im_start|>user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=256, truncation=True).input_ids.cuda()
print(input_ids)
print(prompt)
outputs = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
print(outputs)
tensor([[     1,      3,  15236,    271,  31702,  31817,    557,   5302,   6001,
           1061,   6771,   2023,   5256, 119735,    271,  31601, 119782,  97849,
           4437,    271,  60457, 119782,      4, 119715,    271,      3,  58406,
            271]], device='cuda:0')
<|im_start|>user
Translate the following text from English into German.
English: Hello world
German:<|im_end|>
<|im_start|>assistant

['<s><|im_start|> user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|> \n<|im_start|> assistant\n']

With the flag=True, the space is added (which is not the same behavior with the llama2 one)
With the flag=False I am getting the correct behavior:

import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("utter-project/EuroLLM-1.7B-Instruct", padding_side='left', add_prefix_space=False)
prompt = f"<|im_start|>user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=256, truncation=True).input_ids.cuda()
print(input_ids)
print(prompt)
outputs = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
print(outputs)
tensor([[     1,      3,  13676,    271,  31702,  31817,    557,   5302,   6001,
           1061,   6771,   2023,   5256, 119735,    271,  31601, 119782,  97849,
           4437,    271,  60457, 119782,      4,    271,      3,    788,  35441,
            271]], device='cuda:0')
<|im_start|>user
Translate the following text from English into German.
English: Hello world
German:<|im_end|>
<|im_start|>assistant

['<s><|im_start|>user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|>\n<|im_start|>assistant\n']

Again tokens are not the same of course.

Can someone clarify exactly what is going on ?

Thanks

@ArthurZucker
Copy link
Collaborator

=> why the same behavior with False or True
I think there is confusion here. add_prefix_space should only add a prefix space at the beginning of the sentence. If the first token is a special token, it wont do it. Thus you have the same behaviour

For your second case, you should not use batch_decode or if you do, use clean_up_tokenization_space=False ( a warning should pop up).

@vince62s
Copy link

vince62s commented Oct 5, 2024

still unclear. First model (the llama2 one) then why it triggers a different token for "user" when not using the flag vs using the flag ?

@ArthurZucker
Copy link
Collaborator

add_prefix_space affects both encoding and decoding.
Your "user" is not at the beginning of the sentence. Thus if add_prefix_space works as expected, the space should be added if you call tokenizer.encode("user") with True and not appear with False.

Now, it also affects decoding. Basically the decoding removes the added space. But that means decoding tokenizer.encode("user") is expected to give the same results!

@ArthurZucker
Copy link
Collaborator

The prompt is: f"<|im_start|>user\nTranslate the following text from English into German.\nEnglish: Hello world\nGerman:<|im_end|>\n<|im_start|>assistant\n"
so there is no way we are adding an extra space, unless you have the bug from the old Metaspace tokenizer! 🤗

When you set add_prefix_space, the bug is fixed by forcing the pretokenizer. Print it in those cases and you'll see that they are different!

@vince62s
Copy link

vince62s commented Oct 5, 2024

I am not even talking about decoding at this point, just encoding.

@ArthurZucker
Copy link
Collaborator

Your script has outputs = tokenizer.batch_decode(input_ids, skip_special_tokens=False) so I thought this was part of the confusion ! (one has a space the other no space),

To be honest most of the issues are because of the legacy path that is default. or am I missing something?

@vince62s
Copy link

vince62s commented Oct 6, 2024

ok keeping legacy=False,

In [82]: import torch
    ...: import os
    ...: from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
    ...: 
    ...: hfpath="Unbabel/TowerInstruct-7B-v0.2"
    ...: localpath="/mnt/InternalCrucial4/LLM_work/TowerInstruct-7b-v0.2/tokenizer.model"
    ...: 
    ...: for prefix in [True, False, None]:
    ...:     for legacy in [False]: #, True, None]:
    ...:         tokenizer = AutoTokenizer.from_pretrained(hfpath, padding_side='left', add_prefix_space=prefix, legacy=legacy)
    ...:         prompt = f"<|im_start|>system\n<|im_end|>\n<|im_start|>user\n"
    ...:         input_ids = tokenizer(prompt, return_tensors="pt", padding=True, max_length=256, truncation=True).input_ids.cuda()
    ...:         print(prefix, legacy)
    ...:         print(input_ids)
    ...: 
True False
tensor([[    1, 32006,  5205,    13, 32005,    13, 32006,  1792,    13]],
       device='cuda:0')
False False
tensor([[    1, 32006,  5205,    13, 32005,    13, 32006,  1792,    13]],
       device='cuda:0')
None False
tensor([[    1, 32006,  1788,    13, 32005, 29871,    13, 32006,  1404,    13]],
       device='cuda:0')

How to explain the choice of 5205 or 1788 for system, ▁system
and at the end 1792, 1404 for user, ▁user
when add_prefix_space is not set or set

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Oct 6, 2024

When you set add_prefix_space to None, it falls into a case where it defaults to legacy.
This is where we have a "bug" as you se legacy to False, but you are not getting legacy behaviour (because a fast version already exists, so from_slow is not used.

Again you should always print tokenizer._tokenizer you'll see that there is a difference there (1 and 2 have MetaSpace as a pre_tokenizer, while 3. has no pre_tokenizer but a Prepend normalizer.

@ArthurZucker
Copy link
Collaborator

Does that make more sense for you? (linked PR)

@vince62s
Copy link

vince62s commented Oct 6, 2024

Does that make more sense for you? (linked PR)

if you PR like this, my understanding is that you will break the behavior of a model like this one: https://huggingface.co/utter-project/EuroLLM-1.7B-Instruct/discussions/6
"break" maybe not, but make a discrepancy between training and inference.

@ArthurZucker
Copy link
Collaborator

I am not sure I understand what you expect from me at this point 😅 I can't "fix" the fact that the issue contaminated the initial training!

@vince62s
Copy link

vince62s commented Oct 7, 2024

My point is that it is still very unclear. We are talking about LLama butlook at Mistral with sentencepiece (not tekken):

import torch
import os
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

#hfpath="Unbabel/TowerInstruct-7B-v0.2"
#localpath="/mnt/InternalCrucial4/LLM_work/TowerInstruct-7b-v0.2/tokenizer.model"
#hfpath="utter-project/EuroLLM-1.7B-Instruct"
#localpath="/mnt/InternalCrucial4/LLM_work/eurollm-1.7b-instruct/tokenizer.model"
#hfpath="meta-llama/Llama-3.1-8B-Instruct"
hfpath="mistralai/Mistral-7B-Instruct-v0.3"

for prefix in [True, False, None]:
    for legacy in [False, True, None]:
        tokenizer = AutoTokenizer.from_pretrained(hfpath, padding_side='left', add_prefix_space=prefix, legacy=legacy) 
        #prompt = f"<|im_start|> system\n<|im_end|>\n<|im_start|> user\n"
        #prompt = f"<|start_header_id|>user<|end_header_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"
        prompt = f"[INST]user[/INST]\n[INST]assistant[/INST]\n"
        input_ids = tokenizer(prompt, return_tensors="pt", padding=False, max_length=256, truncation=True).input_ids.cuda()
        print(prefix, legacy)
        print(input_ids)

It will give you:

True False
tensor([[    1,     3,  2606,     4,   781,     3,  1257, 11911,     4,   781]],
       device='cuda:0')
True True
tensor([[    1,     3,  2956,     4, 29473,   781,     3, 14660,     4, 29473,
           781]], device='cuda:0')
True None
tensor([[    1,     3,  2956,     4, 29473,   781,     3, 14660,     4, 29473,
           781]], device='cuda:0')
False False
tensor([[    1,     3,  2606,     4,   781,     3,  1257, 11911,     4,   781]],
       device='cuda:0')
False True
tensor([[    1,     3,  2606,     4,   781,     3,  1257, 11911,     4,   781]],
       device='cuda:0')
False None
tensor([[    1,     3,  2606,     4,   781,     3,  1257, 11911,     4,   781]],
       device='cuda:0')
None False
tensor([[    1,     3,  2606,     4,   781,     3,  1257, 11911,     4,   781]],
       device='cuda:0')
None True
tensor([[    1,     3,  2606,     4,   781,     3,  1257, 11911,     4,   781]],
       device='cuda:0')
None None
tensor([[    1,     3,  2606,     4,   781,     3,  1257, 11911,     4,   781]],
       device='cuda:0')

Again 2606 = "user", 2956="_user", and [1257, 11911]=[ass, istant] and 14660="_assistant"

When you look at the legacy code of Mistral here:
https://github.com/mistralai/mistral-common/blob/main/tests/test_tokenize_v2.py#L16-L29
it looks like they use the "_user" and "_assistant" when the token is sticked to the special token.

Am I clearer in the explanation of the issue ?

EDIT:
I am reading above one of your comment saying Mistral uses legacy=True
But even looking only at the lines for legacy=True,
when add_prefix_space is set to True then 2956="_user" and 14660="_assistant" are used (same behavior as Mistral repo)
when add_prefix_space is set to False (or None before your PR) then 2606="user", and [1257, 11911] are used.

Is there somewhere a patch that forces add_prefix_space to True for Mistral ?

@pandora-s-git
Copy link

We have this document that may be of help to understand this issue: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md

It explains why exactly each template is slightly different.

@vince62s
Copy link

vince62s commented Oct 8, 2024

thanks @pandora-s-git this doc is very clear. However the fact that in some versions the word "user" becomes "_user" and in some other "user" and same for "assistant" becoming "_assistant" or ("ass", "istant") for the same model, does it trigger a big difference in quality in the end ? (since my understanding is that MistralInstruct-v0.3 supports both V2 and V3 tokenizers)

@pandora-s-git
Copy link

pandora-s-git commented Oct 8, 2024

From experience it can have a huge impact on completion, specifically, lets say I provide the model (tokenizer v2 or v3):
<s>[INST]_user_message[/INST]_assistant_message</s>[INST]_new_user_message[/INST]
The model was trained with this format, so it will want to output a token starting with a white space, if we force the model to start with a white space ourselves:
<s>[INST]_user_message[/INST]_assistant_message</s>[INST]_new_user_message[/INST]_
This can have a considerable impact on the distribution, cause its a situation it should never occur on the training data for example.

I hope this answers your question, but by experience these white spaces have a lot more importance than one may think, and be careful with Tekken, the reason V3-Tekken is considered a V3 template its because the implementation with Mistral Common of the template is the exact same one as the normal V3, the only difference being that V3 uses sentencepiece, and V3-Tekken uses Tiktoken, but this difference actually impacts the template itself if we use the string representatins, becoming:
<s>[INST]user_message[/INST]assistant_message</s>[INST]new_user_message[/INST]

And the tokenizer vocab being completely different it will of course tokenize differently.

V2 and V3 (sentence piece) tokenizers are very very similar, the only difference between them is with the tool calling.

@vince62s
Copy link

vince62s commented Oct 8, 2024

If I may, but this is valid for all models in fact, it would be great to post in the model card the token ID of an expected prompt with special token so that one can verify that the HF flags are set correctly when it comes to both finetuning / inference. I have the impression there is a huge overlook of this issue. Anyway, thanks for your answers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Tokenization Internals of the library; Tokenization.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants