New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ASR CTC streaming example #15309
Conversation
The documentation is not available anymore as the PR was closed or merged. |
One thing left to add: support for infinite streaming if we reach the end of the dataset, so that transformers/examples/research_projects/codeparrot/scripts/codeparrot_training.py Line 20 in 48bf7e4
Otherwise this example yields:
|
examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
Show resolved
Hide resolved
examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
Outdated
Show resolved
Hide resolved
examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
Outdated
Show resolved
Hide resolved
examples/research_projects/robust-speech-event/run_speech_recognition_ctc_streaming.py
Outdated
Show resolved
Hide resolved
prepare_dataset, | ||
) | ||
.remove_columns(raw_column_names[split] + ["target_text"]) | ||
.shuffle(buffer_size=10, seed=training_args.seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
buffer_size=10 is a bit small IMO no? @lhoestq what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also nit maybe split the line in 2,3 lines to make it more readable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buffer_size is a tradeoff. If it's lower then you have a lower shuffling quality, if it's greater the shuffling is better but at the cost of memory and it takes longer for the dataset to start iterating.
For audio data I would say buffer_size=100
is maybe more reasonable ? Feel free to try it though and see if it's ok in terms of RAM and starting delay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Looks good to me. I've ran it on a GPU TITAN RTX and it is indeed a bit slow compared to the non-streaming script.
@anton-l could you re-run the following experiment in streaming mode to benchmark both perf and speed:
https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft
There is a run.sh
script inside and you can also use the already made tokenizer of the repo
Maybe run the script for just 5 epochs both with streaming and without so that we have some exact numbers and can compare perf as well |
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Feel free to merge and announce this @anton-l once there have been some testing :-) |
The ASR README now includes a streaming example+benchmark :) A couple of last notes:
|
examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py
Outdated
Show resolved
Hide resolved
src/transformers/trainer.py
Outdated
@@ -52,7 +52,7 @@ | |||
import torch | |||
from packaging import version | |||
from torch import nn | |||
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler | |||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgugger - could you take a look here as well? :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those changes seem to be needed to make Trainer + DDP with streaming work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For context: len(IterableDatasetShard())
raises an error (as expected) if its underlying dataset (datasets.TorchIterableDataset
in our case) is not sized. So for a deeper Sized
check we have to get to IterableDatasetShard.dataset
, hence the dataloader.dataset.dataset
below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand those changes. IterableDatasetShard
implements __len__
specifically for supporting iterable datasets with length (usually coming from a Dataset with straeming), with the method throwing an error when the underlying dataset does not implements __len__
.
Or is it that IterableDatasetShard
is always considered an abc.collections.Sized
since there is the implementation of the method and you need to have them not considered sized in this case? If that's the use case, we should replace the check from abc
with a try except to get the length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgugger yes, the second case is exactly what we have here! Wasn't sure how you feel about try:except blocks codestyle-wise, but if it's ok I'll replace the checks 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so then can you add an util has_length
that tries to access the length of the dataset and returns False in case of error, then use this instead of diving in three attributes? It'd be cleaner :-)
examples/pytorch/speech-recognition/run_speech_recognition_ctc_streaming.py
Outdated
Show resolved
Hide resolved
examples/pytorch/speech-recognition/run_speech_recognition_ctc_streaming.py
Outdated
Show resolved
Hide resolved
@@ -1,4 +1,4 @@ | |||
datasets >= 1.13.3 | |||
datasets >= 1.18.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The upgrade is needed to support accessing dl_manager.download_config
https://huggingface.co/datasets/mozilla-foundation/common_voice_8_0/blob/main/common_voice_8_0.py#L144
Saw some reports about that on Discord last week :)
features = dataset_splits[0].features | ||
# make sure that the dataset decodes audio with a correct sampling rate | ||
dataset_splits = [ | ||
dataset.cast_column(data_args.audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to cast_column
before interleave_datasets()
, because only the first stage of the streaming pipeline has access to the column names. The column names are None
after that, due to streaming ambiguity (we're not Spark and can't pre-compile pipelines to infer the output column names 😅 )
…-script # Conflicts: # examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
src/transformers/trainer.py
Outdated
@@ -101,6 +100,7 @@ | |||
distributed_concat, | |||
find_batch_size, | |||
get_parameter_names, | |||
has_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgugger gentle ping to check if you're ok with these updates to the Trainer :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that works for me!
Could you extract those changes from this PR and make them in a separate PR? We prefer having smaller PRs laser-focused on one thing rather than big ones that touch every file.
TODO: Merge this after #15539 |
* Single-epoch run * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Infinite dataset * Trainer fix + distributed benchmark * Benchmark fix * unused import * interleaved splits * interleaved splits * has_length util * Move to research projects * Leftover Sized checks * Bump min version * Unused import * Revert trainer changes Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Single-epoch run * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Infinite dataset * Trainer fix + distributed benchmark * Benchmark fix * unused import * interleaved splits * interleaved splits * has_length util * Move to research projects * Leftover Sized checks * Bump min version * Unused import * Revert trainer changes Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
What does this PR do?
This enables the use of streamable datasets to fine-tune CTC models.
Example args: