From 25c451e5a044969eb91e1e481574a2bfca5130ca Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 14 Nov 2022 22:32:50 +0100 Subject: [PATCH] Adding chunking for whisper (all seq2seq actually). Very crude matching algorithm. (#20104) * Very crude matching algorithm. * Fixing tests. * Removing comments * Adding warning + fix short matches. * Cleanup tests. * Quality. * Less noisy. * Fixup. --- .../pipelines/automatic_speech_recognition.py | 69 +++++++++++++++---- ..._pipelines_automatic_speech_recognition.py | 24 +++++-- 2 files changed, 75 insertions(+), 18 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 04e36410ce920..e3b4ad0b6b8ea 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -30,7 +30,7 @@ from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING -def rescale_stride(tokens_or_logits, stride, ratio): +def rescale_stride(stride, ratio): """ Rescales the stride values from audio space to tokens/logits space. @@ -60,8 +60,43 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right): _stride_left = 0 if i == 0 else stride_left is_last = i + step + stride_left >= inputs_len _stride_right = 0 if is_last else stride_right + + if "input_features" in processed: + processed_len = processed["input_features"].shape[-1] + elif "input_values" in processed: + processed_len = processed["input_values"].shape[-1] + chunk_len = chunk.shape[0] + stride = (chunk_len, _stride_left, _stride_right) + if processed_len != chunk.shape[-1]: + ratio = processed_len / chunk_len + stride = rescale_stride([stride], ratio)[0] if chunk.shape[0] > _stride_left: - yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed} + yield {"is_last": is_last, "stride": stride, **processed} + + +def _find_longest_common_sequence(sequences, tokenizer): + # TODO Use a faster algorithm this can probably be done in O(n) + # using suffix array. + # It might be tedious to do because of fault tolerance. + # We actually have a really good property which is that the total sequence + # MUST be those subsequences in order. + # Also the algorithm should be more tolerant to errors. + sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids] + for new_seq in sequences[1:]: + new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids] + + index = 0 + max_ = 0.0 + for i in range(1, len(new_sequence) + 1): + # epsilon to favor long perfect matches + eps = i / 10000.0 + matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i])) + matching = matches / i + eps + if matches > 1 and matching > max_: + index = i + max_ = matching + sequence.extend(new_sequence[index:]) + return np.array(sequence) class AutomaticSpeechRecognitionPipeline(ChunkPipeline): @@ -188,6 +223,8 @@ def _sanitize_parameters(self, **kwargs): preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"] if "stride_length_s" in kwargs: preprocess_params["stride_length_s"] = kwargs["stride_length_s"] + if "ignore_warning" in kwargs: + preprocess_params["ignore_warning"] = kwargs["ignore_warning"] postprocess_params = {} if "decoder_kwargs" in kwargs: @@ -197,7 +234,7 @@ def _sanitize_parameters(self, **kwargs): return preprocess_params, {}, postprocess_params - def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): + def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False): if isinstance(inputs, str): with open(inputs, "rb") as f: inputs = f.read() @@ -249,10 +286,14 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") if chunk_length_s: - if self.type not in {"ctc", "ctc_with_lm"}: - raise ValueError( - "`chunk_length_s` is only valid for CTC models, use other chunking options for other models" + if self.type == "seq2seq" and not ignore_warning: + logger.warning( + "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily" + " be entirely accurate and will have caveats. More information:" + " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," + " ignore_warning=True)" ) + self._preprocess_params["ignore_warning"] = True if stride_length_s is None: stride_length_s = chunk_length_s / 6 @@ -262,7 +303,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): # XXX: Carefuly, this variable will not exist in `seq2seq` setting. # Currently chunking is not possible at this level for `seq2seq` so # it's ok. - align_to = self.model.config.inputs_to_logits_ratio + align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1) chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to) stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) @@ -329,9 +370,9 @@ def _forward(self, model_inputs): # the pieces are to be concatenated. ratio = 1 / self.model.config.inputs_to_logits_ratio if isinstance(stride, tuple): - out["stride"] = rescale_stride(logits, [stride], ratio)[0] + out["stride"] = rescale_stride([stride], ratio)[0] else: - out["stride"] = rescale_stride(logits, stride, ratio) + out["stride"] = rescale_stride(stride, ratio) # Leftover extra = model_inputs return {"is_last": is_last, **out, **extra} @@ -347,10 +388,11 @@ def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, retu final_items = [] key = "logits" if self.type == "ctc_with_lm" else "tokens" + stride = None for outputs in model_outputs: items = outputs[key].numpy() stride = outputs.pop("stride", None) - if stride is not None: + if stride is not None and self.type in {"ctc", "ctc_with_lm"}: total_n, left, right = stride # Total_n might be < logits.shape[1] # because of padding, that's why @@ -359,8 +401,11 @@ def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, retu right_n = total_n - right items = items[:, left:right_n] final_items.append(items) - items = np.concatenate(final_items, axis=1) - items = items.squeeze(0) + if stride and self.type == "seq2seq": + items = _find_longest_common_sequence(final_items, self.tokenizer) + else: + items = np.concatenate(final_items, axis=1) + items = items.squeeze(0) if self.type == "ctc_with_lm": if decoder_kwargs is None: decoder_kwargs = {} diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 5338f59aacbbf..84ceb9fce84e6 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -144,12 +144,8 @@ def test_small_model_pt(self): waveform = np.tile(np.arange(1000, dtype=np.float32), 34) output = speech_recognizer(waveform) self.assertEqual(output, {"text": "(Applaudissements)"}) - with self.assertRaises(ValueError) as v: - _ = speech_recognizer(waveform, chunk_length_s=10) - self.assertEqual( - str(v.exception), - "`chunk_length_s` is only valid for CTC models, use other chunking options for other models", - ) + output = speech_recognizer(waveform, chunk_length_s=10) + self.assertEqual(output, {"text": "(Applaudissements)"}) # Non CTC models cannot use return_timestamps with self.assertRaises(ValueError) as v: @@ -261,6 +257,22 @@ def test_torch_large(self): output = speech_recognizer(filename) self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"}) + @require_torch + @slow + def test_torch_whisper(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny", + framework="pt", + ) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + filename = ds[40]["file"] + output = speech_recognizer(filename) + self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."}) + + output = speech_recognizer([filename], chunk_length_s=5, batch_size=4) + self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}]) + @require_torch @slow def test_torch_speech_encoder_decoder(self):