Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama tokenizer inconsistency for the newline character for convert_tokens_to_ids #31030

Open
2 of 4 tasks
JackCai1206 opened this issue May 25, 2024 · 4 comments
Open
2 of 4 tasks
Labels
Core: Tokenization Internals of the library; Tokenization.

Comments

@JackCai1206
Copy link

JackCai1206 commented May 25, 2024

System Info

transformers 4.41.0
torch 2.3.0
GPU: NVIDIA GeForce RTX 4090, CUDA version 12.3

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am trying to get the token id for the new line character for llama 3, and found this weird inconsistency. Basically convert_tokens_to_ids('\n') outputs None, but tokenize('\n') outputs 198. But then tokenizer.convert_ids_to_tokens(198) gives me Ċ

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", padding_side='left')

id1 = tokenizer.convert_tokens_to_ids('\n')
print(id1) # None

id2 = tokenizer('\n')['input_ids'][1]
print(id2) # 198

newline_tok = tokenizer.convert_ids_to_tokens(198)
print(newline_tok) # Ċ

id3 = tokenizer('Ċ')['input_ids'][1]
print(id3) # 128

id4 = tokenizer.convert_tokens_to_ids('Ċ')
print(id4) # 198

Expected behavior

I expected the output of convert_tokens_to_ids('\n') to be 128

@itazap
Copy link
Contributor

itazap commented May 27, 2024

Hello! 🤗

The token 'Ċ' is actually represented by 2 tokens (after the begin_of_text token):

tokenizer('Ċ')['input_ids'] # [128000, 128, 232] which are [128,232].

Token 128 corresponds to another token:

tokenizer.decode([128]) # �

So, to decode 'Ċ' you will need:

tokenizer.decode([128, 232]) # Ċ

why you are seeing 'Ċ' as opposed to '\n':

This is related to the BPE algorithm which converts 'space' tokens like newline and tab into special characters, which may be represented by multiple bytes and results in them being represented by multiple ids, such as 2 values in this case. It is explained well in this comment here and here.

I hope I answered your questions! Feel free to reply with any further questions

@amyeroberts amyeroberts added the Core: Tokenization Internals of the library; Tokenization. label May 28, 2024
@JackCai1206
Copy link
Author

JackCai1206 commented May 29, 2024

So when I use GenerationConfig, I want to initlialize it like so

GenerationConfig(
    max_new_tokens=config.max_new_tokens,
    num_return_sequences=1,
    return_dict_in_generate=True, 
    stop_strings='\n', # Doesn't work
    pad_token_id=tokenizer.pad_token_id)

The newline character as stop strings doesn't work for llama 3 because it is internally using something similar to convert_tokens_to_ids and returning None, which means the model.generate does not recognize the '\n' stop token. Right now the workaround is to do set eos_token_id=[198] in GenerationConfig, but I want to be able to use the stop_strings argument.

@itazap
Copy link
Contributor

itazap commented May 30, 2024

Can you please share a small reproducer?

@Acejoy
Copy link

Acejoy commented Jun 1, 2024

So when I use GenerationConfig, I want to initlialize it like so

GenerationConfig(
    max_new_tokens=config.max_new_tokens,
    num_return_sequences=1,
    return_dict_in_generate=True, 
    stop_strings='\n', # Doesn't work
    pad_token_id=tokenizer.pad_token_id)

The newline character as stop strings doesn't work for llama 3 because it is internally using something similar to convert_tokens_to_ids and returning None, which means the model.generate does not recognize the '\n' stop token. Right now the workaround is to do set eos_token_id=[198] in GenerationConfig, but I want to be able to use the stop_strings argument.

Hey, were you able to get it working? I too have the same issue. I want to use the stop_strings parameter for stopping generation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Tokenization Internals of the library; Tokenization.
Projects
None yet
Development

No branches or pull requests

4 participants