Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions MIGRATION_GUIDE_V5.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ _We aim for this to be fixed and released in a following release candidate in th
### Remote code incompatibility

A lot of paths were removed and reworked; paths like `transformers.tokenization_utils` and `transformers.tokenization_utils_fast`, which no longer exist.
We'll be working on backwards compatibility for these before version 5 is fully released.
These now redirect to `transformers.tokenization_utils_sentencepiece` and `transformers.tokenization_utils_tokenizers` respectively; please update imports accordingly.

_We aim for this to be fixed and released in a following release candidate in the week that follows RC0._

Expand Down Expand Up @@ -621,4 +621,4 @@ Linked PR: https://github.com/huggingface/transformers/pull/42391.
- related to 1., it is not possible to set proxies from your script. To handle proxies, you must set the `HTTP_PROXY` / `HTTPS_PROXY` environment variables
- `hf_transfer` and therefore `HF_HUB_ENABLE_HF_TRANSFER` have been completed dropped in favor of `hf_xet`. This should be transparent for most users. Please let us know if you notice any downside!

`typer-slim` has been added as required dependency, used to implement both `hf` and `transformers` CLIs.
`typer-slim` has been added as required dependency, used to implement both `hf` and `transformers` CLIs.
27 changes: 25 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

__version__ = "5.0.0.dev0"

import importlib
import sys
import types
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -174,6 +177,8 @@
"quantizers": [],
"testing_utils": [],
"tokenization_python": ["PreTrainedTokenizer", "PythonBackend"],
"tokenization_utils": [],
"tokenization_utils_fast": [],
"tokenization_utils_sentencepiece": ["SentencePieceBackend"],
"tokenization_utils_base": [
"AddedToken",
Expand Down Expand Up @@ -764,8 +769,6 @@
from .utils.quantization_config import VptqConfig as VptqConfig
from .video_processing_utils import BaseVideoProcessor as BaseVideoProcessor
else:
import sys

_import_structure = {k: set(v) for k, v in _import_structure.items()}

import_structure = define_import_structure(Path(__file__).parent / "models", prefix="models")
Expand All @@ -779,6 +782,26 @@
extra_objects={"__version__": __version__},
)

def _create_tokenization_alias(alias: str, target: str) -> None:
"""
Lazily redirect legacy tokenization module paths to their replacements without importing heavy deps.
"""

module = types.ModuleType(alias)
module.__doc__ = f"Alias module for backward compatibility with `{target}`."

def _get_target():
return importlib.import_module(target, __name__)

module.__getattr__ = lambda name: getattr(_get_target(), name)
module.__dir__ = lambda: dir(_get_target())

sys.modules[alias] = module
setattr(sys.modules[__name__], alias.rsplit(".", 1)[-1], module)

_create_tokenization_alias(f"{__name__}.tokenization_utils_fast", ".tokenization_utils_tokenizers")
_create_tokenization_alias(f"{__name__}.tokenization_utils", ".tokenization_utils_sentencepiece")


if not is_torch_available():
logger.warning_advice(
Expand Down
99 changes: 41 additions & 58 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@

logger = logging.get_logger(__name__)


class AutoTokenizerError(ValueError):
"""Base class for AutoTokenizer loading errors."""


class AutoTokenizerBackendError(AutoTokenizerError):
"""Raised when the requested tokenizer backend cannot be used."""


class AutoTokenizerLoadError(AutoTokenizerError):
"""Raised when no tokenizer can be loaded from a checkpoint."""


# V5: Simplified mapping - single tokenizer class per model type (always prefer tokenizers-based)
REGISTERED_TOKENIZER_CLASSES: dict[str, type[Any]] = {}
REGISTERED_FAST_ALIASES: dict[str, type[Any]] = {}
Expand Down Expand Up @@ -389,6 +402,10 @@ def load_merges(merges_file):


def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
# Bloom tokenizer classes were removed but should map to the fast backend for BC
if class_name in {"BloomTokenizer", "BloomTokenizerFast"}:
return TokenizersBackend

if class_name in REGISTERED_FAST_ALIASES:
return REGISTERED_FAST_ALIASES[class_name]

Expand Down Expand Up @@ -453,7 +470,7 @@ def _load_tokenizers_backend(tokenizer_class, pretrained_model_name_or_path, inp
An instantiated tokenizer object

Raises:
ValueError: If tokenizer could not be loaded with tokenizers backend
AutoTokenizerLoadError: If tokenizer could not be loaded with tokenizers backend
"""
files_loaded = []

Expand Down Expand Up @@ -538,7 +555,8 @@ def _load_tokenizers_backend(tokenizer_class, pretrained_model_name_or_path, inp
if "vocab" in fast_sig.parameters:
try:
vocab_ids, vocab_scores, merges = SentencePieceExtractor(resolved_spm).extract()
files_loaded.append(spm_file)
if spm_file not in files_loaded:
files_loaded.append(spm_file)
kwargs["backend"] = "tokenizers"
kwargs["files_loaded"] = files_loaded
# If tokenizer needs both vocab and merges (BPE models)
Expand All @@ -553,6 +571,14 @@ def _load_tokenizers_backend(tokenizer_class, pretrained_model_name_or_path, inp
)
except Exception:
pass
if TokenizersBackend is not None and issubclass(tokenizer_class, TokenizersBackend):
# Provide the SentencePiece model directly when tokenizer.json is absent or extraction fails.
if spm_file not in files_loaded:
files_loaded.append(spm_file)
kwargs["backend"] = "tokenizers"
kwargs["files_loaded"] = files_loaded
kwargs.setdefault("vocab_file", resolved_spm)
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
except ImportError as e:
if "sentencepiece" in str(e).lower() or "SentencePiece" in str(e):
raise ImportError(
Expand Down Expand Up @@ -613,9 +639,9 @@ def _load_tokenizers_backend(tokenizer_class, pretrained_model_name_or_path, inp
pass

# If all methods failed, raise an error
raise ValueError(
raise AutoTokenizerLoadError(
f"Could not load tokenizer from {pretrained_model_name_or_path} using tokenizers backend. "
"No tokenizer.json, tekken.json, vocab.json+merges.txt, vocab.txt, or compatible SentencePiece model found."
"No tokenizer.json, tekken.json, vocab.json/merges.txt, vocab.txt, or compatible SentencePiece model found."
)


Expand Down Expand Up @@ -643,7 +669,8 @@ def _try_load_tokenizer_with_fallbacks(tokenizer_class, pretrained_model_name_or
An instantiated tokenizer object

Raises:
ValueError: If no tokenizer could be loaded
AutoTokenizerBackendError: If a requested backend dependency is missing
AutoTokenizerLoadError: If no tokenizer could be loaded
"""
# Extract the backend parameter - default to "tokenizers" to prioritize tokenizers backend
backend = kwargs.pop("backend", "tokenizers")
Expand All @@ -659,7 +686,7 @@ def _try_load_tokenizer_with_fallbacks(tokenizer_class, pretrained_model_name_or
# Route to SentencePiece backend if requested
if backend == "sentencepiece":
if SentencePieceBackend is None:
raise ValueError(
raise AutoTokenizerBackendError(
"SentencePiece backend was requested but sentencepiece is not installed. "
"Please install it with: pip install sentencepiece"
)
Expand Down Expand Up @@ -690,6 +717,9 @@ def _try_load_tokenizer_with_fallbacks(tokenizer_class, pretrained_model_name_or

# Route to tokenizers backend (default)
if backend == "tokenizers":
if tokenizer_class is None and TokenizersBackend is not None:
tokenizer_class = TokenizersBackend

if tokenizer_class is not None:
# Check if tokenizer_class inherits from PreTrainedTokenizer (but not from TokenizersBackend/SentencePieceBackend)
# These are edge cases with custom logic (e.g., BioGptTokenizer with Moses tokenization)
Expand Down Expand Up @@ -735,54 +765,13 @@ def _try_load_tokenizer_with_fallbacks(tokenizer_class, pretrained_model_name_or
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **custom_kwargs)

if TokenizersBackend is None:
raise ValueError(
raise AutoTokenizerBackendError(
"Tokenizers backend is the default but tokenizers library is not installed. "
"Please install it with: pip install tokenizers"
)
logger.info("Loading tokenizer with tokenizers backend")
try:
return _load_tokenizers_backend(tokenizer_class, pretrained_model_name_or_path, inputs, kwargs)
except ValueError as e:
# If tokenizers backend fails, try falling back to SentencePiece backend if available
spm_file = _find_sentencepiece_model_file(pretrained_model_name_or_path, **kwargs)
if spm_file is not None and SentencePieceBackend is not None:
logger.info(
f"Tokenizers backend failed: {e}. "
f"Falling back to SentencePieceBackend since {spm_file} file was found."
)
files_loaded = [spm_file]
kwargs["backend"] = "sentencepiece"
kwargs["files_loaded"] = files_loaded
# Resolve the SPM file path and pass it as vocab_file
resolved_vocab_file = cached_file(
pretrained_model_name_or_path,
spm_file,
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
subfolder=kwargs.get("subfolder", ""),
)
kwargs["vocab_file"] = resolved_vocab_file
if tokenizer_class is not None and issubclass(tokenizer_class, SentencePieceBackend):
logger.info(
"Falling back to SentencePiece backend using tokenizer class that inherits from SentencePieceBackend."
)
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
return SentencePieceBackend.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
# If no fallback available, try calling tokenizer class directly as last resort
if hasattr(tokenizer_class, "from_pretrained"):
logger.info(
f"Tokenizers backend failed: {e}. Trying to load tokenizer directly from tokenizer class."
)
# Filter out AutoTokenizer-specific kwargs that custom tokenizers don't accept
custom_kwargs = {k: v for k, v in kwargs.items() if k not in ["backend", "files_loaded"]}
custom_kwargs["_from_auto"] = True # Signal that this is called from AutoTokenizer
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **custom_kwargs)
# Re-raise if no fallback options available
raise

return _load_tokenizers_backend(tokenizer_class, pretrained_model_name_or_path, inputs, kwargs)

# If no tokenizer class but tokenizers backend requested, fall back to SentencePiece if available
spm_file = _find_sentencepiece_model_file(pretrained_model_name_or_path, **kwargs)
Expand Down Expand Up @@ -818,7 +807,7 @@ def _try_load_tokenizer_with_fallbacks(tokenizer_class, pretrained_model_name_or
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
return SentencePieceBackend.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

raise ValueError(
raise AutoTokenizerLoadError(
f"Could not load tokenizer from {pretrained_model_name_or_path}. "
"No tokenizer class could be determined and no SentencePiece model found."
)
Expand Down Expand Up @@ -1144,17 +1133,11 @@ def from_pretrained(

model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
tokenizer_class = TOKENIZER_MAPPING[type(config)]

tokenizer_class = TOKENIZER_MAPPING.get(type(config), TokenizersBackend)
if tokenizer_class is not None:
return _try_load_tokenizer_with_fallbacks(
tokenizer_class, pretrained_model_name_or_path, inputs, kwargs
)
else:
raise ValueError(
"This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
"in order to use this tokenizer."
)

raise ValueError(
f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/tokenization_mistral_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,3 +1986,7 @@ def _get_validation_mode(mode: Union[str, ValidationMode]) -> ValidationMode:
if mode not in [ValidationMode.finetuning, ValidationMode.test]:
raise ValueError(_invalid_mode_msg)
return mode


# Backward compatibility alias for codebases still importing the legacy name.
MistralCommonTokenizer = MistralCommonBackend
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After vllm-project/vllm#29872, vLLM will import the new name if it can.

It's probably worth adding a warning to get people to update

59 changes: 59 additions & 0 deletions tests/models/auto/test_tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
from transformers.models.auto.tokenization_auto import (
TOKENIZER_MAPPING,
AutoTokenizerBackendError,
AutoTokenizerLoadError,
_load_tokenizers_backend,
_try_load_tokenizer_with_fallbacks,
get_tokenizer_config,
tokenizer_class_from_name,
)
Expand Down Expand Up @@ -263,6 +267,61 @@ def test_auto_tokenizer_from_local_folder_mistral_detection(self):
self.assertIsInstance(tokenizer2, tokenizer.__class__)
self.assertTrue(tokenizer2.vocab_size > 100_000)

def test_sentencepiece_backend_missing_raises(self):
dummy_inputs = ()
dummy_kwargs = {"backend": "sentencepiece"}
with mock.patch(
"transformers.models.auto.tokenization_auto.SentencePieceBackend", None
), pytest.raises(
AutoTokenizerBackendError,
match="SentencePiece backend was requested but sentencepiece is not installed",
):
_try_load_tokenizer_with_fallbacks(None, "dummy-repo", dummy_inputs, dummy_kwargs)

def test_tokenizers_backend_missing_raises(self):
class DummyBackend:
@classmethod
def from_pretrained(cls, *args, **kwargs):
return cls()

dummy_inputs = ()
dummy_kwargs = {"backend": "tokenizers"}
with mock.patch.multiple(
"transformers.models.auto.tokenization_auto", TokenizersBackend=None, SentencePieceBackend=DummyBackend
), pytest.raises(
AutoTokenizerBackendError,
match="Tokenizers backend is the default but tokenizers library is not installed",
):
_try_load_tokenizer_with_fallbacks(DummyBackend, "dummy-repo", dummy_inputs, dummy_kwargs)

def test_auto_tokenizer_load_error_has_expected_message(self):
class DummyTokenizer:
@classmethod
def from_pretrained(cls, *args, **kwargs):
return cls()

with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(
AutoTokenizerLoadError,
match=(
"Could not load tokenizer from .* using tokenizers backend. "
"No tokenizer\\.json, tekken\\.json, vocab\\.json/merges\\.txt, vocab\\.txt, or compatible "
"SentencePiece model found."
),
):
_load_tokenizers_backend(DummyTokenizer, tmp_dir, (), {})

@require_tokenizers
def test_auto_tokenizer_loads_bloom_repo_without_tokenizer_class(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
self.assertIsInstance(tokenizer, TokenizersBackend)
self.assertTrue(tokenizer.is_fast)

@require_tokenizers
def test_auto_tokenizer_loads_sentencepiece_only_repo(self):
tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-mbart")
self.assertIsInstance(tokenizer, TokenizersBackend)
self.assertTrue(tokenizer.is_fast)

def test_auto_tokenizer_fast_no_slow(self):
tokenizer = AutoTokenizer.from_pretrained("Salesforce/ctrl")
# There is no fast CTRL so this always gives us a slow tokenizer.
Expand Down