Skip to content

Commit

Permalink
MNT Use check_scalar in AdaBoostClassifier (scikit-learn#21442)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
  • Loading branch information
3 people committed Nov 2, 2021
1 parent 7dcee61 commit 1ea7289
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
23 changes: 22 additions & 1 deletion sklearn/ensemble/_weight_boosting.py
Expand Up @@ -25,6 +25,7 @@

from abc import ABCMeta, abstractmethod

import numbers
import numpy as np

import warnings
Expand All @@ -36,6 +37,7 @@

from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
from ..utils import check_random_state, _safe_indexing
from ..utils import check_scalar
from ..utils.extmath import softmax
from ..utils.extmath import stable_cumsum
from ..metrics import accuracy_score, r2_score
Expand Down Expand Up @@ -478,9 +480,28 @@ def fit(self, X, y, sample_weight=None):
self : object
Fitted estimator.
"""
check_scalar(
self.n_estimators,
"n_estimators",
target_type=numbers.Integral,
min_val=1,
include_boundaries="left",
)

check_scalar(
self.learning_rate,
"learning_rate",
target_type=numbers.Real,
min_val=0,
include_boundaries="neither",
)

# Check that algorithm is supported
if self.algorithm not in ("SAMME", "SAMME.R"):
raise ValueError("algorithm %s is not supported" % self.algorithm)
raise ValueError(
"Algorithm must be 'SAMME' or 'SAMME.R'."
f" Got {self.algorithm!r} instead."
)

# Fit
return super().fit(X, y, sample_weight)
Expand Down
32 changes: 26 additions & 6 deletions sklearn/ensemble/tests/test_weight_boosting.py
Expand Up @@ -273,12 +273,6 @@ def test_importances():
def test_error():
# Test that it gives proper exception on deficient input.

with pytest.raises(ValueError):
AdaBoostClassifier(learning_rate=-1).fit(X, y_class)

with pytest.raises(ValueError):
AdaBoostClassifier(algorithm="foo").fit(X, y_class)

with pytest.raises(ValueError):
AdaBoostClassifier().fit(X, y_class, sample_weight=np.asarray([-1]))

Expand Down Expand Up @@ -549,6 +543,32 @@ def test_adaboostregressor_sample_weight():
assert score_no_outlier == pytest.approx(score_with_weight)


@pytest.mark.parametrize(
"params, err_type, err_msg",
[
({"n_estimators": -1}, ValueError, "n_estimators == -1, must be >= 1"),
({"n_estimators": 0}, ValueError, "n_estimators == 0, must be >= 1"),
(
{"n_estimators": 1.5},
TypeError,
"n_estimators must be an instance of <class 'numbers.Integral'>,"
" not <class 'float'>",
),
({"learning_rate": -1}, ValueError, "learning_rate == -1, must be > 0."),
({"learning_rate": 0}, ValueError, "learning_rate == 0, must be > 0."),
(
{"algorithm": "unknown"},
ValueError,
"Algorithm must be 'SAMME' or 'SAMME.R'.",
),
],
)
def test_adaboost_classifier_params_validation(params, err_type, err_msg):
"""Check the parameters validation in `AdaBoostClassifier`."""
with pytest.raises(err_type, match=err_msg):
AdaBoostClassifier(**params).fit(X, y_class)


@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
def test_adaboost_consistent_predict(algorithm):
# check that predict_proba and predict give consistent results
Expand Down

0 comments on commit 1ea7289

Please sign in to comment.