Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Original Llama-3 tokenizer behaves differently from transformers version #31187

Closed
2 of 4 tasks
chawins opened this issue Jun 2, 2024 · 2 comments
Closed
2 of 4 tasks

Comments

@chawins
Copy link

chawins commented Jun 2, 2024

System Info

transformers==4.41.2
tiktoken==0.7.0 and 0.4.0

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The issue can be produced with the following snippet.

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
ids = tokenizer("! ! !").input_ids
print(ids)  # >> [128000, 0, 758, 758]
print(tokenizer.decode(tokenizer("! ! !").input_ids))  # >> <|begin_of_text|>!!! (this is wrong?)

# Llama-3 from https://github.com/meta-llama/llama3
# Download tokenizer.model from https://llama.meta.com/llama-downloads/
from llama import Tokenizer
tokenizer = Tokenizer("path/to/tokenizer.model")
ids = tokenizer.encode("! ! !", bos=True, eos=False)
print(ids)  # >> [128000, 0, 758, 758]
print(tokenizer.decode(ids))  # <|begin_of_text|>! ! ! (this is expected)

Expected behavior

The string after encoding and decoding back should be the same. The original tokenizer has this behavior, but not the transformers version which throws out the whitespace. Is this expected?

I know that original Llama-3's tokenizer is based on tiktoken now. Is that the reason we see this difference?

@ArthurZucker
Copy link
Collaborator

Hey! This is because of a default in transformers:
print(tokenizer.decode(tokenizer("! ! !").input_ids, clean_up_tokenization_spaces=False) )
should do the trick.
Let's set it to default False and deprecated it cc @itazap !

@chawins
Copy link
Author

chawins commented Jun 2, 2024

Awesome! I confirmed that clean_up_tokenization_spaces=False fixed the issue. Thanks a lot for pointing it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants