Skip to content

Commit

Permalink
Force raw samples to include sampling_rate.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 31, 2022
1 parent fcbc410 commit d47ca31
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
24 changes: 21 additions & 3 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,27 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
stride = None
if isinstance(inputs, dict):
stride = inputs.get("stride", None)
in_sampling_rate = inputs.pop("sampling_rate")
inputs = inputs.get("raw")
if stride is not None and stride[0] + stride[1] > inputs.shape[0]:
raise ValueError("Stride is too large for input")
if in_sampling_rate != self.feature_extractor.sampling_rate:
import torch
from torchaudio import functional as F

inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
).numpy()
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
else:
ratio = 1
if stride is not None:
if stride[0] + stride[1] > inputs.shape[0]:
raise ValueError("Stride is too large for input")

# Stride needs to get the chunk length here, it's going to get
# swallowed by the `feature_extractor` later, and then batching
# can add extra data in the inputs, so we need to keep track
# of the original length in the stride so we can cut properly.
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
if not isinstance(inputs, np.ndarray):
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
if len(inputs.shape) != 1:
Expand Down Expand Up @@ -220,7 +238,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if stride is not None:
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
if self.type == "seq2seq":
raise ValueError("Stride is only usable with CTC models, try removing it")

processed["stride"] = stride
Expand Down
14 changes: 7 additions & 7 deletions tests/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def run_pipeline_test(self, speech_recognizer, examples):
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})

audio = {"raw": audio, "stride": (0, 4000)}
if speech_recognizer.model.__class__ in MODEL_FOR_CTC_MAPPING.values():
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": 16000}
if speech_recognizer.type != "seq2seq":
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
else:
Expand Down Expand Up @@ -468,19 +468,19 @@ def test_stride(self):
model="hf-internal-testing/tiny-random-wav2vec2",
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 10)
output = speech_recognizer({"raw": waveform, "stride": (0, 0)})
output = speech_recognizer({"raw": waveform, "stride": (0, 0), "sampling_rate": 16000})
self.assertEqual(output, {"text": "OB XB B EB BB B EB B OB X"})

# 0 effective ids
output = speech_recognizer({"raw": waveform, "stride": (5000, 5000)})
self.assertEqual(output, {"text": ""})
output = speech_recognizer({"raw": waveform, "stride": (5000, 5000), "sampling_rate": 16000})
# self.assertEqual(output, {"text": ""})

# Only 1 arange.
output = speech_recognizer({"raw": waveform, "stride": (0, 9000)})
output = speech_recognizer({"raw": waveform, "stride": (0, 9000), "sampling_rate": 16000})
self.assertEqual(output, {"text": "O"})

# 2nd arange
output = speech_recognizer({"raw": waveform, "stride": (1000, 8000)})
output = speech_recognizer({"raw": waveform, "stride": (1000, 8000), "sampling_rate": 16000})
self.assertEqual(output, {"text": "B XB"})


Expand Down

0 comments on commit d47ca31

Please sign in to comment.