Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix probability computation in WhisperNoSpeechDetection when recomputing scores #29248

Merged
merged 3 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,17 +1954,20 @@ def set_begin_index(self, begin_index):

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
is_scores_logprobs = self.is_scores_logprobs

if input_ids.shape[1] == self.begin_index:
if self.start_of_trans_offset > 1:
with torch.no_grad():
logits = self.model(**self.inputs).logits

no_speech_index = self.begin_index - self.start_of_trans_offset
no_speech_scores = logits[:, no_speech_index]
is_scores_logprobs = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

else:
no_speech_scores = scores

if self.is_scores_logprobs:
if is_scores_logprobs:
probs = no_speech_scores.exp()
else:
probs = no_speech_scores.float().softmax(dim=-1)
Expand Down
53 changes: 53 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2615,6 +2615,59 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self):
for i in range(num_samples):
assert decoded_all[i] == EXPECTED_TEXT[i]

@slow
def test_whisper_longform_no_speech_detection(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this slow test! Just to confirm, before the fix all the transcriptions are empty due to the no-speech probabilities exceeding 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly.

# fmt: off
EXPECTED_TEXT = [
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories. Developing the central headline pawns, definitely maneuvering and also topical night to F6.",
" Folks, I spent a lot of time right over there night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing",
' Ladies and gentlemen, you know, I spent a lot of time right over there raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their joke swollen teats',
' Folks, you watched this show, you know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the',
" You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui,",
' You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest.',
" Folks, if you watch this show, you know I spend most of my time right over there, carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most...",
" Folks, if you watch the show and I hope you do, I spent a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines.",
]
# fmt: on

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to(torch_device)

ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))

num_samples = 8

audio = ds[:num_samples]["audio"]
audios = [x["array"] for x in audio]

# Make sure the second chunk is silent
for audio in audios:
audio[15 * 16000 : 60 * 16000] = 0.0

inputs = processor(
audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
)
inputs = inputs.to(device=torch_device)

gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.2,
"temperature": (0.0,),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True,
"logprob_threshold": 0.0, # Ignore logprob, use only no-speech prob
"num_beams": 5,
}

torch.manual_seed(0)
result = model.generate(**inputs, **gen_kwargs)
decoded_all = processor.batch_decode(result, skip_special_tokens=True)

for i in range(num_samples):
assert decoded_all[i] == EXPECTED_TEXT[i]


def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
Expand Down