Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom feature extractor #15630

Merged
merged 5 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from requests import HTTPError

from .dynamic_module_utils import custom_object_save
from .file_utils import (
FEATURE_EXTRACTOR_NAME,
EntryNotFoundError,
Expand Down Expand Up @@ -205,6 +206,8 @@ class FeatureExtractionMixin:
extractors.
"""

_auto_class = None

def __init__(self, **kwargs):
"""Set elements of `kwargs` as attributes."""
# Pop "processor_class" as it should be saved as private attribute
Expand Down Expand Up @@ -316,6 +319,12 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self)

os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained`
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
Expand Down Expand Up @@ -539,3 +548,29 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]):

def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"

@classmethod
def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
"""
Register this class with a given auto class. This should only be used for custom feature extractors as the ones
in the library are already mapped with `AutoFeatureExtractor`.

<Tip warning={true}>

This API is experimental and may have some slight breaking changes in the next releases.

</Tip>

Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
The auto class to register this new feature extractor with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__

import transformers.models.auto as auto_module

if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")

cls._auto_class = auto_class
162 changes: 140 additions & 22 deletions src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,28 @@
# limitations under the License.
""" AutoFeatureExtractor class."""
import importlib
import json
import os
from collections import OrderedDict
from typing import Dict, Optional, Union

# Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo
from ...utils import logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings,
)


logger = logging.get_logger(__name__)

FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
("beit", "BeitFeatureExtractor"),
Expand Down Expand Up @@ -66,6 +71,96 @@ def feature_extractor_class_from_name(class_name: str):
return None


def get_feature_extractor_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved
"""
Loads the tokenizer configuration from a pretrained model tokenizer configuration.

Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:

- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
under a user or organization name, like `dbmdz/bert-base-german-cased`.
- a path to a *directory* containing a configuration file saved using the
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.

cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision(`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.

<Tip>

Passing `use_auth_token=True` is required when you want to use a private model.

</Tip>

Returns:
`Dict`: The configuration of the tokenizer.

Examples:

```python
# Download configuration from huggingface.co and cache.
tokenizer_config = get_tokenizer_config("bert-base-uncased")
# This model does not have a tokenizer config so the result will be an empty dict.
tokenizer_config = get_tokenizer_config("xlm-roberta-base")

# Save a pretrained tokenizer locally and you can reload its config
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
```"""
resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path,
FEATURE_EXTRACTOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
)
if resolved_config_file is None:
logger.info(
"Could not locate the feature extractor configuration file, will try to use the model config instead."
)
return {}

with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)


class AutoFeatureExtractor:
r"""
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
Expand Down Expand Up @@ -128,6 +223,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
Expand All @@ -151,35 +250,54 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
```"""
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
kwargs["_from_auto"] = True

is_feature_extraction_file = os.path.isfile(pretrained_model_name_or_path)
is_directory = os.path.isdir(pretrained_model_name_or_path) and os.path.exists(
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
)

has_local_config = (
os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False
)
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
feature_extractor_class = config_dict.get("feature_extractor_type", None)
feature_extractor_auto_map = None
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]

# load config, if it can be loaded
if not is_feature_extraction_file and (has_local_config or not is_directory):
# If we don't find the feature extractor class in the feature extractor config, let's try the model config.
if feature_extractor_class is None and feature_extractor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
# It could be in `config.feature_extractor_type``
feature_extractor_class = getattr(config, "feature_extractor_type", None)
if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]

kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if feature_extractor_class is not None:
# If we have custom code for a feature extractor, we get the proper class.
if feature_extractor_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file "
"in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
"code to ensure no malicious code has been contributed in a newer revision."
)

model_type = config_class_to_model_type(type(config).__name__)
module_file, class_name = feature_extractor_auto_map.split(".")
feature_extractor_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
else:
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)

if "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
return feature_extractor_class.from_dict(config_dict, **kwargs)
elif model_type is not None:
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
# Last try: we use the FEATURE_EXTRACTOR_MAPPING.
elif type(config) in FEATURE_EXTRACTOR_MAPPING:
feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
return feature_extractor_class.from_dict(config_dict, **kwargs)

Comment on lines -173 to 298
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Welcome change!

raise ValueError(
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
)
6 changes: 6 additions & 0 deletions tests/test_feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,9 @@ def test_feature_extractor_not_found(self):
"hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.",
):
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")

def test_from_pretrained_dynamic_feature_extractor(self):
model = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
)
self.assertEqual(model.__class__.__name__, "NewFeatureExtractor")
53 changes: 53 additions & 0 deletions tests/test_feature_extraction_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,21 @@

import json
import os
import sys
import tempfile
import unittest
from pathlib import Path

from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import PASS, USER, is_staging_test


sys.path.append(str(Path(__file__).parent.parent / "utils"))

from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402


if is_torch_available():
Expand All @@ -29,6 +41,9 @@
from PIL import Image


SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")


def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
Expand Down Expand Up @@ -99,3 +114,41 @@ def test_feat_extract_from_and_save_pretrained(self):
def test_init_without_params(self):
feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract)


@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = login(username=USER, password=PASS)

@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
except HTTPError:
pass

def test_push_to_hub_dynamic_feature_extractor(self):
CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)

with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-feature-extractor", use_auth_token=self._token)
feature_extractor.save_pretrained(tmp_dir)

# This has added the proper auto_map field to the config
self.assertDictEqual(
feature_extractor.auto_map,
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
)
# The code has been copied from fixtures
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py")))

repo.push_to_hub()

new_feature_extractor = AutoFeatureExtractor.from_pretrained(
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
)
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
5 changes: 5 additions & 0 deletions utils/test_module/custom_feature_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from transformers import Wav2Vec2FeatureExtractor


class CustomFeatureExtractor(Wav2Vec2FeatureExtractor):
pass