-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid looping when data exhausted #14413
Avoid looping when data exhausted #14413
Conversation
when using an iterable dataset num_epochs is set to sys.maxsize to make sure all data is consumed likewise we want to set max_steps high enough but still stop when all data is consumed (cherry picked from commit 6f0e1d6)
1da4eb6
to
04ed756
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this, it's a nice addition!
Make sure you run make style
on your branch to fix the quality issue.
tests/test_trainer.py
Outdated
batch_size = 1 | ||
gradient_accumulation_steps = 1 | ||
num_samples = 10 | ||
|
||
available_steps = num_samples // (batch_size * gradient_accumulation_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not really using gradient_accumulation_steps
here so let's remove.
batch_size = 1 | |
gradient_accumulation_steps = 1 | |
num_samples = 10 | |
available_steps = num_samples // (batch_size * gradient_accumulation_steps) | |
batch_size = 1 | |
num_samples = 10 | |
available_steps = num_samples // batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I falsely assumed there was another default value for it than 1. I removed it.
reformat training_args docstring
Thanks again for fixing this! :-) |
What does this PR do?
This fix avoids running into a virtually infinite loop when using a finite iterable dataset.
When using an iterable dataset
num_epochs
is set to sys.maxsize to make sure all data is consumed (see #12561)Likewise I'd like to set
max_steps
large enough to consume all data but still stop when the data is exhausted. In case we don't know how many samples there will be and the iterator stops we might run into a virtually infinite loop (iterating the int range untilsys.maxsize
).See this code snipped to reproduce the behavior:
With this fix it is checked whether
epoch_iterator
did not produce any samples and accordingly setcontrol.should_training_stop
toTrue
.I don't know if changing the flow control this way is approved of as it's always changed through callback handlers, I'm happy for suggestions how to properly do this.
I tried coming up with a test case checking the logs for when training was stopped in this case. Other options would be to measure the time training takes and time out after a while but that wouldn't be a nice test as run time may be affected by other circumstances.
Before submitting
Who can review?