Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All chunks except the first one ignore num_beams in Whisper long-form transcription #29312

Closed
cifkao opened this issue Feb 27, 2024 · 4 comments · Fixed by #29225
Closed

All chunks except the first one ignore num_beams in Whisper long-form transcription #29312

cifkao opened this issue Feb 27, 2024 · 4 comments · Fixed by #29225

Comments

@cifkao
Copy link
Contributor

cifkao commented Feb 27, 2024

System Info

  • transformers version: 4.39.0.dev0 (dabe855)

Who can help?

@patrickvonplaten @sanchit-gandhi

Reproduction

In WhisperGenerationMixin.generate_with_fallback(), the kwargs argument is modified in-place by kwargs.pop("num_beams", 1), which results in the following problems:

  1. Subsequent iterations of the fallback loop will use num_beams=1 (although this has no effect provided that temperature 0 appears only once in the temperature list).
  2. More importantly, since kwargs is passed by reference (and not via the **kwargs syntax), modifications to it propagate to the caller and to subsequent calls of generate_with_fallback, so every chunk except for the first one will always use num_beams=1.

Expected behavior

The specified num_beams should be used for all chunks.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

This sounds like a pretty huge bug, do you have a reproducer?

@cifkao
Copy link
Contributor Author

cifkao commented Mar 30, 2024

@ArthurZucker How is this for a reproducer?

import datasets
from transformers import AutoProcessor, GenerationMixin, WhisperForConditionalGeneration
import numpy as np

processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

orig_generate = GenerationMixin.generate


def generate(self, *args, **kwargs):
    print("num_beams:", args[1].num_beams)
    return orig_generate(self, *args, **kwargs)


GenerationMixin.generate = generate

ds = datasets.load_dataset(
    "google/fleurs", "en_us", split="test", trust_remote_code=True
)
ds = ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16000))
raw_audio = np.concatenate([x["array"].astype(np.float32) for x in ds[:16]["audio"]])

inputs = processor(
    [raw_audio],
    return_tensors="pt",
    truncation=False,
    padding="longest",
    return_attention_mask=True,
    sampling_rate=16_000,
)

model.generate(
    **inputs,
    num_beams=2,
    task="transcribe",
    language="en",
)

Before fix:

num_beams: 2
num_beams: 1
num_beams: 1
num_beams: 1
num_beams: 1
num_beams: 1

After fix:

num_beams: 2
num_beams: 2
num_beams: 2
num_beams: 2
num_beams: 2
num_beams: 2

I've tried to modify one of the long-form transcription tests to check that GenerationMixin.generate() is being called with the correct arguments, but I haven't figured out how to mock it out correctly.

@ArthurZucker
Copy link
Collaborator

IT's perfect, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants