Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Use first_metric_only flag for early_stopping function. #2049

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ Note that ``train()`` will return a model from the best iteration.

This works with both metrics to minimize (L2, log loss, etc.) and to maximize (NDCG, AUC, etc.).
Note that if you specify more than one evaluation metric, all of them will be used for early stopping.
However, you can change this behavior and make LightGBM check only the first metric for early stopping by creating ``early_stopping`` callback with ``first_metric_only=True``.

Prediction
----------
Expand Down
7 changes: 6 additions & 1 deletion python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _callback(env):
return _callback


def early_stopping(stopping_rounds, verbose=True):
def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
"""Create a callback that activates early stopping.

Note
Expand All @@ -161,11 +161,14 @@ def early_stopping(stopping_rounds, verbose=True):
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric set ``first_metric_only`` to True.

Parameters
----------
stopping_rounds : int
The possible number of rounds without the trend occurrence.
first_metric_only : bool, optional (default=False)
Whether to use only the first metric for early stopping.
verbose : bool, optional (default=True)
Whether to print message with early stopping information.

Expand Down Expand Up @@ -227,5 +230,7 @@ def _callback(env):
print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
raise EarlyStopException(best_iter[i], best_score_list[i])
if first_metric_only: # the only first metric is used for early stopping
break
_callback.order = 30
return _callback
4 changes: 4 additions & 0 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def train(params, train_set, num_boost_round=100,
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric you can pass in ``callbacks``
``early_stopping`` callback with ``first_metric_only=True``.
The index of iteration that has the best performance will be saved in the ``best_iteration`` field
if early stopping logic is enabled by setting ``early_stopping_rounds``.
evals_result: dict or None, optional (default=None)
Expand Down Expand Up @@ -391,6 +393,8 @@ def cv(params, train_set, num_boost_round=100,
CV score needs to improve at least every ``early_stopping_rounds`` round(s)
to continue.
Requires at least one metric. If there's more than one, will check all of them.
To check only the first metric you can pass in ``callbacks``
``early_stopping`` callback with ``first_metric_only=True``.
Last entry in evaluation history is the one from the best iteration.
fpreproc : callable or None, optional (default=None)
Preprocessing function that takes (dtrain, dtest, params)
Expand Down
2 changes: 2 additions & 0 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def fit(self, X, y,
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric you can pass in ``callbacks``
``early_stopping`` callback with ``first_metric_only=True``.
verbose : bool or int, optional (default=True)
Requires at least one evaluation data.
If True, the eval metric on the eval set is printed at each boosting stage.
Expand Down
43 changes: 43 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8
# pylint: skip-file
import copy
import itertools
import math
import os
import psutil
Expand Down Expand Up @@ -1318,3 +1319,45 @@ def test_get_split_value_histogram(self):
np.testing.assert_almost_equal(bin_edges[1:][mask], hist[:, 0])
# test histogram is disabled for categorical features
self.assertRaises(lgb.basic.LightGBMError, gbm.get_split_value_histogram, 2)

def test_early_stopping_for_only_first_metric(self):
X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'objective': 'regression',
'metric': 'None',
'verbose': -1
}
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)

decreasing_generator = itertools.count(0, -1)

def decreasing_metric(preds, train_data):
return ('decreasing_metric', next(decreasing_generator), False)

def constant_metric(preds, train_data):
return ('constant_metric', 0.0, False)

# test that all metrics are checked (default behaviour)
early_stop_callback = lgb.early_stopping(5, verbose=False)
gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval],
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
constant_metric(preds, train_data)],
callbacks=[early_stop_callback])
self.assertEqual(gbm.best_iteration, 1)

# test that only the first metric is checked
early_stop_callback = lgb.early_stopping(5, first_metric_only=True, verbose=False)
gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval],
feval=lambda preds, train_data: [decreasing_metric(preds, train_data),
constant_metric(preds, train_data)],
callbacks=[early_stop_callback])
self.assertEqual(gbm.best_iteration, 20)
# ... change the order of metrics
early_stop_callback = lgb.early_stopping(5, first_metric_only=True, verbose=False)
gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=[lgb_eval],
feval=lambda preds, train_data: [constant_metric(preds, train_data),
decreasing_metric(preds, train_data)],
callbacks=[early_stop_callback])
self.assertEqual(gbm.best_iteration, 1)