Skip to content

Commit

Permalink
refine early stopping and add a test case (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
wxchan authored and guolinke committed Mar 28, 2017
1 parent 1141ed9 commit 6ed335d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
1 change: 0 additions & 1 deletion python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def callback(env):
env.iteration + 1, '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
best_msg[i] = best_msg_buffer
elif env.iteration - best_iter[i] >= stopping_rounds:
env.model.set_attr(best_iteration=str(best_iter[i]))
if verbose:
print('Early stopping, best iteration is:\n' + best_msg[i])
raise EarlyStopException(best_iter[i])
Expand Down
14 changes: 7 additions & 7 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def train(params, train_set, num_boost_round=100,
booster.set_train_data_name(train_data_name)
for valid_set, name_valid_set in zip(reduced_valid_sets, name_valid_sets):
booster.add_valid(valid_set, name_valid_set)
booster.best_iteration = -1

"""start training"""
for i in range_(init_iteration, init_iteration + num_boost_round):
Expand Down Expand Up @@ -192,19 +193,17 @@ def train(params, train_set, num_boost_round=100,
begin_iteration=init_iteration,
end_iteration=init_iteration + num_boost_round,
evaluation_result_list=evaluation_result_list))
except callback.EarlyStopException:
except callback.EarlyStopException as earlyStopException:
booster.best_iteration = earlyStopException.best_iteration + 1
break
if booster.attr('best_iteration') is not None:
booster.best_iteration = int(booster.attr('best_iteration')) + 1
else:
booster.best_iteration = -1
return booster


class CVBooster(object):
""""Auxiliary data struct to hold all boosters of CV."""
def __init__(self):
self.boosters = []
self.best_iteration = -1

def append(self, booster):
"""add a booster to CVBooster"""
Expand Down Expand Up @@ -408,8 +407,9 @@ def cv(params, train_set, num_boost_round=10,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=res))
except callback.EarlyStopException as e:
except callback.EarlyStopException as earlyStopException:
cvfolds.best_iteration = earlyStopException.best_iteration + 1
for k in results:
results[k] = results[k][:e.best_iteration + 1]
results[k] = results[k][:cvfolds.best_iteration]
break
return dict(results)
26 changes: 26 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,32 @@ def test_multiclass(self):
self.assertLess(ret, 0.2)
self.assertAlmostEqual(min(evals_result['eval']['multi_logloss']), ret, places=5)

def test_early_stopping(self):
X_y = load_breast_cancer(True)
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1,
'seed': 42
}
X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
# no early stopping
gbm = lgb.train(params, lgb_train,
num_boost_round=10,
valid_sets=lgb_eval,
verbose_eval=False,
early_stopping_rounds=5)
self.assertEqual(gbm.best_iteration, -1)
# early stopping occurs
gbm = lgb.train(params, lgb_train,
num_boost_round=100,
valid_sets=lgb_eval,
verbose_eval=False,
early_stopping_rounds=5)
self.assertLessEqual(gbm.best_iteration, 100)

def test_continue_train_and_other(self):
params = {
'objective': 'regression',
Expand Down

0 comments on commit 6ed335d

Please sign in to comment.