Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLaMA3] 'add_bos_token=True, add_eos_token=True' seems not taking effect #30947

Open
2 of 4 tasks
kiva12138 opened this issue May 22, 2024 · 5 comments
Open
2 of 4 tasks
Labels
Core: Tokenization Internals of the library; Tokenization. Feature request Request for a new feature

Comments

@kiva12138
Copy link

kiva12138 commented May 22, 2024

System Info

Platform = Windows
PyTorch = 2.3.0
Transformers = 4.41.0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import AutoTokenizer

LLaMAPath = '/path/to/llama3-8b'

# The following two yields the same results, all of them contains BOS token and no EOS token
tokenizer = AutoTokenizer.from_pretrained(LLaMAPath, add_bos_token=True, add_eos_token=True)
# tokenizer = AutoTokenizer.from_pretrained(LLaMAPath, add_bos_token=False, add_eos_token=False)

tokenizer.add_special_tokens({"pad_token": "<|reserved_special_token_0|>"}) 
inputs = tokenizer(['hi, how are you today?'], padding=True, return_tensors='pt')
print(inputs)

All of the statements above produce [128000, 6151, 11, 1268, 527, 499, 3432, 30]

Expected behavior

I think when using tokenizer = AutoTokenizer.from_pretrained(LLaMAPath, add_bos_token=True, add_eos_token=True), we get [128000, 6151, 11, 1268, 527, 499, 3432, 30, 128001],

when using tokenizer = AutoTokenizer.from_pretrained(LLaMAPath, add_bos_token=False, add_eos_token=False), we get [6151, 11, 1268, 527, 499, 3432, 30],

@kiva12138 kiva12138 changed the title [LLaMA] 'add_bos_token=True, add_eos_token=True' seems not taking effect [LLaMA3] 'add_bos_token=True, add_eos_token=True' seems not taking effect May 22, 2024
@eyloncaplan
Copy link

I'm having the same issue. Neither of these change the encodings:
tokenizer.add_bos_token = False
tokenizer.add_eos_token = True

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker

@amyeroberts amyeroberts added the Core: Tokenization Internals of the library; Tokenization. label May 22, 2024
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented May 23, 2024

Hey! This is related to #30607, the tokenizer for Llama3 is a PreTrainedTokenizerFast, not the LLamaTokenizer or a LlamaTokenizerFast. Though it might actually be good to support an easy way to add bos and eos. Currently what you have to do is update the TemplateProcessor which is fairly annoying (not beginner friendly).

That's something which should be handle on the tokenizers side

@ArthurZucker ArthurZucker added the Feature request Request for a new feature label May 23, 2024
@eyloncaplan
Copy link

eyloncaplan commented May 23, 2024

Hey! This is related to #30607, the tokenizer for Llama3 is a PreTrainedTokenizerFast, not the LLamaTokenizer or a LlamaTokenizerFast. Though it might actually be good to support an easy way to add bos and eos. Currently what you have to do is update the TemplateProcessor which is fairly annoying (not beginner friendly).

That's something which should be handle on the tokenizers side

@ArthurZucker I think it's called TemplateProcessing, not TemplateProcessor. For those wondering this is how I used it to get the tokenizer to put the eos token:

bos = "<|begin_of_text|>"
eos = "<|end_of_text|>"
tokenizer._tokenizer.post_processor = processors.Sequence(
    [
        processors.ByteLevel(trim_offsets=False),
        processors.TemplateProcessing(
            single=f"{bos}:0 $A:0 {eos}:0",
            pair=f"{bos}:0 $A:0 {bos}:1 $B:1 {eos}:1",
            special_tokens=[
                (bos, tokenizer.bos_token_id),
                (eos, tokenizer.eos_token_id),
            ],
        ),
    ]
)

Now I'm worried that the padding tokens won't get added properly, but that's a different issue...

@ArthurZucker
Copy link
Collaborator

Padding token is unrelated, it's added if you ask the tokenizer to pad the input!
And yes, thanks for providing the snippet @eyloncaplan 😉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Tokenization Internals of the library; Tokenization. Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

4 participants