diff --git a/.circleci/config.yml b/.circleci/config.yml index e099814ea62417..ba8fa352dbc820 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -83,6 +83,7 @@ jobs: - run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,torch-speech,vision] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html - run: pip install tensorflow_probability + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-{{ checksum "setup.py" }} paths: @@ -151,6 +152,7 @@ jobs: - run: pip install --upgrade pip - run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-{{ checksum "setup.py" }} paths: @@ -187,6 +189,7 @@ jobs: - run: pip install --upgrade pip - run: pip install .[sklearn,flax,torch,testing,sentencepiece,torch-speech,vision] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-{{ checksum "setup.py" }} paths: @@ -217,6 +220,7 @@ jobs: - run: pip install --upgrade pip - run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} paths: @@ -252,6 +256,7 @@ jobs: - run: pip install --upgrade pip - run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} paths: @@ -278,9 +283,11 @@ jobs: keys: - v0.4-tf-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision] - run: pip install tensorflow_probability + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-tf-{{ checksum "setup.py" }} paths: @@ -312,9 +319,11 @@ jobs: keys: - v0.4-tf-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision] - run: pip install tensorflow_probability + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-tf-{{ checksum "setup.py" }} paths: @@ -341,8 +350,10 @@ jobs: keys: - v0.4-flax-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: sudo pip install .[flax,testing,sentencepiece,flax-speech,vision] + - run: pip install .[flax,testing,sentencepiece,flax-speech,vision] + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-flax-{{ checksum "setup.py" }} paths: @@ -374,8 +385,10 @@ jobs: keys: - v0.4-flax-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: sudo pip install .[flax,testing,sentencepiece,vision,flax-speech] + - run: pip install .[flax,testing,sentencepiece,vision,flax-speech] + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-flax-{{ checksum "setup.py" }} paths: @@ -407,6 +420,7 @@ jobs: - run: pip install --upgrade pip - run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} paths: @@ -443,6 +457,7 @@ jobs: - run: pip install --upgrade pip - run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm] - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html + - run: pip install https://github.com/kpu/kenlm/archive/master.zip - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} paths: @@ -582,7 +597,7 @@ jobs: path: ~/transformers/examples_output.txt - store_artifacts: path: ~/transformers/reports - + run_examples_torch_all: working_directory: ~/transformers docker: diff --git a/.github/workflows/self-push.yml b/.github/workflows/self-push.yml index 314b9c781a196d..b997f1db08b4b6 100644 --- a/.github/workflows/self-push.yml +++ b/.github/workflows/self-push.yml @@ -34,6 +34,7 @@ jobs: apt install -y libsndfile1-dev pip install --upgrade pip pip install .[sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm] + pip install https://github.com/kpu/kenlm/archive/master.zip - name: Launcher docker uses: actions/checkout@v2 @@ -87,6 +88,7 @@ jobs: pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html pip install --upgrade pip pip install .[sklearn,testing,sentencepiece,flax,flax-speech,vision] + pip install https://github.com/kpu/kenlm/archive/master.zip - name: Launcher docker uses: actions/checkout@v2 @@ -142,6 +144,7 @@ jobs: # apt -y update && apt install -y software-properties-common && apt -y update && add-apt-repository -y ppa:git-core/ppa && apt -y update && apt install -y git # pip install --upgrade pip # pip install .[sklearn,testing,onnxruntime,sentencepiece,tf-speech] +# pip install https://github.com/kpu/kenlm/archive/master.zip # # - name: Launcher docker # uses: actions/checkout@v2 @@ -200,7 +203,7 @@ jobs: apt install -y libsndfile1-dev pip install --upgrade pip pip install .[sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm] - + pip install https://github.com/kpu/kenlm/archive/master.zip - name: Launcher docker uses: actions/checkout@v2 with: @@ -256,6 +259,7 @@ jobs: # pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # pip install --upgrade pip # pip install .[sklearn,testing,sentencepiece,flax,flax-speech,vision] +# pip install https://github.com/kpu/kenlm/archive/master.zip # # - name: Launcher docker # uses: actions/checkout@v2 @@ -311,6 +315,7 @@ jobs: # apt -y update && apt install -y software-properties-common && apt -y update && add-apt-repository -y ppa:git-core/ppa && apt -y update && apt install -y git # pip install --upgrade pip # pip install .[sklearn,testing,onnxruntime,sentencepiece,tf-speech] +# pip install https://github.com/kpu/kenlm/archive/master.zip # # - name: Launcher docker # uses: actions/checkout@v2 diff --git a/.github/workflows/self-scheduled.yml b/.github/workflows/self-scheduled.yml index 5adbc4a9d45ea1..72a4b8a3d35d68 100644 --- a/.github/workflows/self-scheduled.yml +++ b/.github/workflows/self-scheduled.yml @@ -36,6 +36,7 @@ jobs: apt -y update && apt install -y libsndfile1-dev git pip install --upgrade pip pip install .[integrations,sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm] + pip install https://github.com/kpu/kenlm/archive/master.zip - name: Are GPUs recognized by our DL frameworks run: | @@ -102,6 +103,7 @@ jobs: pip install --upgrade pip pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html pip install .[flax,integrations,sklearn,testing,sentencepiece,flax-speech,vision] + pip install https://github.com/kpu/kenlm/archive/master.zip - name: Are GPUs recognized by our DL frameworks run: | @@ -141,6 +143,8 @@ jobs: apt -y update && apt install -y libsndfile1-dev git pip install --upgrade pip pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision] + pip install https://github.com/kpu/kenlm/archive/master.zip + - name: Are GPUs recognized by our DL frameworks run: | @@ -236,6 +240,7 @@ jobs: apt -y update && apt install -y libsndfile1-dev git pip install --upgrade pip pip install .[integrations,sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm] + pip install https://github.com/kpu/kenlm/archive/master.zip - name: Are GPUs recognized by our DL frameworks run: | @@ -288,6 +293,7 @@ jobs: apt -y update && apt install -y libsndfile1-dev git pip install --upgrade pip pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision] + pip install https://github.com/kpu/kenlm/archive/master.zip - name: Are GPUs recognized by our DL frameworks run: | diff --git a/docs/source/model_doc/wav2vec2.rst b/docs/source/model_doc/wav2vec2.rst index 2aef0abb86a0e9..8c3d5481ac9ffc 100644 --- a/docs/source/model_doc/wav2vec2.rst +++ b/docs/source/model_doc/wav2vec2.rst @@ -67,9 +67,19 @@ Wav2Vec2Processor :members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor +Wav2Vec2ProcessorWithLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Wav2Vec2ProcessorWithLM + :members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor + + Wav2Vec2 specific outputs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: transformers.models.wav2vec2.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput + :members: + .. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput :members: diff --git a/setup.py b/setup.py index 4d59a717f27047..b30f3dd9743c98 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ pip install -i https://testpypi.python.org/pypi transformers Check you can run the following commands: - python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))" + python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))" python -c "from transformers import *" 9. Upload the final version to actual pypi: @@ -59,7 +59,7 @@ 10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. -11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release, +11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release, you need to go back to master before executing this. """ @@ -159,6 +159,7 @@ "tokenizers>=0.10.1,<0.11", "torch>=1.0", "torchaudio", + "pyctcdecode>=0.2.0", "tqdm>=4.27", "unidic>=1.0.2", "unidic_lite>=1.0.7", @@ -262,7 +263,7 @@ def run(self): extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"] extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") -extras["audio"] = deps_list("librosa") +extras["audio"] = deps_list("librosa", "pyctcdecode") extras["speech"] = deps_list("torchaudio") + extras["audio"] # `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead extras["torch-speech"] = deps_list("torchaudio") + extras["audio"] extras["tf-speech"] = extras["audio"] diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 267a277ba49c26..2a44169f0872e9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -44,6 +44,7 @@ from .file_utils import ( _LazyModule, is_flax_available, + is_pyctcdecode_available, is_pytorch_quantization_available, is_scatter_available, is_sentencepiece_available, @@ -468,6 +469,15 @@ name for name in dir(dummy_speech_objects) if not name.startswith("_") ] +if is_pyctcdecode_available(): + _import_structure["models.wav2vec2"].append("Wav2Vec2ProcessorWithLM") +else: + from .utils import dummy_pyctcdecode_objects + + _import_structure["utils.dummy_pyctcdecode_objects"] = [ + name for name in dir(dummy_pyctcdecode_objects) if not name.startswith("_") + ] + if is_sentencepiece_available() and is_speech_available(): _import_structure["models.speech_to_text"].append("Speech2TextProcessor") else: @@ -2434,6 +2444,11 @@ else: from .utils.dummy_speech_objects import * + if is_pyctcdecode_available(): + from .models.wav2vec2 import Wav2Vec2ProcessorWithLM + else: + from .utils.dummy_pyctcdecode_objects import * + if is_speech_available() and is_sentencepiece_available(): from .models.speech_to_text import Speech2TextProcessor else: diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index b074ffe13a36ef..328f640a29c7c2 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -70,6 +70,7 @@ "tokenizers": "tokenizers>=0.10.1,<0.11", "torch": "torch>=1.0", "torchaudio": "torchaudio", + "pyctcdecode": "pyctcdecode>=0.2.0", "tqdm": "tqdm>=4.27", "unidic": "unidic>=1.0.2", "unidic_lite": "unidic_lite>=1.0.7", diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index d88da95dbbf0e5..4d082e9a5c2847 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -237,6 +237,22 @@ _torchaudio_available = False +_pyctcdecode_available = importlib.util.find_spec("pyctcdecode") is not None +try: + _pyctcdecode_version = importlib_metadata.version("pyctcdecode") + logger.debug(f"Successfully imported pyctcdecode version {_pyctcdecode_version}") +except importlib_metadata.PackageNotFoundError: + _pyctcdecode_available = False + + +_librosa_available = importlib.util.find_spec("librosa") is not None +try: + _librosa_version = importlib_metadata.version("librosa") + logger.debug(f"Successfully imported librosa version {_librosa_version}") +except importlib_metadata.PackageNotFoundError: + _librosa_available = False + + torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) old_default_cache_path = os.path.join(torch_cache_home, "transformers") # New default cache, shared with the Datasets library @@ -311,6 +327,14 @@ def is_torch_available(): return _torch_available +def is_pyctcdecode_available(): + return _pyctcdecode_available + + +def is_librosa_available(): + return _librosa_available + + def is_torch_cuda_available(): if is_torch_available(): import torch @@ -718,6 +742,12 @@ def wrapper(*args, **kwargs): `pip install pytesseract` """ +# docstyle-ignore +PYCTCDECODE_IMPORT_ERROR = """ +{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip: +`pip install pyctcdecode` +""" + BACKENDS_MAPPING = OrderedDict( [ @@ -727,6 +757,7 @@ def wrapper(*args, **kwargs): ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)), ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), + ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)), ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)), ("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)), ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)), diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index 445e9183034025..0ca789825b899a 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available +from ...file_utils import _LazyModule, is_flax_available, is_pyctcdecode_available, is_tf_available, is_torch_available _import_structure = { @@ -27,6 +27,9 @@ "tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"], } +if is_pyctcdecode_available(): + _import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"] + if is_torch_available(): _import_structure["modeling_wav2vec2"] = [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -61,6 +64,9 @@ from .processing_wav2vec2 import Wav2Vec2Processor from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer + if is_pyctcdecode_available(): + from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM + if is_torch_available(): from .modeling_wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py new file mode 100644 index 00000000000000..b0acbfbc608752 --- /dev/null +++ b/src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py @@ -0,0 +1,358 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Speech processor class for Wav2Vec2 +""" +import os +from contextlib import contextmanager +from dataclasses import dataclass +from multiprocessing import Pool +from typing import Iterable, List, Optional, Union + +import numpy as np + +from pyctcdecode import BeamSearchDecoderCTC +from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN +from pyctcdecode.constants import ( + DEFAULT_BEAM_WIDTH, + DEFAULT_HOTWORD_WEIGHT, + DEFAULT_MIN_TOKEN_LOGP, + DEFAULT_PRUNE_LOGP, +) + +from ...feature_extraction_utils import FeatureExtractionMixin +from ...file_utils import ModelOutput, requires_backends +from ...tokenization_utils import PreTrainedTokenizer +from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor +from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer + + +@dataclass +class Wav2Vec2DecoderWithLMOutput(ModelOutput): + """ + Output type of :class:`~transformers.Wav2Vec2DecoderWithLM`, with transcription. + + Args: + text (list of :obj:`str`): + Decoded logits in text from. Usually the speech transcription. + """ + + text: Union[List[str], str] + + +class Wav2Vec2ProcessorWithLM: + r""" + Constructs a Wav2Vec2 processor which wraps a Wav2Vec2 feature extractor, a Wav2Vec2 CTC tokenizer and a decoder + with language model support into a single processor for language model boosted speech recognition decoding. + + Args: + feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`): + An instance of :class:`~transformers.Wav2Vec2FeatureExtractor`. The feature extractor is a required input. + tokenizer (:class:`~transformers.Wav2Vec2CTCTokenizer`): + An instance of :class:`~transformers.Wav2Vec2CTCTokenizer`. The tokenizer is a required input. + decoder (:obj:`pyctcdecode.BeamSearchDecoderCTC`): + An instance of :class:`pyctcdecode.BeamSearchDecoderCTC`. The decoder is a required input. + """ + + def __init__( + self, + feature_extractor: FeatureExtractionMixin, + tokenizer: PreTrainedTokenizer, + decoder: BeamSearchDecoderCTC, + ): + if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor): + raise ValueError( + f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}" + ) + if not isinstance(tokenizer, Wav2Vec2CTCTokenizer): + # TODO(PVP) - this can be relaxed in the future to allow other kinds of tokenizers + raise ValueError( + f"`tokenizer` has to be of type {Wav2Vec2CTCTokenizer.__class__}, but is {type(tokenizer)}" + ) + if not isinstance(decoder, BeamSearchDecoderCTC): + raise ValueError(f"`decoder` has to be of type {BeamSearchDecoderCTC.__class__}, but is {type(decoder)}") + + # make sure that decoder's alphabet and tokenizer's vocab match in content + missing_decoder_tokens = self.get_missing_alphabet_tokens(decoder, tokenizer) + if len(missing_decoder_tokens) > 0: + raise ValueError( + f"The tokens {missing_decoder_tokens} are defined in the tokenizer's " + "vocabulary, but not in the decoder's alphabet. " + f"Make sure to include {missing_decoder_tokens} in the decoder's alphabet." + ) + + self.feature_extractor = feature_extractor + self.tokenizer = tokenizer + self.decoder = decoder + self.current_processor = self.feature_extractor + + def save_pretrained(self, save_directory): + """ + Save the Wav2Vec2 feature_extractor, a tokenizer object and a pyctcdecode decoder to the directory + ``save_directory``, so that they can be re-loaded using the + :func:`~transformers.Wav2Vec2ProcessorWithLM.from_pretrained` class method. + + .. note:: + + This class method is simply calling + :meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained,` + :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained` and pyctcdecode's + :meth:`pyctcdecode.BeamSearchDecoderCTC.save_to_dir`. + + Please refer to the docstrings of the methods above for more information. + + Args: + save_directory (:obj:`str` or :obj:`os.PathLike`): + Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will + be created if it does not exist). + """ + self.feature_extractor.save_pretrained(save_directory) + self.tokenizer.save_pretrained(save_directory) + self.decoder.save_to_dir(save_directory) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + Instantiate a :class:`~transformers.Wav2Vec2ProcessorWithLM` from a pretrained Wav2Vec2 processor. + + .. note:: + + This class method is simply calling Wav2Vec2FeatureExtractor's + :meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained`, + Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`, + and :meth:`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`. + + Please refer to the docstrings of the methods above for more information. + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + This can be either: + + - a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or + namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a feature extractor file saved using the + :meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g., + ``./my_model_directory/``. + - a path or url to a saved feature extractor JSON `file`, e.g., + ``./my_model_directory/preprocessor_config.json``. + **kwargs + Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and + :class:`~transformers.PreTrainedTokenizer` + """ + requires_backends(cls, "pyctcdecode") + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) + tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + + if os.path.isdir(pretrained_model_name_or_path): + decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path) + else: + decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs) + + # set language model attributes + for attribute in ["alpha", "beta", "unk_score_offset", "score_boundary"]: + value = kwargs.pop(attribute, None) + + if value is not None: + cls._set_language_model_attribute(decoder, attribute, value) + + # make sure that decoder's alphabet and tokenizer's vocab match in content + missing_decoder_tokens = cls.get_missing_alphabet_tokens(decoder, tokenizer) + if len(missing_decoder_tokens) > 0: + raise ValueError( + f"The tokens {missing_decoder_tokens} are defined in the tokenizer's " + "vocabulary, but not in the decoder's alphabet. " + f"Make sure to include {missing_decoder_tokens} in the decoder's alphabet." + ) + + return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder) + + @staticmethod + def _set_language_model_attribute(decoder: BeamSearchDecoderCTC, attribute: str, value: float): + setattr(decoder.model_container[decoder._model_key], attribute, value) + + @property + def language_model(self): + return self.decoder.model_container[self.decoder._model_key] + + @staticmethod + def get_missing_alphabet_tokens(decoder, tokenizer): + # we need to make sure that all of the tokenizer's except the special tokens + # are present in the decoder's alphabet. Retrieve missing alphabet token + # from decoder + tokenizer_vocab_list = list(tokenizer.get_vocab().keys()) + + # replace special tokens + for i, token in enumerate(tokenizer_vocab_list): + if BLANK_TOKEN_PTN.match(token): + tokenizer_vocab_list[i] = "" + if token == tokenizer.word_delimiter_token: + tokenizer_vocab_list[i] = " " + if UNK_TOKEN_PTN.match(token): + tokenizer_vocab_list[i] = UNK_TOKEN + + # are any of the extra tokens no special tokenizer tokens? + missing_tokens = set(tokenizer_vocab_list) - set(decoder._alphabet.labels) + + return missing_tokens + + def __call__(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's + :meth:`~transformers.Wav2Vec2FeatureExtractor.__call__` and returns its output. If used in the context + :meth:`~transformers.Wav2Vec2ProcessorWithLM.as_target_processor` this method forwards all its arguments to + Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.__call__`. Please refer to the docstring of + the above two methods for more information. + """ + return self.current_processor(*args, **kwargs) + + def pad(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's + :meth:`~transformers.Wav2Vec2FeatureExtractor.pad` and returns its output. If used in the context + :meth:`~transformers.Wav2Vec2ProcessorWithLM.as_target_processor` this method forwards all its arguments to + Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.pad`. Please refer to the docstring of the + above two methods for more information. + """ + return self.current_processor.pad(*args, **kwargs) + + def batch_decode( + self, + logits: np.ndarray, + num_processes: Optional[int] = None, + beam_width: Optional[int] = None, + beam_prune_logp: Optional[float] = None, + token_min_logp: Optional[float] = None, + hotwords: Optional[Iterable[str]] = None, + hotword_weight: Optional[float] = None, + ): + """ + Batch decode output logits to audio transcription with language model support. + + .. note:: + + This function makes use of Python's multiprocessing. + + Args: + logits (:obj:`np.ndarray`): + The logits output vector of the model representing the log probabilities for each token. + num_processes (:obj:`int`, `optional`): + Number of processes on which the function should be parallelized over. Defaults to the number of + available CPUs. + beam_width (:obj:`int`, `optional`): + Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH. + beam_prune_logp (:obj:`int`, `optional`): + Beams that are much worse than best beam will be pruned Defaults to pyctcdecode's DEFAULT_PRUNE_LOGP. + token_min_logp (:obj:`int`, `optional`): + Tokens below this logp are skipped unless they are argmax of frame Defaults to pyctcdecode's + DEFAULT_MIN_TOKEN_LOGP. + hotwords (:obj:`List[str]`, `optional`): + List of words with extra importance, can be OOV for LM + hotword_weight (:obj:`int`, `optional`): + Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT. + + Returns: + :class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`. + + """ + + # set defaults + beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH + beam_prune_logp = beam_prune_logp if beam_prune_logp is not None else DEFAULT_PRUNE_LOGP + token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP + hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT + + # create multiprocessing pool and list numpy arrays + logits_list = [array for array in logits] + pool = Pool(num_processes) + + # pyctcdecode + decoded_beams = self.decoder.decode_beams_batch( + pool, + logits_list=logits_list, + beam_width=beam_width, + beam_prune_logp=beam_prune_logp, + token_min_logp=token_min_logp, + hotwords=hotwords, + hotword_weight=hotword_weight, + ) + + # extract text + batch_texts = [d[0][0] for d in decoded_beams] + + # more output features will be added in the future + return Wav2Vec2DecoderWithLMOutput(text=batch_texts) + + def decode( + self, + logits: np.ndarray, + beam_width: Optional[int] = None, + beam_prune_logp: Optional[float] = None, + token_min_logp: Optional[float] = None, + hotwords: Optional[Iterable[str]] = None, + hotword_weight: Optional[float] = None, + ): + """ + Decode output logits to audio transcription with language model support. + + Args: + logits (:obj:`np.ndarray`): + The logits output vector of the model representing the log probabilities for each token. + beam_width (:obj:`int`, `optional`): + Maximum number of beams at each step in decoding. Defaults to pyctcdecode's DEFAULT_BEAM_WIDTH. + beam_prune_logp (:obj:`int`, `optional`): + A threshold to prune beams with log-probs less than best_beam_logp + beam_prune_logp. The value should + be <= 0. Defaults to pyctcdecode's DEFAULT_PRUNE_LOGP. + token_min_logp (:obj:`int`, `optional`): + Tokens with log-probs below token_min_logp are skipped unless they are have the maximum log-prob for an + utterance. Defaults to pyctcdecode's DEFAULT_MIN_TOKEN_LOGP. + hotwords (:obj:`List[str]`, `optional`): + List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"] + hotword_weight (:obj:`int`, `optional`): + Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT. + + Returns: + :class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`. + + """ + + # set defaults + beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH + beam_prune_logp = beam_prune_logp if beam_prune_logp is not None else DEFAULT_PRUNE_LOGP + token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP + hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT + + # pyctcdecode + decoded_beams = self.decoder.decode_beams( + logits, + beam_width=beam_width, + beam_prune_logp=beam_prune_logp, + token_min_logp=token_min_logp, + hotwords=hotwords, + hotword_weight=hotword_weight, + ) + + # more output features will be added in the future + return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0]) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning + Wav2Vec2. + """ + self.current_processor = self.tokenizer + yield + self.current_processor = self.feature_extractor diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index e5f96d830e0450..644405cc746c95 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -36,8 +36,10 @@ is_faiss_available, is_flax_available, is_keras2onnx_available, + is_librosa_available, is_onnx_available, is_pandas_available, + is_pyctcdecode_available, is_pytesseract_available, is_pytorch_quantization_available, is_rjieba_available, @@ -589,6 +591,26 @@ def require_deepspeed(test_case): return test_case +def require_pyctcdecode(test_case): + """ + Decorator marking a test that requires pyctcdecode + """ + if not is_pyctcdecode_available(): + return unittest.skip("test requires pyctcdecode")(test_case) + else: + return test_case + + +def require_librosa(test_case): + """ + Decorator marking a test that requires librosa + """ + if not is_librosa_available(): + return unittest.skip("test requires librosa")(test_case) + else: + return test_case + + def get_gpu_count(): """ Return the number of available gpus (regardless of whether torch or tf is used) diff --git a/src/transformers/utils/dummy_pyctcdecode_objects.py b/src/transformers/utils/dummy_pyctcdecode_objects.py new file mode 100644 index 00000000000000..fee38b3dac5dec --- /dev/null +++ b/src/transformers/utils/dummy_pyctcdecode_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..file_utils import requires_backends + + +class Wav2Vec2ProcessorWithLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pyctcdecode"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pyctcdecode"]) diff --git a/tests/test_modeling_flax_wav2vec2.py b/tests/test_modeling_flax_wav2vec2.py index d75891a1b0d5bf..f0805e1742f3a2 100644 --- a/tests/test_modeling_flax_wav2vec2.py +++ b/tests/test_modeling_flax_wav2vec2.py @@ -17,9 +17,19 @@ import unittest import numpy as np +from datasets import load_dataset from transformers import Wav2Vec2Config, is_flax_available -from transformers.testing_utils import require_datasets, require_flax, require_soundfile, slow +from transformers.testing_utils import ( + is_librosa_available, + is_pyctcdecode_available, + require_datasets, + require_flax, + require_librosa, + require_pyctcdecode, + require_soundfile, + slow, +) from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask @@ -39,6 +49,14 @@ ) +if is_pyctcdecode_available(): + from transformers import Wav2Vec2ProcessorWithLM + + +if is_librosa_available(): + import librosa + + class FlaxWav2Vec2ModelTester: def __init__( self, @@ -354,8 +372,6 @@ def test_sample_negatives_with_attn_mask(self): @slow class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): def _load_datasamples(self, num_samples): - from datasets import load_dataset - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech speech_samples = ds.sort("id").filter( @@ -447,3 +463,22 @@ def test_inference_pretrained(self): # a random wav2vec2 model has not learned to predict the quantized latent states # => the cosine similarity between quantized states and predicted states is very likely < 0.1 self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0) + + @require_pyctcdecode + @require_librosa + def test_wav2vec2_with_lm(self): + ds = load_dataset("common_voice", "es", split="test", streaming=True) + sample = next(iter(ds)) + + resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000) + + model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(resampled_audio, return_tensors="np").input_values + + logits = model(input_values).logits + + transcription = processor.batch_decode(np.array(logits)).text + + self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") diff --git a/tests/test_modeling_tf_wav2vec2.py b/tests/test_modeling_tf_wav2vec2.py index 46f877f06357c1..a349b9ab82b852 100644 --- a/tests/test_modeling_tf_wav2vec2.py +++ b/tests/test_modeling_tf_wav2vec2.py @@ -21,9 +21,11 @@ import numpy as np import pytest +from datasets import load_dataset from transformers import Wav2Vec2Config, is_tf_available -from transformers.testing_utils import require_datasets, require_soundfile, require_tf, slow +from transformers.file_utils import is_librosa_available, is_pyctcdecode_available +from transformers.testing_utils import require_datasets, require_librosa, require_pyctcdecode, require_tf, slow from .test_configuration_common import ConfigTester from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor @@ -36,6 +38,14 @@ from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices +if is_pyctcdecode_available(): + from transformers import Wav2Vec2ProcessorWithLM + + +if is_librosa_available(): + import librosa + + @require_tf class TFWav2Vec2ModelTester: def __init__( @@ -474,7 +484,6 @@ def test_compute_mask_indices_overlap(self): @require_tf @slow @require_datasets -@require_soundfile class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): def _load_datasamples(self, num_samples): from datasets import load_dataset @@ -544,3 +553,22 @@ def test_inference_ctc_robust_batched(self): "his instant panic was followed by a small sharp blow high on his chest", ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) + + @require_pyctcdecode + @require_librosa + def test_wav2vec2_with_lm(self): + ds = load_dataset("common_voice", "es", split="test", streaming=True) + sample = next(iter(ds)) + + resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000) + + model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(resampled_audio, return_tensors="tf").input_values + + logits = model(input_values).logits + + transcription = processor.batch_decode(logits.numpy()).text + + self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index 278465341a48d1..c3a0271bd0ece7 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -18,15 +18,19 @@ import unittest import numpy as np -import pytest +from datasets import load_dataset from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from transformers import Wav2Vec2Config, is_torch_available from transformers.testing_utils import ( is_pt_flax_cross_test, + is_pyctcdecode_available, + is_torchaudio_available, require_datasets, + require_pyctcdecode, require_soundfile, require_torch, + require_torchaudio, slow, torch_device, ) @@ -54,6 +58,14 @@ ) +if is_torchaudio_available(): + import torchaudio + + +if is_pyctcdecode_available(): + from transformers import Wav2Vec2ProcessorWithLM + + class Wav2Vec2ModelTester: def __init__( self, @@ -331,7 +343,7 @@ def check_labels_out_of_vocab(self, config, input_values, *args): max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size + 100) - with pytest.raises(ValueError): + with self.parent.assertRaises(ValueError): model(input_values, labels=labels) def prepare_config_and_inputs_for_common(self): @@ -998,8 +1010,6 @@ def test_sample_negatives_with_mask(self): @slow class Wav2Vec2ModelIntegrationTest(unittest.TestCase): def _load_datasamples(self, num_samples): - from datasets import load_dataset - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech speech_samples = ds.sort("id").filter( @@ -1009,8 +1019,6 @@ def _load_datasamples(self, num_samples): return [x["array"] for x in speech_samples] def _load_superb(self, task, num_samples): - from datasets import load_dataset - ds = load_dataset("anton-l/superb_dummy", task, split="test") return ds[:num_samples] @@ -1337,3 +1345,27 @@ def test_inference_emotion_recognition(self): self.assertListEqual(predicted_ids.tolist(), expected_labels) self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2)) + + @require_pyctcdecode + @require_torchaudio + def test_wav2vec2_with_lm(self): + ds = load_dataset("common_voice", "es", split="test", streaming=True) + sample = next(iter(ds)) + + resampled_audio = torchaudio.functional.resample( + torch.tensor(sample["audio"]["array"]), 48_000, 16_000 + ).numpy() + + model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to( + torch_device + ) + processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm") + + input_values = processor(resampled_audio, return_tensors="pt").input_values + + with torch.no_grad(): + logits = model(input_values.to(torch_device)).logits + + transcription = processor.batch_decode(logits.cpu().numpy()).text + + self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py new file mode 100644 index 00000000000000..155e09a22eb3d8 --- /dev/null +++ b/tests/test_processor_wav2vec2_with_lm.py @@ -0,0 +1,236 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import tempfile +import unittest +from multiprocessing import Pool + +import numpy as np + +from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available +from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor +from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES +from transformers.testing_utils import require_pyctcdecode + +from .test_feature_extraction_wav2vec2 import floats_list + + +if is_pyctcdecode_available(): + from pyctcdecode import BeamSearchDecoderCTC + from transformers.models.wav2vec2 import Wav2Vec2ProcessorWithLM + + +@require_pyctcdecode +class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): + def setUp(self): + vocab = "| a b c d e f g h i j k".split() + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + + self.add_kwargs_tokens_map = { + "unk_token": "", + "bos_token": "", + "eos_token": "", + } + feature_extractor_map = { + "feature_size": 1, + "padding_value": 0.0, + "sampling_rate": 16000, + "return_attention_mask": False, + "do_normalize": True, + } + + self.tmpdirname = tempfile.mkdtemp() + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + + with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(feature_extractor_map) + "\n") + + # load decoder from hub + self.decoder_name = "hf-internal-testing/ngram-beam-search-decoder" + + def get_tokenizer(self, **kwargs_init): + kwargs = self.add_kwargs_tokens_map.copy() + kwargs.update(kwargs_init) + return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs) + + def get_feature_extractor(self, **kwargs): + return Wav2Vec2FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) + + def get_decoder(self, **kwargs): + return BeamSearchDecoderCTC.load_from_hf_hub(self.decoder_name, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_save_load_pretrained_default(self): + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + processor.save_pretrained(self.tmpdirname) + processor = Wav2Vec2ProcessorWithLM.from_pretrained(self.tmpdirname) + + # tokenizer + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer) + + # feature extractor + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor.feature_extractor, Wav2Vec2FeatureExtractor) + + # decoder + self.assertEqual(processor.decoder._alphabet.labels, decoder._alphabet.labels) + self.assertEqual( + processor.decoder.model_container[decoder._model_key]._unigram_set, + decoder.model_container[decoder._model_key]._unigram_set, + ) + self.assertIsInstance(processor.decoder, BeamSearchDecoderCTC) + + def test_save_load_pretrained_additional_features(self): + processor = Wav2Vec2ProcessorWithLM( + tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor(), decoder=self.get_decoder() + ) + processor.save_pretrained(self.tmpdirname) + + # make sure that error is thrown when decoder alphabet doesn't match + processor = Wav2Vec2ProcessorWithLM.from_pretrained( + self.tmpdirname, alpha=5.0, beta=3.0, score_boundary=-7.0, unk_score_offset=3 + ) + + # decoder + self.assertEqual(processor.language_model.alpha, 5.0) + self.assertEqual(processor.language_model.beta, 3.0) + self.assertEqual(processor.language_model.score_boundary, -7.0) + self.assertEqual(processor.language_model.unk_score_offset, 3) + + def test_load_decoder_tokenizer_mismatch_content(self): + tokenizer = self.get_tokenizer() + # add token to trigger raise + tokenizer.add_tokens(["xx"]) + with self.assertRaisesRegex(ValueError, "include"): + Wav2Vec2ProcessorWithLM( + tokenizer=tokenizer, feature_extractor=self.get_feature_extractor(), decoder=self.get_decoder() + ) + + def test_feature_extractor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + raw_speech = floats_list((3, 1000)) + + input_feat_extract = feature_extractor(raw_speech, return_tensors="np") + input_processor = processor(raw_speech, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + input_str = "This is a test string" + + with processor.as_target_processor(): + encoded_processor = processor(input_str) + + encoded_tok = tokenizer(input_str) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key]) + + def _get_dummy_logits(self, shape=(2, 10, 16), seed=77): + np.random.seed(seed) + return np.random.rand(*shape) + + def test_decoder(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + logits = self._get_dummy_logits(shape=(10, 16), seed=13) + + decoded_processor = processor.decode(logits).text + + decoded_decoder = decoder.decode_beams(logits)[0][0] + + self.assertEqual(decoded_decoder, decoded_processor) + self.assertEqual(" ", decoded_processor) + + def test_decoder_batch(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + logits = self._get_dummy_logits() + + decoded_processor = processor.batch_decode(logits).text + + logits_list = [array for array in logits] + decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(Pool(), logits_list)] + + self.assertListEqual(decoded_decoder, decoded_processor) + self.assertListEqual([" ", " "], decoded_processor) + + def test_decoder_with_params(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + logits = self._get_dummy_logits() + + beam_width = 20 + beam_prune_logp = -20.0 + token_min_logp = -4.0 + + decoded_processor_out = processor.batch_decode( + logits, + beam_width=beam_width, + beam_prune_logp=beam_prune_logp, + token_min_logp=token_min_logp, + ) + decoded_processor = decoded_processor_out.text + + logits_list = [array for array in logits] + decoded_decoder_out = decoder.decode_beams_batch( + Pool(), + logits_list, + beam_width=beam_width, + beam_prune_logp=beam_prune_logp, + token_min_logp=token_min_logp, + ) + + decoded_decoder = [d[0][0] for d in decoded_decoder_out] + + self.assertListEqual(decoded_decoder, decoded_processor) + self.assertListEqual([" ", " "], decoded_processor)