Skip to content

Commit

Permalink
Fix whisper for pipeline (#19482)
Browse files Browse the repository at this point in the history
* update feature extractor params

* update attention mask handling

* fix doc and pipeline test

* add warning when skipping test

* add whisper translation and transcription test

* fix build doc test
  • Loading branch information
ArthurZucker authored and sgugger committed Oct 11, 2022
1 parent 9ae22fe commit c8bc0a0
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 24 deletions.
14 changes: 14 additions & 0 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Expand Up @@ -218,6 +218,7 @@ def __call__(
return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "max_length",
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
**kwargs
) -> BatchFeature:
"""
Expand Down Expand Up @@ -261,6 +262,19 @@ def __call__(
The value that is used to fill the padding values / vectors.
"""

if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)

is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
Expand Down
40 changes: 18 additions & 22 deletions src/transformers/models/whisper/modeling_whisper.py
Expand Up @@ -31,13 +31,22 @@
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_whisper import WhisperConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "WhisperConfig"
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_PROCESSOR_FOR_DOC = "openai/whisper-tiny"
_EXPECTED_OUTPUT_SHAPE = [1, 2, 512]


WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
Expand Down Expand Up @@ -982,7 +991,14 @@ def get_decoder(self):
return self.decoder

@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
modality="audio",
)
def forward(
self,
input_features=None,
Expand All @@ -999,26 +1015,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import WhisperModel, WhisperFeatureExtractor
>>> from datasets import load_dataset
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
48 changes: 48 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Expand Up @@ -26,6 +26,8 @@
AutoTokenizer,
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
WhisperForConditionalGeneration,
WhisperProcessor,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
Expand Down Expand Up @@ -308,6 +310,52 @@ def test_simple_s2t(self):
output = asr(data)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})

@slow
@require_torch
@require_torchaudio
def test_simple_whisper_asr(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to"})

@slow
@require_torch
@require_torchaudio
def test_simple_whisper_translation(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-large",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-large")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large")

speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_2 = speech_recognizer_2(filename)
self.assertEqual(output, output_2)

processor = WhisperProcessor(feature_extractor, tokenizer)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
speech_translator = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_3 = speech_translator(filename)
self.assertEqual(output_3, {"text": " Un uomo ha detto allo universo, Sir, esiste."})

@slow
@require_torch
@require_torchaudio
Expand Down
12 changes: 10 additions & 2 deletions tests/pipelines/test_pipelines_common.py
Expand Up @@ -178,8 +178,16 @@ def __repr__(self):
class PipelineTestCaseMeta(type):
def __new__(mcs, name, bases, dct):
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class):
@skipIf(tiny_config is None, "TinyConfig does not exist")
@skipIf(checkpoint is None, "checkpoint does not exist")
@skipIf(
tiny_config is None,
"TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling"
" file",
)
@skipIf(
checkpoint is None,
"checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the"
" modeling file",
)
def test(self):
if ModelClass.__name__.endswith("ForCausalLM"):
tiny_config.is_encoder_decoder = False
Expand Down

0 comments on commit c8bc0a0

Please sign in to comment.