Skip to content

Commit

Permalink
Speech2TextTransformer (#10175)
Browse files Browse the repository at this point in the history
* s2t

* fix config

* conversion script

* fix import

* add tokenizer

* fix tok init

* fix tokenizer

* first version working

* fix embeds

* fix lm head

* remove extra heads

* fix convert script

* handle encoder attn mask

* style

* better enc attn mask

* override _prepare_attention_mask_for_generation

* handle attn_maks in encoder and decoder

* input_ids => input_features

* enable use_cache

* remove old code

* expand embeddings if needed

* remove logits bias

* masked_lm_loss => loss

* hack tokenizer to support feature processing

* fix model_input_names

* style

* fix error message

* doc

* remove inputs_embeds

* remove input_embeds

* remove unnecessary docstring

* quality

* SpeechToText => Speech2Text

* style

* remove shared_embeds

* subsample => conv

* remove Speech2TextTransformerDecoderWrapper

* update output_lengths formula

* fix table

* remove max_position_embeddings

* update conversion scripts

* add possibility to do upper case for now

* add FeatureExtractor and Processor

* add tests for extractor

* require_torch_audio => require_torchaudio

* add processor test

* update import

* remove classification head

* attention mask is now 1D

* update docstrings

* attention mask should be of type long

* handle attention mask from generate

* alwyas return attention_mask

* fix test

* style

* doc

* Speech2TextTransformer => Speech2Text

* Speech2TextTransformerConfig => Speech2TextConfig

* remove dummy_inputs

* nit

* style

* multilinguial tok

* fix tokenizer

* add tgt_lang setter

* save lang_codes

* fix tokenizer

* add forced_bos_token_id to tokenizer

* apply review suggestions

* add torchaudio to extra deps

* add speech deps to CI

* fix dep

* add libsndfile to ci

* libsndfile1

* add speech to extras all

* libsndfile1 -> libsndfile1

* libsndfile

* libsndfile1-dev

* apt update

* add sudo to install

* update deps table

* install libsndfile1-dev on CI

* tuple to list

* init conv layer

* add model tests

* quality

* add integration tests

* skip_special_tokens

* add speech_to_text_transformer in toctree

* fix tokenizer

* fix fp16 tests

* add tokenizer tests

* fix copyright

* input_values => input_features

* doc

* add model in readme

* doc

* change checkpoint names

* fix copyright

* fix code example

* add max_model_input_sizes in tokenizer

* fix integration tests

* add do_lower_case to tokenizer

* remove clamp trick

* fix "Add modeling imports here"

* fix copyrights

* fix tests

* SpeechToTextTransformer => SpeechToText

* fix naming

* fix table formatting

* fix typo

* style

* fix typos

* remove speech dep from extras[testing]

* fix copies

* rename doc file,

* put imports under is_torch_available

* run feat extract tests when torch is available

* dummy objects for processor and extractor

* fix imports in tests

* fix import in modeling test

* fxi imports

* fix torch import

* fix imports again

* fix positional embeddings

* fix typo in import

* adapt new extractor refactor

* style

* fix torchscript test

* doc

* doc

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix docs, copied from, style

* fix docstring

* handle imports

* remove speech from all extra deps

* remove s2t from seq2seq lm mapping

* better names

* skip training tests

* add install instructions

* List => Tuple

* doc

* fix conversion script

* fix urls

* add instruction for libsndfile

* fix fp16 test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 10, 2021
1 parent efb5c0a commit d26b37e
Show file tree
Hide file tree
Showing 30 changed files with 3,833 additions and 24 deletions.
13 changes: 9 additions & 4 deletions .circleci/config.yml
Expand Up @@ -77,8 +77,9 @@ jobs:
keys:
- v0.4-torch_and_tf-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece]
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
Expand All @@ -104,8 +105,9 @@ jobs:
keys:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install .[sklearn,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
Expand Down Expand Up @@ -157,8 +159,9 @@ jobs:
keys:
- v0.4-flax-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece]
- run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece,speech]
- save_cache:
key: v0.4-flax-{{ checksum "setup.py" }}
paths:
Expand All @@ -183,8 +186,9 @@ jobs:
keys:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install .[sklearn,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
Expand Down Expand Up @@ -300,6 +304,7 @@ jobs:
keys:
- v0.4-build_doc-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install ."[all, docs]"
- save_cache:
Expand Down
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -227,6 +227,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[ProphetNet](https://huggingface.co/transformers/model_doc/prophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
1. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SpeechToTextTransformer](https://huggingface.co/transformers/model_doc/speech_to_text.html)** (from Facebook), released together with the paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
Expand Down
24 changes: 15 additions & 9 deletions docs/source/index.rst
Expand Up @@ -191,31 +191,34 @@ and conversion utilities for the following models:
36. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
37. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
37. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
38. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
Krishna, and Kurt W. Keutzer.
38. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
39. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
39. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
40. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
Francesco Piccinno and Julian Martin Eisenschlos.
40. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
41. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
41. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
42. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
Zhou, Abdelrahman Mohamed, Michael Auli.
42. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
43. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
43. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
44. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
44. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
45. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
Zettlemoyer and Veselin Stoyanov.
45. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
46. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.

Expand Down Expand Up @@ -304,6 +307,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RoBERTa ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Speech2Text ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SqueezeBERT ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| T5 ||||||
Expand Down Expand Up @@ -436,6 +441,7 @@ TensorFlow and/or Flax.
model_doc/reformer
model_doc/retribert
model_doc/roberta
model_doc/speech_to_text
model_doc/squeezebert
model_doc/t5
model_doc/tapas
Expand Down
152 changes: 152 additions & 0 deletions docs/source/model_doc/speech_to_text.rst
@@ -0,0 +1,152 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

Speech2Text
-----------------------------------------------------------------------------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The Speech2Text model was proposed in `fairseq S2T: Fast Speech-to-Text Modeling with fairseq
<https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. It's a
transformer-based seq2seq (encoder-decoder) model designed for end-to-end Automatic Speech Recognition (ASR) and Speech
Translation (ST). It uses a convolutional downsampler to reduce the length of speech inputs by 3/4th before they are
fed into the encoder. The model is trained with standard autoregressive cross-entropy loss and generates the
transcripts/translations autoregressively. Speech2Text has been fine-tuned on several datasets for ASR and ST:
`LibriSpeech <http://www.openslr.org/12>`__, `CoVoST 2 <https://github.com/facebookresearch/covost>`__, `MuST-C
<https://ict.fbk.eu/must-c/>`__.

The original code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/speech_to_text>`__.


Inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Speech2Text is a speech model that accepts a float tensor of log-mel filter-bank features extracted from the speech
signal. It's a transformer-based seq2seq model, so the transcripts/translations are generated autoregressively. The
:obj:`generate()` method can be used for inference.

The :class:`~transformers.Speech2TextFeatureExtractor` class is responsible for extracting the log-mel filter-bank
features. The :class:`~transformers.Speech2TextProcessor` wraps :class:`~transformers.Speech2TextFeatureExtractor` and
:class:`~transformers.Speech2TextTokenizer` into a single instance to both extract the input features and decode the
predicted token ids.

The feature extractor depends on :obj:`torchaudio` and the tokenizer depends on :obj:`sentencepiece` so be sure to
install those packages before running the examples. You could either install those as extra speech dependancies with
``pip install transformers"[speech, sentencepiece]"`` or install the packages seperatly with ``pip install torchaudio
sentencepiece``. Also ``torchaudio`` requires the development version of the `libsndfile
<http://www.mega-nerd.com/libsndfile/>`__ package which can be installed via a system package manager. On Ubuntu it can
be installed as follows: ``apt install libsndfile1-dev``


- ASR and Speech Translation

.. code-block::
>>> import torch
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
>>> processor = Speech2Textprocessor.from_pretrained("facebook/s2t-small-librispeech-asr")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_features = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt").input_features # Batch size 1
>>> generated_ids = model.generate(input_ids=input_features)
>>> transcription = processor.batch_decode(generated_ids)
- Multilingual speech translation

For multilingual speech translation models, :obj:`eos_token_id` is used as the :obj:`decoder_start_token_id` and
the target language id is forced as the first generated token. To force the target language id as the first
generated token, pass the :obj:`forced_bos_token_id` parameter to the :obj:`generate()` method. The following
example shows how to transate English speech to French text using the `facebook/s2t-medium-mustc-multilingual-st`
checkpoint.

.. code-block::
>>> import torch
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
>>> processor = Speech2Textprocessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_features = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt").input_features # Batch size 1
>>> generated_ids = model.generate(input_ids=input_features, forced_bos_token_id=processor.tokenizer.lang_code_to_id["fr"])
>>> translation = processor.batch_decode(generated_ids)
See the `model hub <https://huggingface.co/models?filter=speech_to_text>`__ to look for Speech2Text checkpoints.


Speech2TextConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Speech2TextConfig
:members:


Speech2TextTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Speech2TextTokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, save_vocabulary


Speech2TextFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Speech2TextFeatureExtractor
:members: __call__


Speech2TextProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Speech2TextProcessor
:members: __call__, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor


Speech2TextModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Speech2TextModel
:members: forward


Speech2TextForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Speech2TextForConditionalGeneration
:members: forward
1 change: 1 addition & 0 deletions setup.cfg
Expand Up @@ -35,6 +35,7 @@ known_third_party =
tensorflow_datasets
timeout_decorator
torch
torchaudio
torchtext
torchvision
torch_xla
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -134,6 +134,7 @@
"timeout-decorator",
"tokenizers>=0.10.1,<0.11",
"torch>=1.0",
"torchaudio",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
Expand Down Expand Up @@ -227,14 +228,13 @@ def run(self):
extras["modelcreation"] = deps_list("cookiecutter")

extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["speech"] = deps_list("soundfile")
extras["speech"] = deps_list("soundfile", "torchaudio")

extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = (
deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets")
+ extras["retrieval"]
+ extras["modelcreation"]
+ extras["speech"]
)
extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton")
extras["quality"] = deps_list("black", "isort", "flake8")
Expand Down

0 comments on commit d26b37e

Please sign in to comment.