diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 4981df873c7..05a73829271 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -847,6 +847,28 @@ def objective_(self): raise LGBMNotFittedError('No objective found. Need to call fit beforehand.') return self._objective + @property + def n_estimators_(self) -> int: + """:obj:`int`: True number of boosting iterations performed. + + This might be less than parameter ``n_estimators`` if early stopping was enabled or + if boosting stopped early due to limits on complexity like ``min_gain_to_split``. + """ + if not self.__sklearn_is_fitted__(): + raise LGBMNotFittedError('No n_estimators found. Need to call fit beforehand.') + return self._Booster.current_iteration() + + @property + def n_iter_(self) -> int: + """:obj:`int`: True number of boosting iterations performed. + + This might be less than parameter ``n_estimators`` if early stopping was enabled or + if boosting stopped early due to limits on complexity like ``min_gain_to_split``. + """ + if not self.__sklearn_is_fitted__(): + raise LGBMNotFittedError('No n_iter found. Need to call fit beforehand.') + return self._Booster.current_iteration() + @property def booster_(self): """Booster: The underlying Booster of this model.""" diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 152757c7963..4204ffb4ec0 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1158,6 +1158,17 @@ def test_continue_training_with_model(): assert gbm.evals_result_['valid_0']['multi_logloss'][-1] < init_gbm.evals_result_['valid_0']['multi_logloss'][-1] +def test_actual_number_of_trees(): + X = [[1, 2, 3], [1, 2, 3]] + y = [1, 1] + n_estimators = 5 + gbm = lgb.LGBMRegressor(n_estimators=n_estimators).fit(X, y) + assert gbm.n_estimators == n_estimators + assert gbm.n_estimators_ == 1 + assert gbm.n_iter_ == 1 + np.testing.assert_array_equal(gbm.predict(np.array(X) * 10), y) + + # sklearn < 0.22 requires passing "attributes" argument @pytest.mark.skipif(sk_version < parse_version('0.22'), reason='scikit-learn version is less than 0.22') def test_check_is_fitted():