Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid looping when data exhausted #14413

Merged

Conversation

valentinkoe
Copy link
Contributor

@valentinkoe valentinkoe commented Nov 16, 2021

What does this PR do?

This fix avoids running into a virtually infinite loop when using a finite iterable dataset.

When using an iterable dataset num_epochs is set to sys.maxsize to make sure all data is consumed (see #12561)
Likewise I'd like to set max_steps large enough to consume all data but still stop when the data is exhausted. In case we don't know how many samples there will be and the iterator stops we might run into a virtually infinite loop (iterating the int range until sys.maxsize).

See this code snipped to reproduce the behavior:

from torch.utils.data import IterableDataset

from transformers import BertForMaskedLM, BertConfig, TrainingArguments, Trainer

model = BertForMaskedLM(BertConfig())


class FiniteIterableDataset(IterableDataset):
    def __init__(self, num_samples: int):
        self.current_sample = 0
        self.num_samples = num_samples

    def __iter__(self):
        while self.current_sample < self.num_samples:
            yield {"input_ids": [0, 0, 0, self.current_sample], "labels": [0, 0, 0, 1]}
            self.current_sample += 1


batch_size = 1
gradient_accumulation_steps = 1
num_samples = 10

available_steps = num_samples // (batch_size * gradient_accumulation_steps)

data = FiniteIterableDataset(num_samples)
train_args = TrainingArguments(
    "tmp_dir",
    max_steps=available_steps,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
)
trainer = Trainer(model, train_dataset=data, args=train_args)
trainer.train()  # works

data = FiniteIterableDataset(num_samples)
train_args = TrainingArguments(
    "tmp_dir",
    max_steps=available_steps + 1,  # set a higher number than actually available
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
)
trainer = Trainer(model, train_dataset=data, args=train_args)
trainer.train()  # "hangs" at 91% after 10 steps iterating through epochs like wild (until sys.maxsize)

With this fix it is checked whether epoch_iterator did not produce any samples and accordingly set control.should_training_stop to True.
I don't know if changing the flow control this way is approved of as it's always changed through callback handlers, I'm happy for suggestions how to properly do this.

I tried coming up with a test case checking the logs for when training was stopped in this case. Other options would be to measure the time training takes and time out after a while but that wouldn't be a nice test as run time may be affected by other circumstances.

Before submitting

Who can review?

Valentin Deyringer added 3 commits November 16, 2021 11:04
when using an iterable dataset num_epochs is set to
sys.maxsize to make sure all data is consumed
likewise we want to set max_steps high enough
but still stop when all data is consumed

(cherry picked from commit 6f0e1d6)
@valentinkoe valentinkoe force-pushed the avoid-looping-when-data-exhausted branch from 1da4eb6 to 04ed756 Compare November 16, 2021 10:04
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this, it's a nice addition!

Make sure you run make style on your branch to fix the quality issue.

Comment on lines 1075 to 1079
batch_size = 1
gradient_accumulation_steps = 1
num_samples = 10

available_steps = num_samples // (batch_size * gradient_accumulation_steps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not really using gradient_accumulation_steps here so let's remove.

Suggested change
batch_size = 1
gradient_accumulation_steps = 1
num_samples = 10
available_steps = num_samples // (batch_size * gradient_accumulation_steps)
batch_size = 1
num_samples = 10
available_steps = num_samples // batch_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I falsely assumed there was another default value for it than 1. I removed it.

tests/test_trainer.py Outdated Show resolved Hide resolved
Valentin Deyringer added 2 commits November 16, 2021 16:25
@sgugger sgugger merged commit a33168a into huggingface:master Nov 16, 2021
@sgugger
Copy link
Collaborator

sgugger commented Nov 16, 2021

Thanks again for fixing this! :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants