Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto processor #14465

Merged
merged 7 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/model_doc/auto.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ AutoFeatureExtractor
:members:


AutoProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.AutoProcessor
:members:


AutoModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,11 @@
"CONFIG_MAPPING",
"FEATURE_EXTRACTOR_MAPPING",
"MODEL_NAMES_MAPPING",
"PROCESSOR_MAPPING",
"TOKENIZER_MAPPING",
"AutoConfig",
"AutoFeatureExtractor",
"AutoProcessor",
"AutoTokenizer",
],
"models.bart": ["BartConfig", "BartTokenizer"],
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"auto_factory": ["get_values"],
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
"processing_auto": ["PROCESSOR_MAPPING", "AutoProcessor"],
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
}

Expand Down Expand Up @@ -130,6 +131,7 @@
from .auto_factory import get_values
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from .processing_auto import PROCESSOR_MAPPING, AutoProcessor
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer

if is_torch_available():
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object
(either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's
missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
The feature extractor class to instantiate is selected based on the :obj:`model_type` property of the config
object (either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when
it's missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
Comment on lines -84 to +86
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch

List options
Expand Down Expand Up @@ -136,10 +136,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
>>> from transformers import AutoFeatureExtractor
>>> # Download vocabulary from huggingface.co and cache.
>>> # Download feature extractor from huggingface.co and cache.
>>> feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h')
>>> # If vocabulary files are in a directory (e.g. feature extractor was saved using `save_pretrained('./test/saved_model/')`)
>>> # If feature extractor files are in a directory (e.g. feature extractor was saved using `save_pretrained('./test/saved_model/')`)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained('./test/saved_model/')
"""
Expand Down
165 changes: 165 additions & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" AutoProcessor class. """
import importlib
from collections import OrderedDict

# Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings,
)


PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("clip", "CLIPProcessor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutxlm", "LayoutXLMProcessor"),
("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"),
("trocr", "TrOCRProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
]
)

PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)


def processor_class_from_name(class_name: str):
for module_name, processors in PROCESSOR_MAPPING_NAMES.items():
if class_name in processors:
module_name = model_type_to_module_name(module_name)

module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
break

return None


class AutoProcessor:
r"""
This is a generic processor class that will be instantiated as one of the processor classes of the library when
created with the :meth:`AutoProcessor.from_pretrained` class method.

This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""

def __init__(self):
raise EnvironmentError(
"AutoProcessor is designed to be instantiated "
"using the `AutoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
)

@classmethod
@replace_list_option_in_docstrings(PROCESSOR_MAPPING_NAMES)
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
Instantiate one of the processor classes of the library from a pretrained model vocabulary.

The processor class to instantiate is selected based on the :obj:`model_type` property of the config object
(either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible):

List options

Params:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:

- a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a processor files saved using the :obj:`save_pretrained()` method,
e.g., ``./my_model_directory/``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force to (re-)download the feature extractor files and override the cached versions
if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file
exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
sgugger marked this conversation as resolved.
Show resolved Hide resolved
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final feature extractor object. If :obj:`True`,
then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where `unused_kwargs` is a
dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the
part of ``kwargs`` which has not been used to update ``feature_extractor`` and is otherwise ignored.
kwargs (:obj:`Dict[str, Any]`, `optional`):
The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
controlled by the ``return_unused_kwargs`` keyword parameter.

.. note::

Passing :obj:`use_auth_token=True` is required when you want to use a private model.

Examples::

>>> from transformers import AutoProcessor

>>> # Download processor from huggingface.co and cache.
>>> processor = AutoProcessor.from_pretrained('facebook/wav2vec2-base-960h')

>>> # If processor files are in a directory (e.g. processor was saved using `save_pretrained('./test/saved_model/')`)
>>> processor = AutoProcessor.from_pretrained('./test/saved_model/')

"""
config = kwargs.pop("config", None)
kwargs["_from_auto"] = True

# First, look for a processor_type in the preprocessor_config
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "processor_class" in config_dict:
processor_class = processor_class_from_name(config_dict["processor_class"])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that here I chose "processor_class". Feature extractors use feature_extractor_type but tokenizers use tokenizer_class, which I think is more adapted as the value is a class name. Which is a type if you want to go there, but by model_type we imply something like "bert" or "speech_to_text", not a class name.

return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

# Otherwise, load config, if it can be loaded.
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

model_type = config_class_to_model_type(type(config).__name__)

if getattr(config_dict, "processor_class", None) is not None:
processor_class = config.processor_class
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)

raise ValueError(
f"Unrecognized processor in {pretrained_model_name_or_path}. Should have a `processor_type` key in "
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
f"{', '.join(c for c in PROCESSOR_MAPPING_NAMES.keys())}"
)
3 changes: 2 additions & 1 deletion tests/fixtures/preprocessor_config.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
"processor_class": "Wav2Vec2Processor"
}
50 changes: 50 additions & 0 deletions tests/test_processor_auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# coding=utf-8
# Copyright 2021 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor


SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_PROCESSOR_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
)
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")


class AutoFeatureExtractorTest(unittest.TestCase):
def test_processor_from_model_shortcut(self):
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(processor, Wav2Vec2Processor)

def test_processor_from_local_directory_from_key(self):
processor = AutoProcessor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
self.assertIsInstance(processor, Wav2Vec2Processor)

def test_processor_from_local_directory_from_config(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model_config = Wav2Vec2Config()
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")

# save in new folder
model_config.save_pretrained(tmpdirname)
processor.save_pretrained(tmpdirname)

processor = AutoProcessor.from_pretrained(tmpdirname)

self.assertIsInstance(processor, Wav2Vec2Processor)