Skip to content

Commit d3ee5e8

Browse files
vasquArthurZucker
andcommitted
[Mistral Tokenizers] Fix tokenizer detection (#42389)
* fix * sanity check * style * comments * make it v5 explicit * make explicit fixes possible in local tokenizers * remove hub usage on local * fix * extend test for no config case * move mistral patch outside to separate fn * fix local path only * add a tes * make sure test does not pass before this PR * styling * make sure it exists * fix * fix * rename * up * last nit i hope lord --------- Co-authored-by: Arthur <arthur.zucker@gmail.com>
1 parent 2915fb3 commit d3ee5e8

File tree

4 files changed

+149
-33
lines changed

4 files changed

+149
-33
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 83 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,12 +2046,13 @@ def from_pretrained(
20462046
template = template.removesuffix(".jinja")
20472047
vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
20482048

2049+
remote_files = []
20492050
if not is_local and not local_files_only:
20502051
try:
20512052
remote_files = list_repo_files(pretrained_model_name_or_path)
20522053
except Exception:
20532054
remote_files = []
2054-
else:
2055+
elif pretrained_model_name_or_path and os.path.isdir(pretrained_model_name_or_path):
20552056
remote_files = os.listdir(pretrained_model_name_or_path)
20562057

20572058
if "tokenizer_file" in vocab_files and not re.search(vocab_files["tokenizer_file"], "".join(remote_files)):
@@ -2385,57 +2386,108 @@ def _from_pretrained(
23852386
except NotImplementedError:
23862387
vocab_size = 0
23872388

2389+
# Optionally patches mistral tokenizers with wrong regex
23882390
if (
23892391
vocab_size > 100000
23902392
and hasattr(tokenizer, "_tokenizer")
23912393
and getattr(tokenizer._tokenizer, "pre_tokenizer", None) is not None
23922394
):
2393-
from huggingface_hub import model_info
2395+
tokenizer = cls._patch_mistral_regex(
2396+
tokenizer,
2397+
pretrained_model_name_or_path,
2398+
token=token,
2399+
cache_dir=cache_dir,
2400+
local_files_only=local_files_only,
2401+
_commit_hash=_commit_hash,
2402+
_is_local=_is_local,
2403+
init_kwargs=init_kwargs,
2404+
fix_mistral_regex=kwargs.get("fix_mistral_regex"),
2405+
)
23942406

2395-
def is_base_mistral(model_id: str) -> bool:
2396-
model = model_info(model_id)
2397-
if model.tags is not None:
2398-
if re.search("base_model:.*mistralai", "".join(model.tags)):
2399-
return True
2400-
return False
2407+
return tokenizer
24012408

2402-
if _is_local or is_base_mistral(pretrained_model_name_or_path):
2403-
_config_file = cached_file(
2404-
pretrained_model_name_or_path,
2405-
"config.json",
2406-
cache_dir=cache_dir,
2407-
token=token,
2408-
local_files_only=local_files_only,
2409-
_raise_exceptions_for_missing_entries=False,
2410-
_raise_exceptions_for_connection_errors=False,
2411-
_commit_hash=_commit_hash,
2412-
)
2413-
if _config_file is not None:
2414-
with open(_config_file, encoding="utf-8") as f:
2415-
_config = json.load(f)
2416-
transformers_version = _config.get("transformers_version")
2409+
@classmethod
2410+
def _patch_mistral_regex(
2411+
cls,
2412+
tokenizer,
2413+
pretrained_model_name_or_path,
2414+
token=None,
2415+
cache_dir=None,
2416+
local_files_only=False,
2417+
_commit_hash=None,
2418+
_is_local=False,
2419+
init_kwargs=None,
2420+
fix_mistral_regex=None,
2421+
):
2422+
"""
2423+
Patches mistral related tokenizers with incorrect regex if detected
2424+
1) Local file with an associated config saved next to it
2425+
>> Model type one of the mistral models (on older versions)
2426+
2) Remote models on the hub from official mistral models
2427+
>> Tags including `base_model:.*mistralai`
2428+
"""
2429+
from huggingface_hub import model_info
24172430

2418-
if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
2419-
if _is_local and _config.model_type not in [
2431+
def is_base_mistral(model_id: str) -> bool:
2432+
model = model_info(model_id)
2433+
if model.tags is not None:
2434+
if re.search("base_model:.*mistralai", "".join(model.tags)):
2435+
return True
2436+
return False
2437+
2438+
if _is_local or is_base_mistral(pretrained_model_name_or_path):
2439+
_config_file = cached_file(
2440+
pretrained_model_name_or_path,
2441+
"config.json",
2442+
cache_dir=cache_dir,
2443+
token=token,
2444+
local_files_only=local_files_only,
2445+
_raise_exceptions_for_missing_entries=False,
2446+
_raise_exceptions_for_connection_errors=False,
2447+
_commit_hash=_commit_hash,
2448+
)
2449+
2450+
# Detected using a (local) mistral tokenizer
2451+
mistral_config_detected = False
2452+
if _config_file is not None:
2453+
with open(_config_file, encoding="utf-8") as f:
2454+
_config = json.load(f)
2455+
transformers_version = _config.get("transformers_version")
2456+
transformers_model_type = _config.get("model_type")
2457+
2458+
# Detect if we can skip the mistral fix by
2459+
# a) having a non-mistral tokenizer
2460+
# b) fixed version of transformers
2461+
if transformers_version and version.parse(transformers_version) <= version.parse("4.57.2"):
2462+
if (
2463+
_is_local
2464+
and transformers_model_type is not None
2465+
and transformers_model_type
2466+
not in [
24202467
"mistral",
24212468
"mistral3",
2422-
"voxstral",
2469+
"voxtral",
24232470
"ministral",
24242471
"pixtral",
2425-
]:
2426-
return tokenizer
2472+
]
2473+
):
2474+
return tokenizer
2475+
elif transformers_version and version.parse(transformers_version) >= version.parse("5.0.0"):
2476+
return tokenizer
24272477

2478+
mistral_config_detected = True
2479+
2480+
if mistral_config_detected or (not _is_local and is_base_mistral(pretrained_model_name_or_path)):
24282481
# Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied.
2429-
if "fix_mistral_regex" in init_kwargs:
2482+
if init_kwargs and "fix_mistral_regex" in init_kwargs:
24302483
setattr(tokenizer, "fix_mistral_regex", init_kwargs["fix_mistral_regex"])
24312484

2432-
fix_mistral_regex = kwargs.get("fix_mistral_regex") # not init kwargs
24332485
# only warn if its not explicitly passed
24342486
if fix_mistral_regex is None and not getattr(tokenizer, "fix_mistral_regex", False):
24352487
setattr(tokenizer, "fix_mistral_regex", False)
24362488
logger.warning(
24372489
f"The tokenizer you are loading from '{pretrained_model_name_or_path}'"
2438-
f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. "
2490+
f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e."
24392491
" This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue."
24402492
)
24412493
elif fix_mistral_regex is True or getattr(tokenizer, "fix_mistral_regex", False):
@@ -2448,7 +2500,6 @@ def is_base_mistral(model_id: str) -> bool:
24482500
),
24492501
behavior="isolated",
24502502
)
2451-
24522503
return tokenizer
24532504

24542505
@staticmethod

tests/models/auto/test_tokenization_auto.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,13 @@
3434
GPT2Tokenizer,
3535
GPT2TokenizerFast,
3636
PreTrainedTokenizerFast,
37+
Qwen2Tokenizer,
38+
Qwen2TokenizerFast,
39+
Qwen3MoeConfig,
3740
RobertaTokenizer,
3841
RobertaTokenizerFast,
3942
is_tokenizers_available,
43+
logging,
4044
)
4145
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
4246
from transformers.models.auto.tokenization_auto import (
@@ -49,6 +53,7 @@
4953
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
5054
DUMMY_UNKNOWN_IDENTIFIER,
5155
SMALL_MODEL_IDENTIFIER,
56+
CaptureLogger,
5257
RequestCounter,
5358
is_flaky,
5459
require_tokenizers,
@@ -229,6 +234,40 @@ def test_auto_tokenizer_from_local_folder(self):
229234
self.assertIsInstance(tokenizer2, tokenizer.__class__)
230235
self.assertEqual(tokenizer2.vocab_size, 12)
231236

237+
def test_auto_tokenizer_from_local_folder_mistral_detection(self):
238+
"""See #42374 for reference, ensuring proper mistral detection on local tokenizers"""
239+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-235B-A22B-Thinking-2507")
240+
config = Qwen3MoeConfig.from_pretrained("Qwen/Qwen3-235B-A22B-Thinking-2507")
241+
self.assertIsInstance(tokenizer, (Qwen2Tokenizer, Qwen2TokenizerFast))
242+
243+
with tempfile.TemporaryDirectory() as tmp_dir:
244+
tokenizer.save_pretrained(tmp_dir)
245+
246+
# Case 1: Tokenizer with no config associated
247+
logger = logging.get_logger("transformers.tokenization_utils_base")
248+
with CaptureLogger(logger) as cl:
249+
AutoTokenizer.from_pretrained(tmp_dir)
250+
self.assertNotIn(
251+
"with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e",
252+
cl.out,
253+
)
254+
255+
# Case 2: Tokenizer with config associated
256+
# Needed to be saved along the tokenizer to detect (non)mistral
257+
# for a version where the regex bug occurs
258+
config_dict = config.to_diff_dict()
259+
config_dict["transformers_version"] = "4.57.2"
260+
261+
# Manually saving to avoid versioning clashes
262+
config_path = os.path.join(tmp_dir, "config.json")
263+
with open(config_path, "w", encoding="utf-8") as f:
264+
json.dump(config_dict, f, indent=2, sort_keys=True)
265+
266+
tokenizer2 = AutoTokenizer.from_pretrained(tmp_dir)
267+
268+
self.assertIsInstance(tokenizer2, tokenizer.__class__)
269+
self.assertTrue(tokenizer2.vocab_size > 100_000)
270+
232271
def test_auto_tokenizer_fast_no_slow(self):
233272
tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
234273
# There is no fast CTRL so this always gives us a slow tokenizer.

tests/models/llama/test_tokenization_llama.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@
5151
@require_sentencepiece
5252
@require_tokenizers
5353
class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
54-
from_pretrained_id = ["hf-internal-testing/llama-tokenizer", "meta-llama/Llama-2-7b-hf"]
54+
from_pretrained_id = [
55+
"hf-internal-testing/llama-tokenizer",
56+
"meta-llama/Llama-2-7b-hf",
57+
"meta-llama/Meta-Llama-3-8B",
58+
]
5559
tokenizer_class = LlamaTokenizer
5660
rust_tokenizer_class = LlamaTokenizerFast
5761

tests/test_tokenization_common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4730,3 +4730,25 @@ def test_empty_input_string(self):
47304730
for return_type, target_type in zip(tokenizer_return_type, output_tensor_type):
47314731
output = tokenizer(empty_input_string, return_tensors=return_type)
47324732
self.assertEqual(output.input_ids.dtype, target_type)
4733+
4734+
def test_local_files_only(self):
4735+
from transformers import AutoTokenizer
4736+
4737+
pretrained_list = getattr(self, "from_pretrained_id", []) or []
4738+
for pretrained_name in pretrained_list:
4739+
with self.subTest(f"AutoTokenizer ({pretrained_name})"):
4740+
# First cache the tokenizer files
4741+
try:
4742+
tokenizer_cached = AutoTokenizer.from_pretrained(pretrained_name)
4743+
4744+
# Now load with local_files_only=True
4745+
tokenizer_local = AutoTokenizer.from_pretrained(pretrained_name, local_files_only=True)
4746+
4747+
# Check that the two tokenizers are identical
4748+
self.assertEqual(tokenizer_cached.get_vocab(), tokenizer_local.get_vocab())
4749+
self.assertEqual(
4750+
tokenizer_cached.all_special_tokens_extended,
4751+
tokenizer_local.all_special_tokens_extended,
4752+
)
4753+
except Exception as _:
4754+
pass # if the pretrained model is not loadable how could it pass locally :)

0 commit comments

Comments
 (0)