Skip to content

Commit

Permalink
[MRG] [python] check params for num_boost_round & early_stopping_roun…
Browse files Browse the repository at this point in the history
…ds (#806)

* check params

* add test case

* fix pylint
  • Loading branch information
wxchan authored and guolinke committed Aug 18, 2017
1 parent ddda85b commit c8142e3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
27 changes: 25 additions & 2 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from __future__ import absolute_import

import collections
import warnings
from operator import attrgetter

import numpy as np

from . import callback
from .basic import Booster, Dataset, LightGBMError, _InnerPredictor
from .compat import (SKLEARN_INSTALLED, LGBMStratifiedKFold, LGBMGroupKFold, integer_types,
range_, string_type)
from .compat import (SKLEARN_INSTALLED, LGBMGroupKFold, LGBMStratifiedKFold,
integer_types, range_, string_type)


def train(params, train_set, num_boost_round=100,
Expand Down Expand Up @@ -94,6 +95,17 @@ def train(params, train_set, num_boost_round=100,
booster : a trained booster model
"""
"""create predictor first"""
for alias in ["num_boost_round", "num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds"]:
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias)
break
for alias in ["early_stopping_round", "early_stopping_rounds", "early_stopping"]:
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias)
break

if isinstance(init_model, string_type):
predictor = _InnerPredictor(model_file=init_model)
elif isinstance(init_model, Booster):
Expand Down Expand Up @@ -370,6 +382,17 @@ def cv(params, train_set, num_boost_round=10,
if not isinstance(train_set, Dataset):
raise TypeError("Traninig only accepts Dataset object")

for alias in ["num_boost_round", "num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds"]:
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias)
break
for alias in ["early_stopping_round", "early_stopping_rounds", "early_stopping"]:
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias)
break

if isinstance(init_model, string_type):
predictor = _InnerPredictor(model_file=init_model)
elif isinstance(init_model, Booster):
Expand Down
21 changes: 9 additions & 12 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,20 @@ def test_binary(self):
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1
'verbose': -1,
'num_iteration': 50 # test num_iteration in dict here
}
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=50,
num_boost_round=20,
valid_sets=lgb_eval,
verbose_eval=False,
evals_result=evals_result)
ret = log_loss(y_test, gbm.predict(X_test))
self.assertLess(ret, 0.15)
self.assertEqual(len(evals_result['valid_0']['binary_logloss']), 50)
self.assertAlmostEqual(evals_result['valid_0']['binary_logloss'][-1], ret, places=5)

def test_rf(self):
Expand Down Expand Up @@ -454,17 +456,12 @@ def test_pandas_categorical(self):
np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred0, pred4)

def test_subset_train_val(self):
'''
Tests that it's fine to construct a single lgb.Dataframe object,
takes subsets of it, and uses the subsets for training and validation
'''
n = 1000
X = np.random.normal(size=(n, 2))
y = np.random.normal(size=n)
def test_reference_chain(self):
X = np.random.normal(size=(100, 2))
y = np.random.normal(size=100)
tmp_dat = lgb.Dataset(X, y)
# take subsets and train
tmp_dat_train = tmp_dat.subset(np.arange(int(n * .8)))
tmp_dat_val = tmp_dat.subset(np.arange(int(n * .8), n)).subset(np.arange(n * .2 * .9))
tmp_dat_train = tmp_dat.subset(np.arange(80))
tmp_dat_val = tmp_dat.subset(np.arange(80, 100)).subset(np.arange(18))
params = {'objective': 'regression_l2', 'metric': 'rmse'}
gbm = lgb.train(params, tmp_dat_train, num_boost_round=20, valid_sets=[tmp_dat_train, tmp_dat_val])

0 comments on commit c8142e3

Please sign in to comment.