Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ASR CTC streaming example #15309

Merged
merged 16 commits into from Feb 7, 2022

Conversation

anton-l
Copy link
Member

@anton-l anton-l commented Jan 24, 2022

What does this PR do?

This enables the use of streamable datasets to fine-tune CTC models.
Example args:

--dataset_name="common_voice"
--model_name_or_path="ntu-spml/distilhubert"
--tokenizer_name_or_path="infinitejoy/wav2vec2-large-xls-r-300m-abkhaz"
--dataset_config_name="ab"
--output_dir="./dummy_ctc"
--overwrite_output_dir
--per_device_train_batch_size="4"
--gradient_accumulation_steps="1"
--learning_rate="5e-5"
--max_steps="3000"
--warmup_steps="500"
--evaluation_strategy="steps"
--text_column_name="sentence"
--save_steps="500"
--eval_steps="5"
--logging_steps="1"
--layerdrop="0.0"
--save_total_limit="1"
--mask_time_prob="0.3"
--mask_time_length="10"
--mask_feature_prob="0.1"
--mask_feature_length="64"
--freeze_feature_encoder
--chars_to_ignore=", ? . ! - \; \: \" % ‘ �"
--fp16
--do_train
--do_eval
--gradient_checkpointing

@HuggingFaceDocBuilder
Copy link

HuggingFaceDocBuilder commented Jan 24, 2022

The documentation is not available anymore as the PR was closed or merged.

@anton-l
Copy link
Member Author

anton-l commented Jan 24, 2022

One thing left to add: support for infinite streaming if we reach the end of the dataset, so that Trainer doesn't stop until max_steps is reached.
Reference:

class ConstantLengthDataset(IterableDataset):

Otherwise this example yields:

There seems to be not a single sample in your epoch_iterator, stopping training at step 6! This is expected if you're using an IterableDataset and set num_steps (3000) higher than the number of available samples.

prepare_dataset,
)
.remove_columns(raw_column_names[split] + ["target_text"])
.shuffle(buffer_size=10, seed=training_args.seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

buffer_size=10 is a bit small IMO no? @lhoestq what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also nit maybe split the line in 2,3 lines to make it more readable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer_size is a tradeoff. If it's lower then you have a lower shuffling quality, if it's greater the shuffling is better but at the cost of memory and it takes longer for the dataset to start iterating.

For audio data I would say buffer_size=100 is maybe more reasonable ? Feel free to try it though and see if it's ok in terms of RAM and starting delay.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Looks good to me. I've ran it on a GPU TITAN RTX and it is indeed a bit slow compared to the non-streaming script.

@anton-l could you re-run the following experiment in streaming mode to benchmark both perf and speed:

https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft

There is a run.sh script inside and you can also use the already made tokenizer of the repo

@patrickvonplaten
Copy link
Contributor

Maybe run the script for just 5 epochs both with streaming and without so that we have some exact numbers and can compare perf as well

anton-l and others added 2 commits January 27, 2022 13:33
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@patrickvonplaten
Copy link
Contributor

Feel free to merge and announce this @anton-l once there have been some testing :-)

@anton-l
Copy link
Member Author

anton-l commented Jan 31, 2022

The ASR README now includes a streaming example+benchmark :)

A couple of last notes:

  1. streaming will run in distributed mode only with the Trainer fixes from this PR (to support IterableDatasetShard)
  2. https://huggingface.co/datasets/mozilla-foundation/common_voice_8_0 is very flaky (connection drops after a couple of epochs, possibly due to a temporary IP ban), so the benchmark only succeeded with common_voice

@@ -52,7 +52,7 @@
import torch
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger - could you take a look here as well? :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those changes seem to be needed to make Trainer + DDP with streaming work

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context: len(IterableDatasetShard()) raises an error (as expected) if its underlying dataset (datasets.TorchIterableDataset in our case) is not sized. So for a deeper Sized check we have to get to IterableDatasetShard.dataset, hence the dataloader.dataset.dataset below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand those changes. IterableDatasetShard implements __len__ specifically for supporting iterable datasets with length (usually coming from a Dataset with straeming), with the method throwing an error when the underlying dataset does not implements __len__.

Or is it that IterableDatasetShard is always considered an abc.collections.Sized since there is the implementation of the method and you need to have them not considered sized in this case? If that's the use case, we should replace the check from abc with a try except to get the length.

Copy link
Member Author

@anton-l anton-l Jan 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger yes, the second case is exactly what we have here! Wasn't sure how you feel about try:except blocks codestyle-wise, but if it's ok I'll replace the checks 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so then can you add an util has_length that tries to access the length of the dataset and returns False in case of error, then use this instead of diving in three attributes? It'd be cleaner :-)

@@ -1,4 +1,4 @@
datasets >= 1.13.3
datasets >= 1.18.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The upgrade is needed to support accessing dl_manager.download_config https://huggingface.co/datasets/mozilla-foundation/common_voice_8_0/blob/main/common_voice_8_0.py#L144

Saw some reports about that on Discord last week :)

features = dataset_splits[0].features
# make sure that the dataset decodes audio with a correct sampling rate
dataset_splits = [
dataset.cast_column(data_args.audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to cast_column before interleave_datasets(), because only the first stage of the streaming pipeline has access to the column names. The column names are None after that, due to streaming ambiguity (we're not Spark and can't pre-compile pipelines to infer the output column names 😅 )

@@ -101,6 +100,7 @@
distributed_concat,
find_batch_size,
get_parameter_names,
has_length,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger gentle ping to check if you're ok with these updates to the Trainer :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that works for me!
Could you extract those changes from this PR and make them in a separate PR? We prefer having smaller PRs laser-focused on one thing rather than big ones that touch every file.

@anton-l
Copy link
Member Author

anton-l commented Feb 7, 2022

TODO: Merge this after #15539

@anton-l anton-l changed the title [WIP] Add ASR CTC streaming example Add ASR CTC streaming example Feb 7, 2022
@anton-l anton-l merged commit a459f7f into huggingface:master Feb 7, 2022
@anton-l anton-l deleted the add-ctc-streaming-script branch February 7, 2022 15:35
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Feb 18, 2022
* Single-epoch run

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Infinite dataset

* Trainer fix + distributed benchmark

* Benchmark fix

* unused import

* interleaved splits

* interleaved splits

* has_length util

* Move to research projects

* Leftover Sized checks

* Bump min version

* Unused import

* Revert trainer changes

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
ManuelFay pushed a commit to ManuelFay/transformers that referenced this pull request Mar 31, 2022
* Single-epoch run

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Infinite dataset

* Trainer fix + distributed benchmark

* Benchmark fix

* unused import

* interleaved splits

* interleaved splits

* has_length util

* Move to research projects

* Leftover Sized checks

* Bump min version

* Unused import

* Revert trainer changes

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants