Skip to content

Commit

Permalink
Fix learning rate scheduler with cv.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 27, 2021
1 parent 9c85234 commit 0734044
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 36 deletions.
13 changes: 9 additions & 4 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def _get_callback_context(env):
context = 'train'
elif env.model is None and env.cvfolds is not None:
context = 'cv'
else:
raise ValueError("Unexpected input with both model and cvfolds.")
return context


Expand Down Expand Up @@ -452,9 +454,11 @@ class LearningRateScheduler(TrainingCallback):
rounds.
'''

def __init__(self, learning_rates) -> None:
assert callable(learning_rates) or \
isinstance(learning_rates, collections.abc.Sequence)
assert callable(learning_rates) or isinstance(
learning_rates, collections.abc.Sequence
)
if callable(learning_rates):
self.learning_rates = learning_rates
else:
Expand Down Expand Up @@ -751,7 +755,7 @@ def before_iteration(self, model, epoch, dtrain, evals):
'''Called before each iteration.'''
for cb in self.callbacks_before_iter:
rank = rabit.get_rank()
cb(CallbackEnv(model=model,
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
cvfolds=self.cvfolds,
iteration=epoch,
begin_iteration=self.start_iteration,
Expand All @@ -764,6 +768,7 @@ def after_iteration(self, model, epoch, dtrain, evals):
'''Called after each iteration.'''
evaluation_result_list = []
if self.cvfolds is not None:
# dtrain is not used here.
scores = model.eval(epoch, self.feval)
self.aggregated_cv = _aggcv(scores)
evaluation_result_list = self.aggregated_cv
Expand All @@ -782,7 +787,7 @@ def after_iteration(self, model, epoch, dtrain, evals):
try:
for cb in self.callbacks_after_iter:
rank = rabit.get_rank()
cb(CallbackEnv(model=model,
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
cvfolds=self.cvfolds,
iteration=epoch,
begin_iteration=self.start_iteration,
Expand Down
51 changes: 32 additions & 19 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ def _configure_deprecated_callbacks(
warnings.warn(f'Old style callback is deprecated. See: {link}', UserWarning)
# Most of legacy advanced options becomes callbacks
if early_stopping_rounds is not None:
callbacks.append(callback.early_stop(early_stopping_rounds,
maximize=maximize,
verbose=bool(verbose_eval)))
callbacks.append(
callback.early_stop(
early_stopping_rounds, maximize=maximize, verbose=bool(verbose_eval)
)
)
if isinstance(verbose_eval, bool) and verbose_eval:
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
else:
Expand Down Expand Up @@ -180,7 +182,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
[xgb.callback.LearningRateScheduler(custom_rates)]
Returns
-------
Expand All @@ -207,6 +209,11 @@ def __init__(self, dtrain, dtest, param):
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
self.bst = Booster(param, [dtrain, dtest])

def __getattr__(self, name):
def _inner(*args, **kwargs):
return getattr(self.bst, name)(*args, **kwargs)
return _inner

def update(self, iteration, fobj):
""""Update the boosters for one iteration"""
self.bst.update(self.dtrain, iteration, fobj)
Expand All @@ -216,14 +223,13 @@ def eval(self, iteration, feval):
return self.bst.eval_set(self.watchlist, iteration, feval)


class _PackedBooster:
class _PackedBooster: # pylint: disable=too-few-public-methods
def __init__(self, cvfolds) -> None:
self.cvfolds = cvfolds

def update(self, iteration, obj):
'''Iterate through folds for update'''
for fold in self.cvfolds:
fold.update(iteration, obj)
def update(self, iteration, fobj):
for booster in self.cvfolds:
booster.update(iteration, fobj)

def eval(self, iteration, feval):
'''Iterate through folds for eval'''
Expand All @@ -239,15 +245,20 @@ def attr(self, key):
'''Redirect to booster attr.'''
return self.cvfolds[0].bst.attr(key)

def set_param(self, params, value=None):
for f in self.cvfolds:
f.bst.set_param(params, value)

def num_boosted_rounds(self):
return self.cvfolds[0].num_boosted_rounds()

@property
def best_iteration(self):
'''Get best_iteration'''
ret = self.cvfolds[0].bst.attr('best_iteration')
return int(ret)
return int(self.cvfolds[0].bst.attr("best_iteration"))

def num_boosted_rounds(self) -> int:
'''Number of boosted rounds.'''
return self.cvfolds[0].bst.num_boosted_rounds()
@property
def best_score(self):
return float(self.cvfolds[0].bst.attr("best_score"))


def groups_to_rows(groups, boundaries):
Expand Down Expand Up @@ -419,7 +430,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
[xgb.callback.LearningRateScheduler(custom_rates)]
shuffle : bool
Shuffle data before creating folds.
Expand Down Expand Up @@ -464,14 +475,16 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
callbacks.append(callback.EvaluationMonitor(period=verbose_eval,
show_stdv=show_stdv))
if early_stopping_rounds:
callbacks.append(callback.EarlyStopping(
rounds=early_stopping_rounds, maximize=maximize))
callbacks.append(
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
)
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
else:
callbacks = _configure_deprecated_callbacks(
verbose_eval, early_stopping_rounds, maximize, 0,
num_boost_round, feval, None, callbacks,
show_stdv=show_stdv, cvfolds=cvfolds)
show_stdv=show_stdv, cvfolds=cvfolds
)
booster = _PackedBooster(cvfolds)
callbacks.before_training(booster)

Expand Down
31 changes: 18 additions & 13 deletions tests/python/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,18 @@ def test_early_stopping_continuation(self):
assert booster.num_boosted_rounds() == \
booster.best_iteration + early_stopping_rounds + 1

def run_eta_decay(self, tree_method, deprecated_callback):
@pytest.mark.parametrize(
"tree_method, deprecated_callback",
[
("hist", True),
("hist", False),
("approx", True),
("approx", False),
("exact", True),
("exact", False),
],
)
def test_eta_decay(self, tree_method, deprecated_callback):
if deprecated_callback:
scheduler = xgb.callback.reset_learning_rate
else:
Expand All @@ -217,7 +228,10 @@ def run_eta_decay(self, tree_method, deprecated_callback):
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 4

warning_check = pytest.warns(UserWarning) if deprecated_callback else tm.noop_context()
if deprecated_callback:
warning_check = pytest.warns(UserWarning)
else:
warning_check = tm.noop_context()

# learning_rates as a list
# init eta with 0 to check whether learning_rates work
Expand Down Expand Up @@ -288,17 +302,8 @@ def eta_decay(ithround, num_boost_round=num_round):
for i in range(1, len(eval_errors_0)):
assert eval_errors_3[i] != eval_errors_2[i]

def test_eta_decay_hist(self):
self.run_eta_decay('hist', True)
self.run_eta_decay('hist', False)

def test_eta_decay_approx(self):
self.run_eta_decay('approx', True)
self.run_eta_decay('approx', False)

def test_eta_decay_exact(self):
self.run_eta_decay('exact', True)
self.run_eta_decay('exact', False)
with warning_check:
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])

def test_check_point(self):
from sklearn.datasets import load_breast_cancer
Expand Down

0 comments on commit 0734044

Please sign in to comment.