Skip to content

Commit

Permalink
[breaking] Remove duplicated predict functions, Fix attributes IO.
Browse files Browse the repository at this point in the history
* Fix attributes not being restored.
* Rename all `data` to `X`. [breaking]
  • Loading branch information
trivialfis committed Jan 12, 2021
1 parent 03cd087 commit a2635af
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 87 deletions.
7 changes: 7 additions & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,6 +1741,13 @@ def load_model(self, fname):
else:
raise TypeError('Unknown file type: ', fname)

if self.attr("best_iteration") is not None:
self.best_iteration = int(self.attr("best_iteration"))
if self.attr("best_score") is not None:
self.best_score = float(self.attr("best_score"))
if self.attr("best_ntree_limit") is not None:
self.best_ntree_limit = int(self.attr("best_ntree_limit"))

def num_boosted_rounds(self) -> int:
'''Get number of boosted rounds. For gblinear this is reset to 0 after
serializing the model.
Expand Down
115 changes: 41 additions & 74 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,10 +682,16 @@ def fit(self, X, y, *, sample_weight=None, base_margin=None,
self._set_evaluation_result(evals_result)
return self

def predict(self, data, output_margin=False, ntree_limit=None,
validate_features=True, base_margin=None):
def predict(
self,
X,
output_margin=False,
ntree_limit=None,
validate_features=True,
base_margin=None
):
"""
Predict with `data`.
Predict with `X`.
.. note:: This function is not thread safe.
Expand All @@ -699,7 +705,7 @@ def predict(self, data, output_margin=False, ntree_limit=None,
Parameters
----------
data : array_like
X : array_like
Data to predict with
output_margin : bool
Whether to output the raw untransformed margin value.
Expand All @@ -718,16 +724,21 @@ def predict(self, data, output_margin=False, ntree_limit=None,
prediction : numpy array
"""
# pylint: disable=missing-docstring,invalid-name
test_dmatrix = DMatrix(data, base_margin=base_margin,
test_dmatrix = DMatrix(X, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
# get ntree_limit to use - if none specified, default to
# best_ntree_limit if defined, otherwise 0.
if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0)
return self.get_booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit,
validate_features=validate_features)
try:
ntree_limit = self.best_ntree_limit
except AttributeError:
ntree_limit = 0
return self.get_booster().predict(
test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit,
validate_features=validate_features
)

def apply(self, X, ntree_limit=0):
"""Return the predicted leaf every tree for each sample.
Expand Down Expand Up @@ -1037,50 +1048,21 @@ def fit(self, X, y, *, sample_weight=None, base_margin=None,
'Fit gradient boosting model',
'Fit gradient boosting classifier', 1)

def predict(self, data, output_margin=False, ntree_limit=None,
validate_features=True, base_margin=None):
"""
Predict with `data`.
.. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call
``xgb.copy()`` to make copies of model object and then call
``predict()``.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
Parameters
----------
data : array_like
Feature matrix.
output_margin : bool
Whether to output the raw untransformed margin value.
ntree_limit : int
Limit number of trees in the prediction; defaults to
best_ntree_limit if defined (i.e. it has been trained with early
stopping), otherwise 0 (use all trees).
validate_features : bool
When this is True, validate that the Booster's and data's
feature_names are identical. Otherwise, it is assumed that the
feature_names are the same.
Returns
-------
prediction : numpy array
"""
test_dmatrix = DMatrix(data, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0)
class_probs = self.get_booster().predict(
test_dmatrix,
def predict(
self,
X,
output_margin=False,
ntree_limit=None,
validate_features=True,
base_margin=None
):
class_probs = super().predict(
X=X,
output_margin=output_margin,
ntree_limit=ntree_limit,
validate_features=validate_features)
validate_features=validate_features,
base_margin=base_margin
)
if output_margin:
# If output_margin is active, simply return the scores
return class_probs
Expand Down Expand Up @@ -1125,13 +1107,13 @@ def predict_proba(self, X, ntree_limit=None, validate_features=False,
a numpy array of shape array-like of shape (n_samples, n_classes) with the
probability of each data example being of a given class.
"""
test_dmatrix = DMatrix(X, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0)
class_probs = self.get_booster().predict(test_dmatrix,
ntree_limit=ntree_limit,
validate_features=validate_features)
class_probs = super().predict(
X=X,
output_margin=False,
ntree_limit=ntree_limit,
validate_features=validate_features,
base_margin=base_margin
)
return _cls_predict_proba(self.objective, class_probs, np.vstack)

def evals_result(self):
Expand Down Expand Up @@ -1493,18 +1475,3 @@ def fit(
self.objective = params["objective"]
self._set_evaluation_result(evals_result)
return self

def predict(self, data, output_margin=False,
ntree_limit=0, validate_features=True, base_margin=None):

test_dmatrix = DMatrix(data, base_margin=base_margin,
missing=self.missing)
if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0)

return self.get_booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit,
validate_features=validate_features)

predict.__doc__ = XGBModel.predict.__doc__
37 changes: 24 additions & 13 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,9 @@ def _train_internal(params, dtrain,
if evals_result is not None and is_new_callback:
evals_result.update(callbacks.history)

if bst.attr('best_score') is not None:
bst.best_score = float(bst.attr('best_score'))
bst.best_iteration = int(bst.attr('best_iteration'))
else:
bst.best_iteration = bst.num_boosted_rounds() - 1

# These should be moved into callback functions `after_training`, but until old
# callbacks are removed, the train function is the only place for setting the
# attributes.
config = json.loads(bst.save_config())
booster = config['learner']['gradient_booster']['name']
if booster == 'gblinear':
Expand All @@ -114,7 +111,20 @@ def _train_internal(params, dtrain,

num_groups = int(config['learner']['learner_model_param']['num_class'])
num_groups = 1 if num_groups == 0 else num_groups
bst.best_ntree_limit = ((bst.best_iteration + 1) * num_parallel_tree * num_groups)
if bst.attr('best_score') is not None:
bst.best_score = float(bst.attr('best_score'))
bst.best_iteration = int(bst.attr('best_iteration'))
bst.set_attr(
best_ntree_limit=str(
(bst.best_iteration + 1) * num_parallel_tree * num_groups
)
)
bst.best_ntree_limit = int(bst.attr("best_ntree_limit"))
else:
# Due to compatibility with version older than 1.4, these attributes are added
# to Python object even if early stopping is not used.
bst.best_iteration = bst.num_boosted_rounds() - 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree * num_groups

# Copy to serialise and unserialise booster to reset state and free
# training memory
Expand Down Expand Up @@ -148,15 +158,16 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **evals**.
The method returns the model from the last iteration (not the best one).
If there's more than one item in **evals**, the last entry will be used
for early stopping.
The method returns the model from the last iteration (not the best one). Use
custom callback or model slicing if the best model is desired.
If there's more than one item in **evals**, the last entry will be used for early
stopping.
If there's more than one metric in the **eval_metric** parameter given in
**params**, the last metric will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
(Use ``bst.best_ntree_limit`` to get the correct value if
``num_parallel_tree`` and/or ``num_class`` appears in the parameters)
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``. (Use
``bst.best_ntree_limit`` to get the correct value if ``num_parallel_tree`` and/or
``num_class`` appears in the parameters)
evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist.
Expand Down
19 changes: 19 additions & 0 deletions tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,25 @@ def validate_model(parameters):
'objective': 'multi:softmax'}
validate_model(parameters)

@pytest.mark.skipif(**tm.no_sklearn())
def test_attributes(self):
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
cls = xgb.XGBClassifier(n_estimators=2)
cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)])
assert cls.get_booster().best_ntree_limit == 2 * cls.n_classes_
assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "cls.json")
cls.save_model(path)

cls = xgb.XGBClassifier(n_estimators=2)
cls.load_model(path)
assert cls.get_booster().best_ntree_limit == 2 * cls.n_classes_
assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit

@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize('booster', ['gbtree', 'dart'])
def test_slice(self, booster):
from sklearn.datasets import make_classification
Expand Down

0 comments on commit a2635af

Please sign in to comment.