Skip to content

Commit

Permalink
Merge pull request #327 from jseabold/sklearn-eval-set
Browse files Browse the repository at this point in the history
ENH: Allow early stopping through scikit-learn API
  • Loading branch information
tqchen committed Jul 26, 2015
2 parents b1dec91 + b76db01 commit eee0d5b
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 15 deletions.
22 changes: 21 additions & 1 deletion demo/guide-python/sklearn_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import xgboost as xgb

import numpy as np
from sklearn.cross_validation import KFold
from sklearn.cross_validation import KFold, train_test_split
from sklearn.metrics import confusion_matrix, mean_squared_error
from sklearn.grid_search import GridSearchCV
from sklearn.datasets import load_iris, load_digits, load_boston
Expand Down Expand Up @@ -65,3 +65,23 @@
pickle.dump(clf, open("best_boston.pkl", "wb"))
clf2 = pickle.load(open("best_boston.pkl", "rb"))
print(np.allclose(clf.predict(X), clf2.predict(X)))

# Early-stopping

X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = xgb.XGBClassifier()
clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc",
eval_set=[(X_test, y_test)])

# Custom evaluation function
from sklearn.metrics import log_loss


def log_loss_eval(y_pred, y_true):
return "log-loss", log_loss(y_true.get_label(), y_pred)


clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric=log_loss_eval,
eval_set=[(X_test, y_test)])
157 changes: 143 additions & 14 deletions wrapper/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Authors: Tianqi Chen, Bing Xu
Early stopping by Zygmunt Zając
"""
# pylint: disable=too-many-arguments, too-many-locals, too-many-lines, invalid-name
# pylint: disable=too-many-arguments, too-many-locals, too-many-lines, invalid-name, fixme
from __future__ import absolute_import

import os
Expand Down Expand Up @@ -738,7 +738,7 @@ def get_fscore(self, fmap=''):


def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
early_stopping_rounds=None, evals_result=None):
early_stopping_rounds=None, evals_result=None, verbose_eval=True):
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
"""Train a booster with given parameters.
Expand Down Expand Up @@ -767,12 +767,14 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
bst.best_score and bst.best_iteration.
evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist
verbose_eval : bool
If `verbose_eval` then the evaluation metric on the validation set, if
given, is printed at each boosting stage.
Returns
-------
booster : a trained booster model
"""

evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals])

Expand All @@ -782,7 +784,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
else:
evals_name = [d[1] for d in evals]
evals_result.clear()
evals_result.update({key:[] for key in evals_name})
evals_result.update({key: [] for key in evals_name})

if not early_stopping_rounds:
for i in range(num_boost_round):
Expand All @@ -794,9 +796,10 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
else:
msg = bst_eval_set.decode()

sys.stderr.write(msg + '\n')
if verbose_eval:
sys.stderr.write(msg + '\n')
if evals_result is not None:
res = re.findall(":([0-9.]+).", msg)
res = re.findall(":-?([0-9.]+).", msg)
for key, val in zip(evals_name, res):
evals_result[key].append(val)
return bst
Expand Down Expand Up @@ -840,10 +843,11 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
else:
msg = bst_eval_set.decode()

sys.stderr.write(msg + '\n')
if verbose_eval:
sys.stderr.write(msg + '\n')

if evals_result is not None:
res = re.findall(":([0-9.]+).", msg)
res = re.findall(":-([0-9.]+).", msg)
for key, val in zip(evals_name, res):
evals_result[key].append(val)

Expand Down Expand Up @@ -1074,6 +1078,8 @@ def get_params(self, deep=False):
params = super(XGBModel, self).get_params(deep=deep)
if params['missing'] is np.nan:
params['missing'] = None # sklearn doesn't handle nan. see #4725
if not params.get('eval_metric', True):
del params['eval_metric'] # don't give as None param to Booster
return params

def get_xgb_params(self):
Expand All @@ -1086,10 +1092,71 @@ def get_xgb_params(self):
xgb_params.pop('nthread', None)
return xgb_params

def fit(self, data, y):
# pylint: disable=missing-docstring,invalid-name
train_dmatrix = DMatrix(data, label=y, missing=self.missing)
self._Booster = train(self.get_xgb_params(), train_dmatrix, self.n_estimators)
def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""
Fit the gradient boosting model
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for
early-stopping
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.md. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
early_stopping_rounds : int
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. Returns the model from the last iteration
(not the best one). If early stopping occurs, the model will
have two additional fields: bst.best_score and bst.best_iteration.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
"""
trainDmatrix = DMatrix(X, label=y, missing=self.missing)

eval_results = {}
if eval_set is not None:
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
evals = list(zip(evals, ["validation_{}".format(i) for i in
range(len(evals))]))
else:
evals = ()

params = self.get_xgb_params()

feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({'eval_metric': eval_metric})

self._Booster = train(params, trainDmatrix,
self.n_estimators, evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=eval_results, feval=feval,
verbose_eval=verbose)
if eval_results:
eval_results = {k: np.array(v, dtype=float)
for k, v in eval_results.items()}
eval_results = {k: np.array(v) for k, v in eval_results.items()}
self.eval_results_ = eval_results
self.best_score_ = self._Booster.best_score
self.best_iteration_ = self._Booster.best_iteration
return self

def predict(self, data):
Expand Down Expand Up @@ -1117,8 +1184,43 @@ def __init__(self, max_depth=3, learning_rate=0.1,
colsample_bytree,
base_score, seed, missing)

def fit(self, X, y, sample_weight=None):
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
sample_weight : array_like
Weight for each instance
eval_set : list, optional
A list of (X, y) pairs to use as a validation set for
early-stopping
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.md. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
early_stopping_rounds : int, optional
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. Returns the model from the last iteration
(not the best one). If early stopping occurs, the model will
have two additional fields: bst.best_score and bst.best_iteration.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
"""
eval_results = {}
self.classes_ = list(np.unique(y))
self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2:
Expand All @@ -1129,6 +1231,22 @@ def fit(self, X, y, sample_weight=None):
else:
xgb_options = self.get_xgb_params()

feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
xgb_options.update({"eval_metric": eval_metric})

if eval_set is not None:
# TODO: use sample_weight if given?
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
nevals = len(evals)
eval_names = ["validation_{}".format(i) for i in range(nevals)]
evals = list(zip(evals, eval_names))
else:
evals = ()

self._le = LabelEncoder().fit(y)
training_labels = self._le.transform(y)

Expand All @@ -1139,7 +1257,18 @@ def fit(self, X, y, sample_weight=None):
train_dmatrix = DMatrix(X, label=training_labels,
missing=self.missing)

self._Booster = train(xgb_options, train_dmatrix, self.n_estimators)
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=eval_results, feval=feval,
verbose_eval=verbose)

if eval_results:
eval_results = {k: np.array(v, dtype=float)
for k, v in eval_results.items()}
self.eval_results_ = eval_results
self.best_score_ = self._Booster.best_score
self.best_iteration_ = self._Booster.best_iteration

return self

Expand Down

0 comments on commit eee0d5b

Please sign in to comment.