Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ def language_to_id(language: str) -> int:
# if task is defined it'll overwrite task ids that might have already been defined via the generation_config
replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
else:
raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
raise ValueError(f"The `{task}` task is not supported. The task should be one of `{TASK_IDS}`")
elif language is not None and hasattr(generation_config, "task_to_id"):
# if language is defined, but no task id is in `init_tokens`, default to transcribe
if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
Expand Down
39 changes: 23 additions & 16 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,13 @@ def _sanitize_parameters(
decoder_kwargs=None,
return_timestamps=None,
return_language=None,
generate_kwargs=None,
**generate_kwargs,
):
# No parameters on this pipeline right now
preprocess_params = {}
forward_params = {}
postprocess_params = {}

# Preprocess params
if chunk_length_s is not None:
if self.type in ["seq2seq", "seq2seq_whisper"] and not ignore_warning:
type_warning = (
Expand All @@ -305,14 +308,28 @@ def _sanitize_parameters(
if stride_length_s is not None:
preprocess_params["stride_length_s"] = stride_length_s

forward_params = defaultdict(dict)
if generate_kwargs is not None:
forward_params.update(generate_kwargs)
# Forward params
# BC: accept a dictionary of generation kwargs (as opposed to **generate_kwargs)
if "generate_kwargs" in generate_kwargs:
forward_params.update(generate_kwargs.pop("generate_kwargs"))
# Default use for kwargs: they are generation-time kwargs
forward_params.update(generate_kwargs)

postprocess_params = {}
if getattr(self, "assistant_model", None) is not None:
forward_params["assistant_model"] = self.assistant_model
if getattr(self, "assistant_tokenizer", None) is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer

# Postprocess params
if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_language is not None:
if self.type != "seq2seq_whisper":
raise ValueError("Only Whisper can return language for now.")
postprocess_params["return_language"] = return_language

# Parameter used in more than one place
# in some models like whisper, the generation config has a `return_timestamps` key
if hasattr(self, "generation_config") and hasattr(self.generation_config, "return_timestamps"):
return_timestamps = return_timestamps or self.generation_config.return_timestamps
Expand All @@ -335,16 +352,6 @@ def _sanitize_parameters(
)
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if return_language is not None:
if self.type != "seq2seq_whisper":
raise ValueError("Only Whisper can return language for now.")
postprocess_params["return_language"] = return_language

if getattr(self, "assistant_model", None) is not None:
forward_params["assistant_model"] = self.assistant_model
if getattr(self, "assistant_tokenizer", None) is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer

return preprocess_params, forward_params, postprocess_params

Expand Down
26 changes: 26 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,32 @@ def test_pipeline_assisted_generation(self):
with self.assertRaises(ValueError):
_ = pipe(prompt, generate_kwargs={"num_beams": 2})

@require_torch
def test_pipeline_generation_kwargs(self):
"""Tests that we can pass kwargs to `generate`, as in the text generation pipelines"""
model = "openai/whisper-tiny"
asr = pipeline("automatic-speech-recognition", model=model)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")

# BC: with `generate_kwargs` as a dictionary
res = asr(
dataset[0]["audio"],
generate_kwargs={"task": "transcribe", "max_new_tokens": 256},
)
self.assertEqual(
res["text"], " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)

# New: kwargs forwarded to `generate`
res = asr(
dataset[0]["audio"],
max_new_tokens=256,
task="transcribe",
)
self.assertEqual(
res["text"], " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)


def require_ffmpeg(test_case):
"""
Expand Down