Skip to content

Commit

Permalink
Fixing batching for ASR.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 3, 2022
1 parent e381fd7 commit 3f8eb3e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
20 changes: 13 additions & 7 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def preprocess(self, inputs, chunk_length_ms=0, stride_length_ms=None):
stride_right = max(stop - (i + chunk_len), 0)
is_last = i + step > inputs_len

yield {"is_last": is_last, "stride": (stride_left, stride_right), **processed}
yield {"is_last": is_last, "stride": (stop - start, stride_left, stride_right), **processed}
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
Expand All @@ -214,12 +214,18 @@ def _forward(self, model_inputs):
outputs = self.model(**model_inputs)
tokens = outputs.logits.argmax(dim=-1)
if stride is not None:
left, right = stride
input_n = model_inputs["input_values"].shape[-1]
token_n = tokens.shape[-1]
left_token = int(left / input_n * token_n)
right_token = int((input_n - right) / input_n * token_n) + 1
tokens = tokens[:, left_token:right_token]
if isinstance(stride, tuple):
stride = [stride]

max_token_n = tokens.shape[-1]
max_input_n = max(input_n for input_n, _, _ in stride)
ratio = max_token_n / max_input_n
for i, (input_n, left, right) in enumerate(stride):
token_n = int(input_n * ratio) + 1
left_token = int(left / input_n * token_n)
right_token = int((input_n - right) / input_n * token_n) + 1
tokens[i, :left_token] = self.tokenizer.pad_token_id
tokens[i, right_token:] = self.tokenizer.pad_token_id
else:
logger.warning("This is an unknown class, treating it as CTC.")
outputs = self.model(**model_inputs)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pipelines/pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __next__(self):
return accumulator
else:
item = processed
is_last = item.pop("is_last")
is_last = item.pop("is_last")
accumulator.append(item)
return accumulator

Expand Down
5 changes: 3 additions & 2 deletions tests/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def test_chunking(self):

n_repeats = 100
audio = np.tile(audio, n_repeats)
output = speech_recognizer(audio, batch_size=2)
output = speech_recognizer([audio], batch_size=2)
expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats
self.assertEqual(output, {"text": expected_text.strip()})
expected = [{"text": expected_text.strip()}]
self.assertEqual(output, expected)

0 comments on commit 3f8eb3e

Please sign in to comment.