Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Facing issues with sequential long form decoding for long audio #29287

Closed
4 tasks
agandhigoto opened this issue Feb 26, 2024 · 2 comments
Closed
4 tasks

Facing issues with sequential long form decoding for long audio #29287

agandhigoto opened this issue Feb 26, 2024 · 2 comments

Comments

@agandhigoto
Copy link

agandhigoto commented Feb 26, 2024

System Info

transformers - 4.38.1
python - 3.9/3.11
platform - macOS/AWS g4dn.xlarge

Who can help?

@patrickvonplaten @sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import WhisperForConditionalGeneration, WhisperProcessor
import librosa
import time
import torch

torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda:0" if torch.cuda.is_available() else "cpu"

SR = 16000

start_time = time.time()
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2", torch_dtype=torch_dtype)
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")
model.to(device)

file_path = "sample.wav"
audio, _ = librosa.load(file_path, sr=SR)

inputs = processor([audio], return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
inputs = inputs.to(device, torch_dtype)

result = model.generate(**inputs, condition_on_prev_tokens=False, temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), logprob_threshold=-1.0, compression_ratio_threshold=1.35, return_timestamps=True)

decoded = processor.batch_decode(result, skip_special_tokens=True)
print(decoded)
generation_time = time.time() - start_time
print(generation_time)

Expected behavior

  1. This code works fine for openai/whisper-large-v3 model but fails for distil whisper and give below issue
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) RuntimeError: probability tensor contains either inf, nanor element < 0

  2. Also it's not fully clear to me what is meant by batching for long form sequential -> Does it mean the ability to provide multiple files via
    inputs = processor([audio1, audio2, audio3...], return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
    OR the ability to use the pipleline with batch_size param. Does the long form even work with pipeline like below ?
    pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, return_language=True, return_timestamps=True, batch_size=16 )

  3. The time taken to transcribe is pretty slow -
    openai/whisper-large-v3 with chunking in pipeline gives a RTF of .04
    openai/whisper-large-v3 with sequential long form gives a RTF of .13 which kind of contradicts the statement assuming 4x speed ups for large-v3 as well OR maybe the usage of batch size with long form sequential decoding is not fully clear since there is not good documentation around it.
    The code now fully functions for batch size > 1 (made sure that results on the four datasets is within +/- 0.1 % WER). When using batch size = 8, there is a 4x speed-up for large-v2, 2x speed-up for small (and 1.5x speed-up for tiny). The bigger the model, the larger the speed-up!

@zucchini-nlp
Copy link
Member

@agandhigoto I am not a whisper expert but after exploring the codebase a bit, this is what I found.

  1. The error raises due to forced_decoder_ids in the distil-whisper model config. It was not failing in openai-whisper because it does not have it by default. I opened a PR to fix it, until it gets merged you can use this code as a workaround
result = model.generate(**inputs, condition_on_prev_tokens=False, temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), logprob_threshold=-1.0, compression_ratio_threshold=1.35, forced_decoder_ids=None, return_timestamps=True)
  1. Whisper longform in pipeline cannot do batched generation right now. You can still pass multiple samples and have batch size=1 by default, in which case the inputs will be processed one by one sequentially. To use longform whisper generation with batches more than 1, you can instantiate a WhisperForConditionalGeneration
from transformers import WhisperForConditionalGeneration, AutoProcessor

processor = AutoProcessor.from_pretrained("distil-whisper/distil-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2").to("cuda:0")

inputs = processor(batch_of_long_audios, return_tensors="pt", truncation=False, padding=True, return_attention_mask=True, sampling_rate=16_000)
inputs = inputs.to("cuda:0")

result = model.generate(**inputs, return_timestamps=True)
decoded = processor.batch_decode(result, skip_special_tokens=True)
print(decoded)
  1. This statement you cited is about comparing speed-ups when using batch_size 1 vs more than 1. And as stated in (2) batch size>1 is not possible for pipelines. Please use the code above for batched generation. I tried to measure time with a toy sample of 50 audio each 40-50 seconds with openai/whisper-large-v2 and validated there is speed up for higher batch size.

Hope this helps to understand how to use batched longform generation in Whisper 🤗

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Apr 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants