Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] reorganize early stopping callback #6114

Merged
merged 8 commits into from Oct 5, 2023
64 changes: 45 additions & 19 deletions python-package/lightgbm/callback.py
Expand Up @@ -229,7 +229,12 @@ def __call__(self, env: CallbackEnv) -> None:
if new_param != env.params.get(key, None):
new_parameters[key] = new_param
if new_parameters:
env.model.reset_parameter(new_parameters)
if isinstance(env.model, Booster):
env.model.reset_parameter(new_parameters)
else:
# CVBooster holds a list of Booster objects, each needs to be updated
for booster in env.model.boosters:
booster.reset_parameter(new_parameters)
env.params.update(new_parameters)


Expand Down Expand Up @@ -267,6 +272,10 @@ def __init__(
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0
) -> None:

if stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be greater than zero. got: {stopping_rounds}")

self.order = 30
self.before_iteration = False

Expand All @@ -291,32 +300,45 @@ def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta

def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name
def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool:
"""Check, by name, if a given Dataset is the training data."""
# for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
# and those metrics are considered for early stopping
if ds_name == "cv_agg" and eval_name == "train":
return True

# for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
if isinstance(env.model, Booster) and ds_name == env.model._train_data_name:
return True

return False

def _init(self, env: CallbackEnv) -> None:
if env.evaluation_result_list is None or env.evaluation_result_list == []:
raise ValueError(
"For early stopping, at least one dataset and eval metric is required for evaluation"
)

is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
train_name=env.model._train_data_name)
)
self.enabled = not is_dart and not only_train_set
if not self.enabled:
if is_dart:
_log_warning('Early stopping is not available in dart mode')
elif only_train_set:
_log_warning('Only training set found, disabling early stopping.')
if is_dart:
self.enabled = False
_log_warning('Early stopping is not available in dart mode')
return

if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")
# validation sets are guaranteed to not be identical to the training data in cv()
if isinstance(env.model, Booster):
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
env=env
)
)
if only_train_set:
self.enabled = False
_log_warning('Only training set found, disabling early stopping.')
return

if self.verbose:
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
Expand Down Expand Up @@ -395,7 +417,11 @@ def __call__(self, env: CallbackEnv) -> None:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
if self._is_train_set(
ds_name=env.evaluation_result_list[i][0],
eval_name=eval_name_splitted[0],
env=env
):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
Expand Down