Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SequenceFeatureExtractor] Rewrite padding logic from pure python to numpy #13650

Merged
merged 8 commits into from
Sep 21, 2021
Merged

[SequenceFeatureExtractor] Rewrite padding logic from pure python to numpy #13650

merged 8 commits into from
Sep 21, 2021

Conversation

anton-l
Copy link
Member

@anton-l anton-l commented Sep 20, 2021

What does this PR do?

Resolves #13539

Since speech models universally use Numpy float32 arrays as input features (standard way of representing waveforms), it was decided to rewrite SequenceFeatureExtractor from pure python lists (akin to traditional tokenizers) to numpy arrays. It will also help with solving some inconsistent normalization issues (#13538, #13585) due to float->np.float32 conversions.

The feature extractor itself is still dtype-agnostic (can pad np.float64 in the future if needed), while the model-specific feature extractors were updated to only work with np.float32

x = np.subtract(x, mean)
if normalize_vars:
var = square_sums / x[:input_length].shape[0] - mean ** 2
std = np.sqrt(np.maximum(var, 1e-10))
std = x[:input_length].std(axis=0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched this logic to pure numpy to squeeze out a bit more precision when working with np.float32 instead of casted float->np.float64

if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech) for speech in raw_speech]
if is_batched:
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
Copy link
Member Author

@anton-l anton-l Sep 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forcing float32 here for consistency with other feature extractors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

@anton-l anton-l changed the title Numpify speech padding [SequenceFeatureExtraction] Rewrite padding logic from pure python to numpy Sep 20, 2021
@anton-l anton-l changed the title [SequenceFeatureExtraction] Rewrite padding logic from pure python to numpy [SequenceFeatureExtractor] Rewrite padding logic from pure python to numpy Sep 20, 2021
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@@ -724,7 +724,7 @@ def map_to_array(batch):
return batch

ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.select(range(num_samples)).map(map_to_array)
ds = ds.sort("id").select(range(num_samples)).map(map_to_array)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR: the newdatasets version re-shuffled this dataset, so sorting is needed for reproducibility.

output = speech_recognizer(waveform)
self.assertEqual(output, {"text": ""})

from datasets import load_dataset

ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation").sort("id")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR: the newdatasets version re-shuffled this dataset, so sorting is needed for reproducibility.

@@ -42,9 +42,9 @@ def test_torch_small(self):
tokenizer="facebook/s2t-small-mustc-en-fr-st",
framework="pt",
)
waveform = np.zeros((34000,))
waveform = np.linspace(0, 1, 34000, dtype=np.float32)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the slight variability due to type-casting is eliminated, this needs to be something specific, because the model becomes unstable (non-reproducible in different environments) with all-zero inputs.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Left a couple of nits:

  • Think we don't have to use np.int64 -> we are using tf.int32 everywhere
  • Would be nice to add jax to the to_py_obj and to_numpy_obj function
  • Do we still need a higher variance for the no-padding test for Speech2Text?

@patrickvonplaten
Copy link
Contributor

Also did you notice a speed-up for larger inputs?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool, LGTM @anton-l!

@anton-l
Copy link
Member Author

anton-l commented Sep 20, 2021

@patrickvonplaten the benchmarking results are pretty promising:

  1. Input lengths from 8000 to 16000 (1 sec max), batch size 64, feature_extractor only:

    • Python: 52.1 ms ± 2.35 ms
    • Numpy: 32.1 ms ± 1.13 ms
  2. Input lengths from 8000 to 160000 (10 sec max), batch size 64, feature_extractor only:

    • Python: 276 ms ± 950 µs
    • Numpy: 68.2 ms ± 491 µs

@patrickvonplaten
Copy link
Contributor

Great job @anton-l - feel free to merge!

@anton-l anton-l merged commit 1417978 into huggingface:master Sep 21, 2021
@anton-l anton-l deleted the numpify-speech-padding branch September 21, 2021 14:16
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 13, 2022
…numpy (huggingface#13650)

* Test np padding

* Pass feature extraction tests

* Update type hints

* Fix flaky integration tests

* Try a more stable waveform

* Add to_numpy jax support

* int32 attention masks

* Refactor normalization tests
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
…numpy (huggingface#13650)

* Test np padding

* Pass feature extraction tests

* Update type hints

* Fix flaky integration tests

* Try a more stable waveform

* Add to_numpy jax support

* int32 attention masks

* Refactor normalization tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SequenceFeatureExtraction] Move padding logic from pure python to numpy
5 participants