diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index b65b7cfcd9168..e535c3dbdea0f 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -26,6 +26,7 @@ from requests import HTTPError +from .dynamic_module_utils import custom_object_save from .file_utils import ( FEATURE_EXTRACTOR_NAME, EntryNotFoundError, @@ -205,6 +206,8 @@ class FeatureExtractionMixin: extractors. """ + _auto_class = None + def __init__(self, **kwargs): """Set elements of `kwargs` as attributes.""" # Pop "processor_class" as it should be saved as private attribute @@ -316,6 +319,12 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) + os.makedirs(save_directory, exist_ok=True) # If we save using the predefined names, we can load using `from_pretrained` output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME) @@ -539,3 +548,29 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]): def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" + + @classmethod + def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"): + """ + Register this class with a given auto class. This should only be used for custom feature extractors as the ones + in the library are already mapped with `AutoFeatureExtractor`. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`): + The auto class to register this new feature extractor with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index a146c611fb963..93e0fc1ba9bd3 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -14,23 +14,28 @@ # limitations under the License. """ AutoFeatureExtractor class.""" import importlib +import json import os from collections import OrderedDict +from typing import Dict, Optional, Union # Build the list of all feature extractors from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module from ...feature_extraction_utils import FeatureExtractionMixin -from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME +from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo +from ...utils import logging from .auto_factory import _LazyAutoMapping from .configuration_auto import ( CONFIG_MAPPING_NAMES, AutoConfig, - config_class_to_model_type, model_type_to_module_name, replace_list_option_in_docstrings, ) +logger = logging.get_logger(__name__) + FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( [ ("beit", "BeitFeatureExtractor"), @@ -66,6 +71,96 @@ def feature_extractor_class_from_name(class_name: str): return None +def get_feature_extractor_config( + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Loads the tokenizer configuration from a pretrained model tokenizer configuration. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision(`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + `Dict`: The configuration of the tokenizer. + + Examples: + + ```python + # Download configuration from huggingface.co and cache. + tokenizer_config = get_tokenizer_config("bert-base-uncased") + # This model does not have a tokenizer config so the result will be an empty dict. + tokenizer_config = get_tokenizer_config("xlm-roberta-base") + + # Save a pretrained tokenizer locally and you can reload its config + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + tokenizer.save_pretrained("tokenizer-test") + tokenizer_config = get_tokenizer_config("tokenizer-test") + ```""" + resolved_config_file = get_file_from_repo( + pretrained_model_name_or_path, + FEATURE_EXTRACTOR_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + if resolved_config_file is None: + logger.info( + "Could not locate the feature extractor configuration file, will try to use the model config instead." + ) + return {} + + with open(resolved_config_file, encoding="utf-8") as reader: + return json.load(reader) + + class AutoFeatureExtractor: r""" This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the @@ -128,6 +223,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. kwargs (`Dict[str, Any]`, *optional*): The values in kwargs of any keys which are feature extractor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is @@ -151,35 +250,54 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): >>> feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/") ```""" config = kwargs.pop("config", None) + trust_remote_code = kwargs.pop("trust_remote_code", False) kwargs["_from_auto"] = True - is_feature_extraction_file = os.path.isfile(pretrained_model_name_or_path) - is_directory = os.path.isdir(pretrained_model_name_or_path) and os.path.exists( - os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) - ) - - has_local_config = ( - os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False - ) + config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) + feature_extractor_class = config_dict.get("feature_extractor_type", None) + feature_extractor_auto_map = None + if "AutoFeatureExtractor" in config_dict.get("auto_map", {}): + feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"] - # load config, if it can be loaded - if not is_feature_extraction_file and (has_local_config or not is_directory): + # If we don't find the feature extractor class in the feature extractor config, let's try the model config. + if feature_extractor_class is None and feature_extractor_auto_map is None: if not isinstance(config, PretrainedConfig): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + # It could be in `config.feature_extractor_type`` + feature_extractor_class = getattr(config, "feature_extractor_type", None) + if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map: + feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"] - kwargs["_from_auto"] = True - config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) + if feature_extractor_class is not None: + # If we have custom code for a feature extractor, we get the proper class. + if feature_extractor_auto_map is not None: + if not trust_remote_code: + raise ValueError( + f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file " + "in that repo on your local machine. Make sure you have read the code there to avoid " + "malicious use, then set the option `trust_remote_code=True` to remove this error." + ) + if kwargs.get("revision", None) is None: + logger.warning( + "Explicitly passing a `revision` is encouraged when loading a feature extractor with custom " + "code to ensure no malicious code has been contributed in a newer revision." + ) - model_type = config_class_to_model_type(type(config).__name__) + module_file, class_name = feature_extractor_auto_map.split(".") + feature_extractor_class = get_class_from_dynamic_module( + pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs + ) + else: + feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class) - if "feature_extractor_type" in config_dict: - feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"]) return feature_extractor_class.from_dict(config_dict, **kwargs) - elif model_type is not None: - return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs) + # Last try: we use the FEATURE_EXTRACTOR_MAPPING. + elif type(config) in FEATURE_EXTRACTOR_MAPPING: + feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)] + return feature_extractor_class.from_dict(config_dict, **kwargs) raise ValueError( - f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in " - f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: " - f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}" + f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a " + f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following " + "`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}" ) diff --git a/tests/test_feature_extraction_auto.py b/tests/test_feature_extraction_auto.py index c827b0a656916..da5386bd506c7 100644 --- a/tests/test_feature_extraction_auto.py +++ b/tests/test_feature_extraction_auto.py @@ -82,3 +82,9 @@ def test_feature_extractor_not_found(self): "hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.", ): _ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model") + + def test_from_pretrained_dynamic_feature_extractor(self): + model = AutoFeatureExtractor.from_pretrained( + "hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True + ) + self.assertEqual(model.__class__.__name__, "NewFeatureExtractor") diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index 217da135ca1cd..931ee2444e834 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -16,9 +16,21 @@ import json import os +import sys import tempfile +import unittest +from pathlib import Path +from huggingface_hub import Repository, delete_repo, login +from requests.exceptions import HTTPError +from transformers import AutoFeatureExtractor from transformers.file_utils import is_torch_available, is_vision_available +from transformers.testing_utils import PASS, USER, is_staging_test + + +sys.path.append(str(Path(__file__).parent.parent / "utils")) + +from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402 if is_torch_available(): @@ -29,6 +41,9 @@ from PIL import Image +SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") + + def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False): """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, or a list of PyTorch tensors if one specifies torchify=True. @@ -99,3 +114,41 @@ def test_feat_extract_from_and_save_pretrained(self): def test_init_without_params(self): feat_extract = self.feature_extraction_class() self.assertIsNotNone(feat_extract) + + +@is_staging_test +class ConfigPushToHubTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._token = login(username=USER, password=PASS) + + @classmethod + def tearDownClass(cls): + try: + delete_repo(token=cls._token, name="test-dynamic-feature-extractor") + except HTTPError: + pass + + def test_push_to_hub_dynamic_feature_extractor(self): + CustomFeatureExtractor.register_for_auto_class() + feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) + + with tempfile.TemporaryDirectory() as tmp_dir: + repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-feature-extractor", use_auth_token=self._token) + feature_extractor.save_pretrained(tmp_dir) + + # This has added the proper auto_map field to the config + self.assertDictEqual( + feature_extractor.auto_map, + {"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"}, + ) + # The code has been copied from fixtures + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py"))) + + repo.push_to_hub() + + new_feature_extractor = AutoFeatureExtractor.from_pretrained( + f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True + ) + # Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module + self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor") diff --git a/utils/test_module/custom_feature_extraction.py b/utils/test_module/custom_feature_extraction.py new file mode 100644 index 0000000000000..de367032d8fe8 --- /dev/null +++ b/utils/test_module/custom_feature_extraction.py @@ -0,0 +1,5 @@ +from transformers import Wav2Vec2FeatureExtractor + + +class CustomFeatureExtractor(Wav2Vec2FeatureExtractor): + pass