@@ -2046,12 +2046,13 @@ def from_pretrained(
20462046 template = template .removesuffix (".jinja" )
20472047 vocab_files [f"chat_template_{ template } " ] = f"{ CHAT_TEMPLATE_DIR } /{ template } .jinja"
20482048
2049+ remote_files = []
20492050 if not is_local and not local_files_only :
20502051 try :
20512052 remote_files = list_repo_files (pretrained_model_name_or_path )
20522053 except Exception :
20532054 remote_files = []
2054- else :
2055+ elif pretrained_model_name_or_path and os . path . isdir ( pretrained_model_name_or_path ) :
20552056 remote_files = os .listdir (pretrained_model_name_or_path )
20562057
20572058 if "tokenizer_file" in vocab_files and not re .search (vocab_files ["tokenizer_file" ], "" .join (remote_files )):
@@ -2385,57 +2386,108 @@ def _from_pretrained(
23852386 except NotImplementedError :
23862387 vocab_size = 0
23872388
2389+ # Optionally patches mistral tokenizers with wrong regex
23882390 if (
23892391 vocab_size > 100000
23902392 and hasattr (tokenizer , "_tokenizer" )
23912393 and getattr (tokenizer ._tokenizer , "pre_tokenizer" , None ) is not None
23922394 ):
2393- from huggingface_hub import model_info
2395+ tokenizer = cls ._patch_mistral_regex (
2396+ tokenizer ,
2397+ pretrained_model_name_or_path ,
2398+ token = token ,
2399+ cache_dir = cache_dir ,
2400+ local_files_only = local_files_only ,
2401+ _commit_hash = _commit_hash ,
2402+ _is_local = _is_local ,
2403+ init_kwargs = init_kwargs ,
2404+ fix_mistral_regex = kwargs .get ("fix_mistral_regex" ),
2405+ )
23942406
2395- def is_base_mistral (model_id : str ) -> bool :
2396- model = model_info (model_id )
2397- if model .tags is not None :
2398- if re .search ("base_model:.*mistralai" , "" .join (model .tags )):
2399- return True
2400- return False
2407+ return tokenizer
24012408
2402- if _is_local or is_base_mistral (pretrained_model_name_or_path ):
2403- _config_file = cached_file (
2404- pretrained_model_name_or_path ,
2405- "config.json" ,
2406- cache_dir = cache_dir ,
2407- token = token ,
2408- local_files_only = local_files_only ,
2409- _raise_exceptions_for_missing_entries = False ,
2410- _raise_exceptions_for_connection_errors = False ,
2411- _commit_hash = _commit_hash ,
2412- )
2413- if _config_file is not None :
2414- with open (_config_file , encoding = "utf-8" ) as f :
2415- _config = json .load (f )
2416- transformers_version = _config .get ("transformers_version" )
2409+ @classmethod
2410+ def _patch_mistral_regex (
2411+ cls ,
2412+ tokenizer ,
2413+ pretrained_model_name_or_path ,
2414+ token = None ,
2415+ cache_dir = None ,
2416+ local_files_only = False ,
2417+ _commit_hash = None ,
2418+ _is_local = False ,
2419+ init_kwargs = None ,
2420+ fix_mistral_regex = None ,
2421+ ):
2422+ """
2423+ Patches mistral related tokenizers with incorrect regex if detected
2424+ 1) Local file with an associated config saved next to it
2425+ >> Model type one of the mistral models (on older versions)
2426+ 2) Remote models on the hub from official mistral models
2427+ >> Tags including `base_model:.*mistralai`
2428+ """
2429+ from huggingface_hub import model_info
24172430
2418- if transformers_version and version .parse (transformers_version ) <= version .parse ("4.57.2" ):
2419- if _is_local and _config .model_type not in [
2431+ def is_base_mistral (model_id : str ) -> bool :
2432+ model = model_info (model_id )
2433+ if model .tags is not None :
2434+ if re .search ("base_model:.*mistralai" , "" .join (model .tags )):
2435+ return True
2436+ return False
2437+
2438+ if _is_local or is_base_mistral (pretrained_model_name_or_path ):
2439+ _config_file = cached_file (
2440+ pretrained_model_name_or_path ,
2441+ "config.json" ,
2442+ cache_dir = cache_dir ,
2443+ token = token ,
2444+ local_files_only = local_files_only ,
2445+ _raise_exceptions_for_missing_entries = False ,
2446+ _raise_exceptions_for_connection_errors = False ,
2447+ _commit_hash = _commit_hash ,
2448+ )
2449+
2450+ # Detected using a (local) mistral tokenizer
2451+ mistral_config_detected = False
2452+ if _config_file is not None :
2453+ with open (_config_file , encoding = "utf-8" ) as f :
2454+ _config = json .load (f )
2455+ transformers_version = _config .get ("transformers_version" )
2456+ transformers_model_type = _config .get ("model_type" )
2457+
2458+ # Detect if we can skip the mistral fix by
2459+ # a) having a non-mistral tokenizer
2460+ # b) fixed version of transformers
2461+ if transformers_version and version .parse (transformers_version ) <= version .parse ("4.57.2" ):
2462+ if (
2463+ _is_local
2464+ and transformers_model_type is not None
2465+ and transformers_model_type
2466+ not in [
24202467 "mistral" ,
24212468 "mistral3" ,
2422- "voxstral " ,
2469+ "voxtral " ,
24232470 "ministral" ,
24242471 "pixtral" ,
2425- ]:
2426- return tokenizer
2472+ ]
2473+ ):
2474+ return tokenizer
2475+ elif transformers_version and version .parse (transformers_version ) >= version .parse ("5.0.0" ):
2476+ return tokenizer
24272477
2478+ mistral_config_detected = True
2479+
2480+ if mistral_config_detected or (not _is_local and is_base_mistral (pretrained_model_name_or_path )):
24282481 # Expose the `fix_mistral_regex` flag on the tokenizer when provided, even if no correction is applied.
2429- if "fix_mistral_regex" in init_kwargs :
2482+ if init_kwargs and "fix_mistral_regex" in init_kwargs :
24302483 setattr (tokenizer , "fix_mistral_regex" , init_kwargs ["fix_mistral_regex" ])
24312484
2432- fix_mistral_regex = kwargs .get ("fix_mistral_regex" ) # not init kwargs
24332485 # only warn if its not explicitly passed
24342486 if fix_mistral_regex is None and not getattr (tokenizer , "fix_mistral_regex" , False ):
24352487 setattr (tokenizer , "fix_mistral_regex" , False )
24362488 logger .warning (
24372489 f"The tokenizer you are loading from '{ pretrained_model_name_or_path } '"
2438- f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. "
2490+ f" with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e."
24392491 " This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue."
24402492 )
24412493 elif fix_mistral_regex is True or getattr (tokenizer , "fix_mistral_regex" , False ):
@@ -2448,7 +2500,6 @@ def is_base_mistral(model_id: str) -> bool:
24482500 ),
24492501 behavior = "isolated" ,
24502502 )
2451-
24522503 return tokenizer
24532504
24542505 @staticmethod
0 commit comments