Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early stopping callback to pytorch trainer #8581

Merged
merged 20 commits into from
Nov 23, 2020

Conversation

cbrochtrup
Copy link

@cbrochtrup cbrochtrup commented Nov 17, 2020

Summary

Address PyTorch half of #4894 by adding early stopping patience and a minimum threshold metrics must improve to prevent early stopping. I piggybacked heavily off of #7431 since the two functions are very similar.

Since #4186 seems to be abandoned and behind master, I figured I'd take a crack at this.

Who can review?

Anyone! But @julien-c and @sgugger seem the most appropriate.

@sgugger
Copy link
Collaborator

sgugger commented Nov 17, 2020

Hi there. Thanks your PR! When I was designing the callbacks, it was to be them small independent pieces of code. I would prefer if early stopping had its own callback that the user would then choose to add or not. Do you think you could amend your PR in that direction?

@cbrochtrup
Copy link
Author

cbrochtrup commented Nov 17, 2020

Hello, thank you for your feedback! I will amend the PR in that direction.

Could you clarify which pieces of early stopping should be in TrainerState and which should be in the callback? I'm grappling with the similarities between best_model_checkpoint and early stopping attributes.

class EarlyStoppingCallback(TrainerCallback):
    best_metric: Optional[float] = None # maybe not this
    best_model_checkpoint: Optional[str] = None # maybe not this either
    early_stopping_patience: int = None
    early_stopping_patience_counter: int = None

    def on_evaluate(self, args, state, control, **kwargs):
        # Keep track of patience
        # End training via early stopping
        if (
            self.early_stopping_patience is not None
            and self.early_sotpping_patience_counter >= self.early_stopping_patience
        ):
            control.should_training_stop = True

@cbrochtrup
Copy link
Author

Or do you mean I just move the if statement I added to its own callback and keep TrainerState as is?

@sgugger
Copy link
Collaborator

sgugger commented Nov 17, 2020

The TrainerState shouldn't change, so the callback you are writing above sounds fine, without the arguments marked with # maybe not this, which should already be in the TrainerState, I think.
Does that sound right to you?

@cbrochtrup
Copy link
Author

That makes sense. I think this block of code (to line 933) could be a callback because it's all about the best metric. Then users could customize the best model calculations. Is that desirable?

If you think that's out of scope I'll keep the early stopping callback simple and separate from the best metric calculation.

@sgugger
Copy link
Collaborator

sgugger commented Nov 17, 2020

I had put it in Trainer because I thought multiple callbacks could need it and it's used by load_best_model_at_end which is kind of a core feature.

@cbrochtrup
Copy link
Author

Sounds good, you know best! I keep load_best_model_at_end in the Trainer and push up an early stopping callback sometime this week.

src/transformers/trainer_callback.py Outdated Show resolved Hide resolved
src/transformers/trainer_callback.py Outdated Show resolved Hide resolved
src/transformers/trainer_callback.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
@cbrochtrup cbrochtrup changed the title Add early stopping patience to pytorch trainer Add early stopping callback to pytorch trainer Nov 19, 2020
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few mote things to change, but we're close to get this in good state. Thanks a lot for your work on this!

src/transformers/trainer_callback.py Outdated Show resolved Hide resolved
src/transformers/trainer_callback.py Outdated Show resolved Hide resolved
metric_value = metrics.get(metric_to_check)

if metric_value is None:
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good warning!

tests/test_trainer.py Outdated Show resolved Hide resolved
self.early_stopping_patience_counter += 1

def on_train_begin(self, args, state, control, **kwargs):
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand why this line is necessary? I feel we should be able to use this callback without the option load_best_model_at_end? The other sanity checks are perfectly ok.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary because we require control.should_save=True for _save_checkpoint to update the best metric. Should I move the best metric calculation into its own function and place it in the should_evaluate block?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it's not fully intuitive to need load_best_model_at_end, but it makes sense to me because if we don't load the best model early stopping will stop us, but the model we receive back from training will not be the model early stopping deemed best.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's leave it as is for now then, and we will re-evaluate if some users complain!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Saw this issue while debugging something. It doesn't seem intuitive how these two are related, so can we please do what @cbrochtrup suggested above?

@cbrochtrup
Copy link
Author

Thanks for your thorough and affable review!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great addition, LGTM!

@sgugger sgugger merged commit 8ffc01a into huggingface:master Nov 23, 2020
@cbrochtrup cbrochtrup deleted the early-stopping-patience branch November 23, 2020 23:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants