Skip to content

Commit

Permalink
Adding chunking for whisper (all seq2seq actually). Very crude matchi…
Browse files Browse the repository at this point in the history
…ng algorithm. (huggingface#20104)

* Very crude matching algorithm.

* Fixing tests.

* Removing comments

* Adding warning + fix short matches.

* Cleanup tests.

* Quality.

* Less noisy.

* Fixup.
  • Loading branch information
Narsil authored and Magnus Pierrau committed Dec 15, 2022
1 parent db4291c commit d1bc02c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
69 changes: 57 additions & 12 deletions src/transformers/pipelines/automatic_speech_recognition.py
Expand Up @@ -30,7 +30,7 @@
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


def rescale_stride(tokens_or_logits, stride, ratio):
def rescale_stride(stride, ratio):
"""
Rescales the stride values from audio space to tokens/logits space.
Expand Down Expand Up @@ -60,8 +60,43 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
_stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right

if "input_features" in processed:
processed_len = processed["input_features"].shape[-1]
elif "input_values" in processed:
processed_len = processed["input_values"].shape[-1]
chunk_len = chunk.shape[0]
stride = (chunk_len, _stride_left, _stride_right)
if processed_len != chunk.shape[-1]:
ratio = processed_len / chunk_len
stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed}
yield {"is_last": is_last, "stride": stride, **processed}


def _find_longest_common_sequence(sequences, tokenizer):
# TODO Use a faster algorithm this can probably be done in O(n)
# using suffix array.
# It might be tedious to do because of fault tolerance.
# We actually have a really good property which is that the total sequence
# MUST be those subsequences in order.
# Also the algorithm should be more tolerant to errors.
sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids]
for new_seq in sequences[1:]:
new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids]

index = 0
max_ = 0.0
for i in range(1, len(new_sequence) + 1):
# epsilon to favor long perfect matches
eps = i / 10000.0
matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i]))
matching = matches / i + eps
if matches > 1 and matching > max_:
index = i
max_ = matching
sequence.extend(new_sequence[index:])
return np.array(sequence)


class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
Expand Down Expand Up @@ -188,6 +223,8 @@ def _sanitize_parameters(self, **kwargs):
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
if "stride_length_s" in kwargs:
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
if "ignore_warning" in kwargs:
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]

postprocess_params = {}
if "decoder_kwargs" in kwargs:
Expand All @@ -197,7 +234,7 @@ def _sanitize_parameters(self, **kwargs):

return preprocess_params, {}, postprocess_params

def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()
Expand Down Expand Up @@ -249,10 +286,14 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")

if chunk_length_s:
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
if self.type == "seq2seq" and not ignore_warning:
logger.warning(
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
" be entirely accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)"
)
self._preprocess_params["ignore_warning"] = True
if stride_length_s is None:
stride_length_s = chunk_length_s / 6

Expand All @@ -262,7 +303,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
# XXX: Carefuly, this variable will not exist in `seq2seq` setting.
# Currently chunking is not possible at this level for `seq2seq` so
# it's ok.
align_to = self.model.config.inputs_to_logits_ratio
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
Expand Down Expand Up @@ -329,9 +370,9 @@ def _forward(self, model_inputs):
# the pieces are to be concatenated.
ratio = 1 / self.model.config.inputs_to_logits_ratio
if isinstance(stride, tuple):
out["stride"] = rescale_stride(logits, [stride], ratio)[0]
out["stride"] = rescale_stride([stride], ratio)[0]
else:
out["stride"] = rescale_stride(logits, stride, ratio)
out["stride"] = rescale_stride(stride, ratio)
# Leftover
extra = model_inputs
return {"is_last": is_last, **out, **extra}
Expand All @@ -347,10 +388,11 @@ def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, retu

final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None
for outputs in model_outputs:
items = outputs[key].numpy()
stride = outputs.pop("stride", None)
if stride is not None:
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
total_n, left, right = stride
# Total_n might be < logits.shape[1]
# because of padding, that's why
Expand All @@ -359,8 +401,11 @@ def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, retu
right_n = total_n - right
items = items[:, left:right_n]
final_items.append(items)
items = np.concatenate(final_items, axis=1)
items = items.squeeze(0)
if stride and self.type == "seq2seq":
items = _find_longest_common_sequence(final_items, self.tokenizer)
else:
items = np.concatenate(final_items, axis=1)
items = items.squeeze(0)
if self.type == "ctc_with_lm":
if decoder_kwargs is None:
decoder_kwargs = {}
Expand Down
24 changes: 18 additions & 6 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Expand Up @@ -144,12 +144,8 @@ def test_small_model_pt(self):
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(
str(v.exception),
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
)
output = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(output, {"text": "(Applaudissements)"})

# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError) as v:
Expand Down Expand Up @@ -261,6 +257,22 @@ def test_torch_large(self):
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})

@require_torch
@slow
def test_torch_whisper(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})

output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])

@require_torch
@slow
def test_torch_speech_encoder_decoder(self):
Expand Down

0 comments on commit d1bc02c

Please sign in to comment.