-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
All chunks except the first one ignore num_beams
in Whisper long-form transcription
#29312
Comments
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
This sounds like a pretty huge bug, do you have a reproducer? |
@ArthurZucker How is this for a reproducer? import datasets
from transformers import AutoProcessor, GenerationMixin, WhisperForConditionalGeneration
import numpy as np
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
orig_generate = GenerationMixin.generate
def generate(self, *args, **kwargs):
print("num_beams:", args[1].num_beams)
return orig_generate(self, *args, **kwargs)
GenerationMixin.generate = generate
ds = datasets.load_dataset(
"google/fleurs", "en_us", split="test", trust_remote_code=True
)
ds = ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16000))
raw_audio = np.concatenate([x["array"].astype(np.float32) for x in ds[:16]["audio"]])
inputs = processor(
[raw_audio],
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=16_000,
)
model.generate(
**inputs,
num_beams=2,
task="transcribe",
language="en",
) Before fix:
After fix:
I've tried to modify one of the long-form transcription tests to check that |
IT's perfect, thanks! |
System Info
transformers
version: 4.39.0.dev0 (dabe855
)Who can help?
@patrickvonplaten @sanchit-gandhi
Reproduction
In
WhisperGenerationMixin.generate_with_fallback()
, thekwargs
argument is modified in-place bykwargs.pop("num_beams", 1)
, which results in the following problems:num_beams=1
(although this has no effect provided that temperature 0 appears only once in the temperature list).kwargs
is passed by reference (and not via the**kwargs
syntax), modifications to it propagate to the caller and to subsequent calls ofgenerate_with_fallback
, so every chunk except for the first one will always usenum_beams=1
.Expected behavior
The specified
num_beams
should be used for all chunks.The text was updated successfully, but these errors were encountered: