Skip to content

Commit

Permalink
[python][sklearn] add n_estimators_ and n_iter_ post-fit attribut…
Browse files Browse the repository at this point in the history
…es (#4753)

* add n_estimators_ and n_iter_ post-fit attributes

* address review comments
  • Loading branch information
StrikerRUS committed Nov 5, 2021
1 parent 99f0f3e commit aab212a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
22 changes: 22 additions & 0 deletions python-package/lightgbm/sklearn.py
Expand Up @@ -847,6 +847,28 @@ def objective_(self):
raise LGBMNotFittedError('No objective found. Need to call fit beforehand.')
return self._objective

@property
def n_estimators_(self) -> int:
""":obj:`int`: True number of boosting iterations performed.
This might be less than parameter ``n_estimators`` if early stopping was enabled or
if boosting stopped early due to limits on complexity like ``min_gain_to_split``.
"""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_estimators found. Need to call fit beforehand.')
return self._Booster.current_iteration()

@property
def n_iter_(self) -> int:
""":obj:`int`: True number of boosting iterations performed.
This might be less than parameter ``n_estimators`` if early stopping was enabled or
if boosting stopped early due to limits on complexity like ``min_gain_to_split``.
"""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_iter found. Need to call fit beforehand.')
return self._Booster.current_iteration()

@property
def booster_(self):
"""Booster: The underlying Booster of this model."""
Expand Down
11 changes: 11 additions & 0 deletions tests/python_package_test/test_sklearn.py
Expand Up @@ -1158,6 +1158,17 @@ def test_continue_training_with_model():
assert gbm.evals_result_['valid_0']['multi_logloss'][-1] < init_gbm.evals_result_['valid_0']['multi_logloss'][-1]


def test_actual_number_of_trees():
X = [[1, 2, 3], [1, 2, 3]]
y = [1, 1]
n_estimators = 5
gbm = lgb.LGBMRegressor(n_estimators=n_estimators).fit(X, y)
assert gbm.n_estimators == n_estimators
assert gbm.n_estimators_ == 1
assert gbm.n_iter_ == 1
np.testing.assert_array_equal(gbm.predict(np.array(X) * 10), y)


# sklearn < 0.22 requires passing "attributes" argument
@pytest.mark.skipif(sk_version < parse_version('0.22'), reason='scikit-learn version is less than 0.22')
def test_check_is_fitted():
Expand Down

0 comments on commit aab212a

Please sign in to comment.