Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2Vec2] PyCTCDecode Integration to support language model boosted decoding #14339

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6de1445
up
patrickvonplaten Nov 9, 2021
b68faa9
up
patrickvonplaten Nov 9, 2021
8294efa
up
patrickvonplaten Nov 9, 2021
6ec01c2
make it cleaner
patrickvonplaten Dec 2, 2021
52afd82
correct
patrickvonplaten Dec 2, 2021
e3b0fde
make styhahalal
patrickvonplaten Dec 2, 2021
e7eb51c
add more tests
patrickvonplaten Dec 3, 2021
ff0de09
finish
patrickvonplaten Dec 3, 2021
6296938
small fix
patrickvonplaten Dec 3, 2021
84bfdf3
make style
patrickvonplaten Dec 3, 2021
4caf406
up
patrickvonplaten Dec 3, 2021
d59b594
tryout to solve cicrle ci
patrickvonplaten Dec 3, 2021
ead3873
Merge branch 'master' into pyctcdecode_integration
patrickvonplaten Dec 3, 2021
682b258
up
patrickvonplaten Dec 3, 2021
6320a5a
Merge branch 'pyctcdecode_integration' of https://github.com/patrickv…
patrickvonplaten Dec 3, 2021
53aaeff
fix more tests
patrickvonplaten Dec 3, 2021
7b24cdc
fix more tests
patrickvonplaten Dec 3, 2021
f3648f6
apply sylvains suggestions
patrickvonplaten Dec 6, 2021
f39f02c
fix import
patrickvonplaten Dec 6, 2021
19a1301
correct docs
patrickvonplaten Dec 6, 2021
88783e3
add pyctcdecode only to speech tests
patrickvonplaten Dec 6, 2021
51f3dc7
fix more tests
patrickvonplaten Dec 6, 2021
ceb6ea2
add tf, flax and pt tests
patrickvonplaten Dec 6, 2021
e2b19af
add pt
patrickvonplaten Dec 6, 2021
b1ba5dd
fix last tests
patrickvonplaten Dec 6, 2021
a52f319
fix more tests
patrickvonplaten Dec 6, 2021
66dd6d8
Apply suggestions from code review
patrickvonplaten Dec 6, 2021
d9cdb5e
change lines
patrickvonplaten Dec 6, 2021
b93b954
Merge branch 'pyctcdecode_integration' of https://github.com/patrickv…
patrickvonplaten Dec 6, 2021
0fe15e1
Apply suggestions from code review
patrickvonplaten Dec 6, 2021
2382b92
correct tests
patrickvonplaten Dec 6, 2021
8e70208
Merge branch 'pyctcdecode_integration' of https://github.com/patrickv…
patrickvonplaten Dec 6, 2021
b46df6b
correct tests
patrickvonplaten Dec 6, 2021
776d152
add doc string
patrickvonplaten Dec 8, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions docs/source/model_doc/wav2vec2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,13 @@ Wav2Vec2Processor
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor


Wav2Vec2 specific outputs
Wav2Vec2ProcessorWithLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
:members:

.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput
:members:
.. autoclass:: transformers.Wav2Vec2ProcessorWithLM
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor

.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2BaseModelOutput
:members:

patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2ForPreTrainingOutput
:members:


Wav2Vec2Model
Expand Down Expand Up @@ -143,3 +136,19 @@ FlaxWav2Vec2ForPreTraining

.. autoclass:: transformers.FlaxWav2Vec2ForPreTraining
:members: __call__

Wav2Vec2 specific outputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all other doc pages, those go before the models, so let's leave them here for now. We can decide to switch them after the model but all models together?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good - reverting this

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
:members:

.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput
:members:

.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2BaseModelOutput
:members:

.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2ForPreTrainingOutput
:members:

3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
"tokenizers>=0.10.1,<0.11",
"torch>=1.0,<1.10",
"torchaudio",
"pyctcdecode>=0.2.0",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
Expand Down Expand Up @@ -256,7 +257,7 @@ def run(self):
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["audio"] = deps_list("librosa")
extras["speech"] = deps_list("torchaudio") + extras["audio"] # `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
extras["torch-speech"] = deps_list("torchaudio") + extras["audio"]
extras["torch-speech"] = deps_list("torchaudio", "pyctcdecode") + extras["audio"]
extras["tf-speech"] = extras["audio"]
extras["flax-speech"] = extras["audio"]
extras["vision"] = deps_list("Pillow")
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@
"Wav2Vec2CTCTokenizer",
"Wav2Vec2FeatureExtractor",
"Wav2Vec2Processor",
"Wav2Vec2ProcessorWithLM",
"Wav2Vec2Tokenizer",
],
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
Expand Down Expand Up @@ -2191,6 +2192,7 @@
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
Wav2Vec2ProcessorWithLM,
Wav2Vec2Tokenizer,
)
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"tokenizers": "tokenizers>=0.10.1,<0.11",
"torch": "torch>=1.0,<1.10",
"torchaudio": "torchaudio",
"pyctcdecode": "pyctcdecode>=0.2.0",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@
_torchaudio_available = False


_pyctcdecode_available = importlib.util.find_spec("pyctcdecode") is not None
try:
_pyctcdecode_version = importlib_metadata.version("pyctcdecode")
logger.debug(f"Successfully imported pyctcdecode version {_pyctcdecode_version}")
except importlib_metadata.PackageNotFoundError:
_pyctcdecode_available = False


torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
Expand Down Expand Up @@ -294,6 +302,10 @@ def is_torch_available():
return _torch_available


def is_pyctcdecode_available():
return _pyctcdecode_available


def is_torch_cuda_available():
if is_torch_available():
import torch
Expand Down Expand Up @@ -650,6 +662,12 @@ def wrapper(*args, **kwargs):
`pip install pytesseract`
"""

# docstyle-ignore
PYCTCDECODE_IMPORT_ERROR = """
{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
`pip install pyctcdecode`
"""


BACKENDS_MAPPING = OrderedDict(
[
Expand All @@ -659,6 +677,7 @@ def wrapper(*args, **kwargs):
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/wav2vec2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"],
"feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"],
"processing_wav2vec2": ["Wav2Vec2Processor"],
"processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"],
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
}

Expand Down Expand Up @@ -59,6 +60,7 @@
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .processing_wav2vec2 import Wav2Vec2Processor
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer

if is_torch_available():
Expand Down
Loading