Skip to content

incorrect special_tokens_mask #35897

@MostHumble

Description

@MostHumble

System Info

  • transformers version: 4.47.1
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Python version: 3.11.11
  • Huggingface_hub version: 0.27.1
  • Safetensors version: 0.5.2
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu121 (True)
  • Tensorflow version (GPU?): 2.17.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.10.2 (gpu)
  • Jax version: 0.4.33
  • JaxLib version: 0.4.33
  • Using distributed or parallel set-up in script?: False
  • Using GPU in script?: False
  • GPU type: Tesla T4

Who can help?

@ArthurZucker @itazap

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "HuggingFaceTB/SmolLM2-135M" # same behavior with gpt2
device = "cuda" # for GPU usage or "cpu" for CPU usage
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
outputs = tokenizer("<|endoftext|><|im_start|>y", return_tensors="pt", return_special_tokens_mask=True)

outputs:

{'input_ids': tensor([[  0,   1, 105]]), ..., 'special_tokens_mask': tensor([[0, 0, 0]])}

Expected behavior

outputs:

{'input_ids': tensor([[  0,   1, 105]]), ..., 'special_tokens_mask': tensor([[1, 1, 0]])}

given that:

tokenizer.special_tokens_map
{'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'unk_token': '<|endoftext|>',
 'additional_special_tokens': ['<|endoftext|>',
  '<|im_start|>',
  '<|im_end|>',
...

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions