Skip to content

Commit

Permalink
Implement __sklearn_is_fitted__. (#7230)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 15, 2021
1 parent d997c96 commit 037dd08
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion python-package/xgboost/sklearn.py
Expand Up @@ -435,6 +435,9 @@ def _more_tags(self) -> Dict[str, bool]:
'''Tags used for scikit-learn data validation.'''
return {'allow_nan': True, 'no_validation': True}

def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")

def get_booster(self) -> Booster:
"""Get the underlying xgboost Booster of this model.
Expand All @@ -444,7 +447,7 @@ def get_booster(self) -> Booster:
-------
booster : a xgboost booster of underlying model
"""
if not hasattr(self, '_Booster'):
if not self.__sklearn_is_fitted__():
from sklearn.exceptions import NotFittedError
raise NotFittedError('need to call fit or load_model beforehand')
return self._Booster
Expand Down
2 changes: 1 addition & 1 deletion tests/python-gpu/test_gpu_with_sklearn.py
Expand Up @@ -19,7 +19,7 @@ def test_gpu_binary_classification():
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold

digits = load_digits(2)
digits = load_digits(n_class=2)
y = digits['target']
X = digits['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
Expand Down

0 comments on commit 037dd08

Please sign in to comment.