Skip to content

Commit

Permalink
Auto feature extractor (#11097)
Browse files Browse the repository at this point in the history
* AutoFeatureExtractor

* Init and first tests

* Tests

* Damn you gitignore

* Quality

* Defensive test for when not all backends are here

* Use pattern for Speech2Text models
  • Loading branch information
sgugger authored Apr 6, 2021
1 parent 520198f commit 403d530
Show file tree
Hide file tree
Showing 18 changed files with 309 additions and 34 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ __pycache__/
*.so

# tests and logs
tests/fixtures/*
!tests/fixtures/sample_text_no_unicode.txt
tests/fixtures/cached_*_text.txt
logs/
lightning_logs/
lang_code_data/
Expand Down
7 changes: 7 additions & 0 deletions docs/source/model_doc/auto.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ AutoTokenizer
:members:


AutoFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.AutoFeatureExtractor
:members:


AutoModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
42 changes: 34 additions & 8 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_BaseLazyModule,
is_flax_available,
is_sentencepiece_available,
is_speech_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
Expand Down Expand Up @@ -102,6 +103,7 @@
"is_py3nvml_available",
"is_sentencepiece_available",
"is_sklearn_available",
"is_speech_available",
"is_tf_available",
"is_tokenizers_available",
"is_torch_available",
Expand Down Expand Up @@ -133,9 +135,11 @@
"models.auto": [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CONFIG_MAPPING",
"FEATURE_EXTRACTOR_MAPPING",
"MODEL_NAMES_MAPPING",
"TOKENIZER_MAPPING",
"AutoConfig",
"AutoFeatureExtractor",
"AutoTokenizer",
],
"models.bart": ["BartConfig", "BartTokenizer"],
Expand Down Expand Up @@ -202,7 +206,6 @@
"models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig",
"Speech2TextFeatureExtractor",
],
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
Expand Down Expand Up @@ -288,7 +291,6 @@
_import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
_import_structure["models.t5"].append("T5Tokenizer")
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
Expand Down Expand Up @@ -339,13 +341,28 @@

if is_sentencepiece_available():
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]

else:
from .utils import dummy_tokenizers_objects

_import_structure["utils.dummy_tokenizers_objects"] = [
name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
]

# Speech-specific objects
if is_speech_available():
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")

if is_sentencepiece_available():
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")

else:
from .utils import dummy_speech_objects

_import_structure["utils.dummy_speech_objects"] = [
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]

# Vision-specific objects
if is_vision_available():
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
Expand Down Expand Up @@ -1394,6 +1411,7 @@
is_py3nvml_available,
is_sentencepiece_available,
is_sklearn_available,
is_speech_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
Expand Down Expand Up @@ -1429,9 +1447,11 @@
from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING,
MODEL_NAMES_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoTokenizer,
)
from .models.bart import BartConfig, BartTokenizer
Expand Down Expand Up @@ -1494,11 +1514,7 @@
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
from .models.speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig,
Speech2TextFeatureExtractor,
)
from .models.speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
Expand Down Expand Up @@ -1585,7 +1601,7 @@
from .models.mt5 import MT5Tokenizer
from .models.pegasus import PegasusTokenizer
from .models.reformer import ReformerTokenizer
from .models.speech_to_text import Speech2TextProcessor, Speech2TextTokenizer
from .models.speech_to_text import Speech2TextTokenizer
from .models.t5 import T5Tokenizer
from .models.xlm_prophetnet import XLMProphetNetTokenizer
from .models.xlm_roberta import XLMRobertaTokenizer
Expand Down Expand Up @@ -1627,9 +1643,19 @@

if is_sentencepiece_available():
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer

else:
from .utils.dummy_tokenizers_objects import *

if is_speech_available():
from .models.speech_to_text import Speech2TextFeatureExtractor

if is_sentencepiece_available():
from .models.speech_to_text import Speech2TextProcessor

else:
from .utils.dummy_speech_objects import *

if is_vision_available():
from .image_utils import ImageFeatureExtractionMixin
from .models.vit import ViTFeatureExtractor
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"sphinx-copybutton": "sphinx-copybutton",
"sphinx-markdown-tables": "sphinx-markdown-tables",
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
"sphinxext-opengraph": "sphinxext-opengraph==0.4.1",
"sphinx": "sphinx==3.2.1",
"starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3",
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def get_feature_extractor_dict(
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)

from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)

user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline

if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
Expand All @@ -349,6 +356,7 @@ def get_feature_extractor_dict(
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
# Load feature_extractor dict
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
Expand Down Expand Up @@ -426,6 +434,7 @@ def to_dict(self) -> Dict[str, Any]:
:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__

return output

Expand Down
18 changes: 18 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,11 @@ def is_torchaudio_available():
return _torchaudio_available


def is_speech_available():
# For now this depends on torchaudio but the exact dependency might evolve in the future.
return _torchaudio_available


def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
Expand Down Expand Up @@ -513,6 +518,13 @@ def wrapper(*args, **kwargs):
"""


# docstyle-ignore
SPEECH_IMPORT_ERROR = """
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
`pip install torchaudio`
"""


# docstyle-ignore
VISION_IMPORT_ERROR = """
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
Expand Down Expand Up @@ -586,6 +598,12 @@ def requires_scatter(obj):
raise ImportError(SCATTER_IMPORT_ERROR.format(name))


def requires_speech(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_speech_available():
raise ImportError(SPEECH_IMPORT_ERROR.format(name))


def requires_vision(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_vision_available():
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

_import_structure = {
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
}

Expand Down Expand Up @@ -104,6 +105,7 @@

if TYPE_CHECKING:
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer

if is_torch_available():
Expand Down
Loading

0 comments on commit 403d530

Please sign in to comment.