Skip to content

Commit

Permalink
Fix probability computation in WhisperNoSpeechDetection when recomp…
Browse files Browse the repository at this point in the history
…uting scores (#29248)

* Fix is_scores_logprobs in WhisperNoSpeechDetection

* Add test_whisper_longform_no_speech_detection

* Fix typo
  • Loading branch information
cifkao authored and ArthurZucker committed Apr 22, 2024
1 parent e7ca8c5 commit 50c056c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,17 +1930,20 @@ def set_begin_index(self, begin_index):

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
is_scores_logprobs = self.is_scores_logprobs

if input_ids.shape[1] == self.begin_index:
if self.start_of_trans_offset > 1:
with torch.no_grad():
logits = self.model(**self.inputs).logits

no_speech_index = self.begin_index - self.start_of_trans_offset
no_speech_scores = logits[:, no_speech_index]
is_scores_logprobs = False
else:
no_speech_scores = scores

if self.is_scores_logprobs:
if is_scores_logprobs:
probs = no_speech_scores.exp()
else:
probs = no_speech_scores.float().softmax(dim=-1)
Expand Down
53 changes: 53 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2670,6 +2670,59 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self):
for i in range(num_samples):
assert decoded_all[i] == EXPECTED_TEXT[i]

@slow
def test_whisper_longform_no_speech_detection(self):
# fmt: off
EXPECTED_TEXT = [
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories. Developing the central headline pawns, definitely maneuvering and also topical night to F6.",
" Folks, I spent a lot of time right over there night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing",
' Ladies and gentlemen, you know, I spent a lot of time right over there raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their joke swollen teats',
' Folks, you watched this show, you know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the',
" You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui,",
' You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest.',
" Folks, if you watch this show, you know I spend most of my time right over there, carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most...",
" Folks, if you watch the show and I hope you do, I spent a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines.",
]
# fmt: on

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to(torch_device)

ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))

num_samples = 8

audio = ds[:num_samples]["audio"]
audios = [x["array"] for x in audio]

# Make sure the second chunk is silent
for audio in audios:
audio[15 * 16000 : 60 * 16000] = 0.0

inputs = processor(
audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
)
inputs = inputs.to(device=torch_device)

gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.2,
"temperature": (0.0,),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True,
"logprob_threshold": 0.0, # Ignore logprob, use only no-speech prob
"num_beams": 5,
}

torch.manual_seed(0)
result = model.generate(**inputs, **gen_kwargs)
decoded_all = processor.batch_decode(result, skip_special_tokens=True)

for i in range(num_samples):
assert decoded_all[i] == EXPECTED_TEXT[i]


def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
Expand Down

0 comments on commit 50c056c

Please sign in to comment.