Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EarlyStoppingCallback early_stopping_patience_counter and Trainer.state not reset between hyperparameter search trials #502

Open
chrisaballard opened this issue Mar 21, 2024 · 1 comment

Comments

@chrisaballard
Copy link

If using the EarlyStoppingCallback from the transformers package, and running the Trainer.hyperparameter_search method, early_stopping_patience_counter on the callback instance and Trainer.state are not reset between subsequent trials.

This means that when running the second and subsequent trials, the early_stopping_patience_counter is greater than the early_stopping_patience. This causes Trainer.control.should_training_stop to be set to True the first time on_evaluate is run on the callbacks. The trial is then terminated.

To resolve this issue, I suggest that when calling Trainer.train(trial=...) and passing in a trial, Trainer.callback_hander should be re-initialised and Trainer.state initialised to TrainerState().

@chrisaballard
Copy link
Author

The state and callback handler are defined in Trainer.__init__():

https://github.com/huggingface/setfit/blob/73d8646e8c49d7926a5bac0ff0eb9f0d3bf96f75/src/setfit/trainer.py#L235C9-L238C94

Create a new Trainer instance with an EarlyStoppingCallback:

trainer = Trainer(
    args=training_arguments,
    metric=metric,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    column_mapping=column_mapping,
    model_init=model_init(...),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

Run hyperparameter search:

trainer.hyperparameter_search(
    direction="maximize", hp_space=hp_space, n_trials=n_trials, **kwargs
)

The early_stopping_patience_counter is stored in the EarlyStoppingCallback instance and not reset to 0 between trials.

The on_evaluate_method of EarlyStoppingCallbackcompares the current eval loss tostate.best_metric`:

https://github.com/huggingface/transformers/blob/f4364a6ff16e33186cb40f1d3fafd3792556d1b8/src/transformers/trainer_callback.py#L566

Because Trainer.state is not reset between trials, Trainer.state.best_metric has the most recent value from the last trial.

Because EarlyStoppingCallback.early_stopping_patience_counter is not reset between trials, it is incremented each time the callback is evaluated across all runs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant