Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding chunking for whisper (all seq2seq actually). Very crude matching algorithm. #20104

Merged
merged 8 commits into from Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 52 additions & 11 deletions src/transformers/pipelines/automatic_speech_recognition.py
Expand Up @@ -31,7 +31,7 @@
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


def rescale_stride(tokens_or_logits, stride, ratio):
def rescale_stride(stride, ratio):
"""
Rescales the stride values from audio space to tokens/logits space.

Expand Down Expand Up @@ -61,8 +61,43 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
_stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right

if "input_features" in processed:
processed_len = processed["input_features"].shape[-1]
elif "input_values" in processed:
processed_len = processed["input_values"].shape[-1]
chunk_len = chunk.shape[0]
stride = (chunk_len, _stride_left, _stride_right)
if processed_len != chunk.shape[-1]:
ratio = processed_len / chunk_len
stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed}
yield {"is_last": is_last, "stride": stride, **processed}


def _find_longest_common_sequence(sequences, tokenizer):
# TODO Use a faster algorithm this can probably be done in O(n)
# using suffix array.
# It might be tedious to do because of fault tolerance.
# We actually have a really good property which is that the total sequence
# MUST be those subsequences in order.
# Also the algorithm should be more tolerant to errors.
sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids]
for new_seq in sequences[1:]:
new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids]

index = 0
max_ = 0.0
for i in range(1, len(new_sequence) + 1):
# epsilon to favor long perfect matches
eps = i / 10000.0
matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i]))
matching = matches / i + eps
if matches > 1 and matching > max_:
index = i
max_ = matching
sequence.extend(new_sequence[index:])
return np.array(sequence)


class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
Expand Down Expand Up @@ -250,9 +285,11 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")

if chunk_length_s:
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
if self.type == "seq2seq":
logger.warning(
"Using `chunk_length_s` is very experimental. The results will not necessarily be entirely"
" accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some logic to only throw this warning once? Users are complaining Transformers is too verbose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there already a created way to do that ?

Otherwise I can create some tool for it.
Any other location we could add this "single" warning ? (Will add in a different PR)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use a dict in the state like this one. No need to overengineer another solution IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if stride_length_s is None:
stride_length_s = chunk_length_s / 6
Expand All @@ -263,7 +300,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
# XXX: Carefuly, this variable will not exist in `seq2seq` setting.
# Currently chunking is not possible at this level for `seq2seq` so
# it's ok.
align_to = self.model.config.inputs_to_logits_ratio
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
Expand Down Expand Up @@ -334,9 +371,9 @@ def _forward(self, model_inputs):
# the pieces are to be concatenated.
ratio = 1 / self.model.config.inputs_to_logits_ratio
if isinstance(stride, tuple):
out["stride"] = rescale_stride(logits, [stride], ratio)[0]
out["stride"] = rescale_stride([stride], ratio)[0]
else:
out["stride"] = rescale_stride(logits, stride, ratio)
out["stride"] = rescale_stride(stride, ratio)
# Leftover
extra = model_inputs
return {"is_last": is_last, **out, **extra}
Expand All @@ -352,10 +389,11 @@ def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, retu

final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None
for outputs in model_outputs:
items = outputs[key].numpy()
stride = outputs.pop("stride", None)
if stride is not None:
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
total_n, left, right = stride
# Total_n might be < logits.shape[1]
# because of padding, that's why
Expand All @@ -364,8 +402,11 @@ def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, retu
right_n = total_n - right
items = items[:, left:right_n]
final_items.append(items)
items = np.concatenate(final_items, axis=1)
items = items.squeeze(0)
if stride and self.type == "seq2seq":
items = _find_longest_common_sequence(final_items, self.tokenizer)
else:
items = np.concatenate(final_items, axis=1)
items = items.squeeze(0)
if self.type == "ctc_with_lm":
if decoder_kwargs is None:
decoder_kwargs = {}
Expand Down
24 changes: 18 additions & 6 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Expand Up @@ -144,12 +144,8 @@ def test_small_model_pt(self):
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(
str(v.exception),
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
)
output = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(output, {"text": "(Applaudissements)"})

# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError) as v:
Expand Down Expand Up @@ -261,6 +257,22 @@ def test_torch_large(self):
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})

@require_torch
@slow
def test_torch_whisper(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})

output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
Comment on lines +273 to +274
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIce


@require_torch
@slow
def test_torch_speech_encoder_decoder(self):
Expand Down