Skip to content

Commit

Permalink
[RFC][python] deprecate advanced args of train() and cv() functio…
Browse files Browse the repository at this point in the history
…ns and sklearn wrapper (#4574)

* deprecate advanced args of `train()` and `cv()`

* update Dask test

* improve deducing

* address review comments
  • Loading branch information
StrikerRUS committed Sep 12, 2021
1 parent a08c37f commit 86bda6f
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 29 deletions.
41 changes: 35 additions & 6 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,27 @@ def _format_eval_result(value: list, show_stdv: bool = True) -> str:


def print_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
"""Create a callback that prints the evaluation results.
"""Create a callback that logs the evaluation results.
By default, standard output resource is used.
Use ``register_logger()`` function to register a custom logger.
Note
----
Requires at least one validation data.
Parameters
----------
period : int, optional (default=1)
The period to print the evaluation results.
The period to log the evaluation results.
The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
show_stdv : bool, optional (default=True)
Whether to show stdv (if provided).
Whether to log stdv (if provided).
Returns
-------
callback : callable
The callback that prints the evaluation results every ``period`` iteration(s).
The callback that logs the evaluation results every ``period`` boosting iteration(s).
"""
def _callback(env: CallbackEnv) -> None:
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
Expand All @@ -82,6 +90,24 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
Any initial contents of the dictionary will be deleted.
.. rubric:: Example
With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
this dictionary after finishing a model training process will have the following structure:
.. code-block::
{
'train':
{
'logloss': [0.48253, 0.35953, ...]
},
'eval':
{
'logloss': [0.480385, 0.357756, ...]
}
}
Returns
-------
callback : callable
Expand Down Expand Up @@ -150,11 +176,12 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
Activates early stopping.
The model will train until the validation score stops improving.
Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
Validation score needs to improve at least every ``stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric set ``first_metric_only`` to True.
The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
Parameters
----------
Expand All @@ -163,7 +190,9 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
first_metric_only : bool, optional (default=False)
Whether to use only the first metric for early stopping.
verbose : bool, optional (default=True)
Whether to print message with early stopping information.
Whether to log message with early stopping information.
By default, standard output resource is used.
Use ``register_logger()`` function to register a custom logger.
Returns
-------
Expand Down
29 changes: 24 additions & 5 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def train(
categorical_feature: Union[List[str], List[int], str] = 'auto',
early_stopping_rounds: Optional[int] = None,
evals_result: Optional[Dict[str, Any]] = None,
verbose_eval: Union[bool, int] = True,
verbose_eval: Union[bool, int, str] = 'warn',
learning_rates: Optional[Union[List[float], Callable[[int], float]]] = None,
keep_training_booster: bool = False,
callbacks: Optional[List[Callable]] = None
Expand Down Expand Up @@ -121,7 +121,7 @@ def train(
To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
The index of iteration that has the best performance will be saved in the ``best_iteration`` field
if early stopping logic is enabled by setting ``early_stopping_rounds``.
evals_result: dict or None, optional (default=None)
evals_result : dict or None, optional (default=None)
Dictionary used to store all evaluation results of all the items in ``valid_sets``.
This should be initialized outside of your call to ``train()`` and should be empty.
Any initial contents of the dictionary will be deleted.
Expand Down Expand Up @@ -176,10 +176,13 @@ def train(
num_boost_round = params.pop(alias)
_log_warning(f"Found `{alias}` in params. Will use it instead of argument")
params["num_iterations"] = num_boost_round
# show deprecation warning only for early stop argument, setting early stop via global params should still be possible
if early_stopping_rounds is not None and early_stopping_rounds > 0:
_log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'early_stopping()' callback via 'callbacks' argument instead.")
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
early_stopping_rounds = params.pop(alias)
_log_warning(f"Found `{alias}` in params. Will use it instead of argument")
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)

Expand Down Expand Up @@ -233,6 +236,14 @@ def train(
callbacks = set(callbacks)

# Most of legacy advanced options becomes callbacks
if verbose_eval != "warn":
_log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'print_evaluation()' callback via 'callbacks' argument instead.")
else:
if callbacks: # assume user has already specified print_evaluation callback
verbose_eval = False
else:
verbose_eval = True
if verbose_eval is True:
callbacks.add(callback.print_evaluation())
elif isinstance(verbose_eval, int):
Expand All @@ -242,9 +253,13 @@ def train(
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))

if learning_rates is not None:
_log_warning("'learning_rates' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'reset_parameter()' callback via 'callbacks' argument instead.")
callbacks.add(callback.reset_parameter(learning_rate=learning_rates))

if evals_result is not None:
_log_warning("'evals_result' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'record_evaluation()' callback via 'callbacks' argument instead.")
callbacks.add(callback.record_evaluation(evals_result))

callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)}
Expand Down Expand Up @@ -520,7 +535,6 @@ def cv(params, train_set, num_boost_round=100,
and returns transformed versions of those.
verbose_eval : bool, int, or None, optional (default=None)
Whether to display the progress.
If None, progress will be displayed when np.ndarray is returned.
If True, progress will be displayed at every boosting stage.
If int, progress will be displayed at every given ``verbose_eval`` boosting stage.
show_stdv : bool, optional (default=True)
Expand Down Expand Up @@ -560,9 +574,11 @@ def cv(params, train_set, num_boost_round=100,
_log_warning(f"Found `{alias}` in params. Will use it instead of argument")
num_boost_round = params.pop(alias)
params["num_iterations"] = num_boost_round
if early_stopping_rounds is not None and early_stopping_rounds > 0:
_log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'early_stopping()' callback via 'callbacks' argument instead.")
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
_log_warning(f"Found `{alias}` in params. Will use it instead of argument")
early_stopping_rounds = params.pop(alias)
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)
Expand Down Expand Up @@ -601,6 +617,9 @@ def cv(params, train_set, num_boost_round=100,
callbacks = set(callbacks)
if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False))
if verbose_eval is not None:
_log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'print_evaluation()' callback via 'callbacks' argument instead.")
if verbose_eval is True:
callbacks.add(callback.print_evaluation(show_stdv=show_stdv))
elif isinstance(verbose_eval, int):
Expand Down
66 changes: 50 additions & 16 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

from .basic import Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning
from .callback import print_evaluation, record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
Expand Down Expand Up @@ -570,7 +571,7 @@ def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None, verbose=True,
eval_metric=None, early_stopping_rounds=None, verbose='warn',
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is set after definition, using a template."""
Expand All @@ -587,7 +588,7 @@ def fit(self, X, y,
self._fobj = _ObjectiveFunctionWrapper(self._objective)
else:
self._fobj = None
evals_result = {}

params = self.get_params()
# user can set verbose with kwargs, it has higher priority
if self.silent != "warn":
Expand Down Expand Up @@ -718,18 +719,51 @@ def _get_meta_data(collection, name, i):
if isinstance(init_model, LGBMModel):
init_model = init_model.booster_

self._Booster = train(params, train_set,
self.n_estimators, valid_sets=valid_sets, valid_names=eval_names,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, fobj=self._fobj, feval=eval_metrics_callable,
verbose_eval=verbose, feature_name=feature_name,
callbacks=callbacks, init_model=init_model)
if early_stopping_rounds is not None and early_stopping_rounds > 0:
_log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'early_stopping()' callback via 'callbacks' argument instead.")
params['early_stopping_rounds'] = early_stopping_rounds

if callbacks is None:
callbacks = []
else:
callbacks = copy.deepcopy(callbacks)

if verbose != 'warn':
_log_warning("'verbose' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'print_evaluation()' callback via 'callbacks' argument instead.")
else:
if callbacks: # assume user has already specified print_evaluation callback
verbose = False
else:
verbose = True
callbacks.append(print_evaluation(int(verbose)))

evals_result = {}
callbacks.append(record_evaluation(evals_result))

self._Booster = train(
params=params,
train_set=train_set,
num_boost_round=self.n_estimators,
valid_sets=valid_sets,
valid_names=eval_names,
fobj=self._fobj,
feval=eval_metrics_callable,
init_model=init_model,
feature_name=feature_name,
callbacks=callbacks
)

if evals_result:
self._evals_result = evals_result
else: # reset after previous call to fit()
self._evals_result = None

if early_stopping_rounds is not None and early_stopping_rounds > 0:
if self._Booster.best_iteration != 0:
self._best_iteration = self._Booster.best_iteration
else: # reset after previous call to fit()
self._best_iteration = None

self._best_score = self._Booster.best_score

Expand Down Expand Up @@ -791,16 +825,16 @@ def n_features_in_(self):

@property
def best_score_(self):
""":obj:`dict` or :obj:`None`: The best score of fitted model."""
""":obj:`dict`: The best score of fitted model."""
if self._n_features is None:
raise LGBMNotFittedError('No best_score found. Need to call fit beforehand.')
return self._best_score

@property
def best_iteration_(self):
""":obj:`int` or :obj:`None`: The best iteration of fitted model if ``early_stopping_rounds`` has been specified."""
""":obj:`int` or :obj:`None`: The best iteration of fitted model if ``early_stopping()`` callback has been specified."""
if self._n_features is None:
raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping_rounds beforehand.')
raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping callback beforehand.')
return self._best_iteration

@property
Expand All @@ -819,7 +853,7 @@ def booster_(self):

@property
def evals_result_(self):
""":obj:`dict` or :obj:`None`: The evaluation results if ``early_stopping_rounds`` has been specified."""
""":obj:`dict` or :obj:`None`: The evaluation results if validation sets have been specified."""
if self._n_features is None:
raise LGBMNotFittedError('No results found. Need to call fit with eval_set beforehand.')
return self._evals_result
Expand Down Expand Up @@ -852,7 +886,7 @@ def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, feature_name='auto', categorical_feature='auto',
verbose='warn', feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
super().fit(X, y, sample_weight=sample_weight, init_score=init_score,
Expand All @@ -878,7 +912,7 @@ def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
early_stopping_rounds=None, verbose='warn',
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
Expand Down Expand Up @@ -1006,7 +1040,7 @@ def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None, eval_metric=None,
eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose=True,
eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose='warn',
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
Expand Down
8 changes: 6 additions & 2 deletions tests/python_package_test/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@ def dummy_metric(_, __):
lgb_data = lgb.Dataset(X, y)

eval_records = {}
callbacks = [
lgb.record_evaluation(eval_records),
lgb.print_evaluation(2),
lgb.early_stopping(4)
]
lgb.train({'objective': 'binary', 'metric': ['auc', 'binary_error']},
lgb_data, num_boost_round=10, feval=dummy_metric,
valid_sets=[lgb_data], evals_result=eval_records,
categorical_feature=[1], early_stopping_rounds=4, verbose_eval=2)
valid_sets=[lgb_data], categorical_feature=[1], callbacks=callbacks)

lgb.plot_metric(eval_records)

Expand Down

0 comments on commit 86bda6f

Please sign in to comment.