Skip to content

Commit

Permalink
Merge pull request #6637 from markotoplak/adaboost-base-estimator
Browse files Browse the repository at this point in the history
adaboost: adapt to scikit-learn's 1.4 deprecation of base_estimator
  • Loading branch information
lanzagar committed Nov 24, 2023
2 parents d9ce49d + ad23f30 commit 42f633f
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 43 deletions.
48 changes: 34 additions & 14 deletions Orange/ensembles/ada_boost.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import sklearn.ensemble as skl_ensemble

from Orange.base import SklLearner
Expand All @@ -7,6 +9,8 @@
from Orange.regression.base_regression import (
SklLearnerRegression, SklModelRegression
)
from Orange.util import OrangeDeprecationWarning


__all__ = ['SklAdaBoostClassificationLearner', 'SklAdaBoostRegressionLearner']

Expand All @@ -15,21 +19,32 @@ class SklAdaBoostClassifier(SklModelClassification):
pass


def base_estimator_deprecation():
warnings.warn(
"`base_estimator` is deprecated (to be removed in 3.39): use `estimator` instead.",
OrangeDeprecationWarning, stacklevel=3)


class SklAdaBoostClassificationLearner(SklLearnerClassification):
__wraps__ = skl_ensemble.AdaBoostClassifier
__returns__ = SklAdaBoostClassifier
supports_weights = True

def __init__(self, base_estimator=None, n_estimators=50, learning_rate=1.,
algorithm='SAMME.R', random_state=None, preprocessors=None):
def __init__(self, estimator=None, n_estimators=50, learning_rate=1.,
algorithm='SAMME.R', random_state=None, preprocessors=None,
base_estimator="deprecated"):
if base_estimator != "deprecated":
base_estimator_deprecation()
estimator = base_estimator
del base_estimator
from Orange.modelling import Fitter
# If fitter, get the appropriate Learner instance
if isinstance(base_estimator, Fitter):
base_estimator = base_estimator.get_learner(
base_estimator.CLASSIFICATION)
if isinstance(estimator, Fitter):
estimator = estimator.get_learner(
estimator.CLASSIFICATION)
# If sklearn learner, get the underlying sklearn representation
if isinstance(base_estimator, SklLearner):
base_estimator = base_estimator.__wraps__(**base_estimator.params)
if isinstance(estimator, SklLearner):
estimator = estimator.__wraps__(**estimator.params)
super().__init__(preprocessors=preprocessors)
self.params = vars()

Expand All @@ -43,15 +58,20 @@ class SklAdaBoostRegressionLearner(SklLearnerRegression):
__returns__ = SklAdaBoostRegressor
supports_weights = True

def __init__(self, base_estimator=None, n_estimators=50, learning_rate=1.,
loss='linear', random_state=None, preprocessors=None):
def __init__(self, estimator=None, n_estimators=50, learning_rate=1.,
loss='linear', random_state=None, preprocessors=None,
base_estimator="deprecated"):
if base_estimator != "deprecated":
base_estimator_deprecation()
estimator = base_estimator
del base_estimator
from Orange.modelling import Fitter
# If fitter, get the appropriate Learner instance
if isinstance(base_estimator, Fitter):
base_estimator = base_estimator.get_learner(
base_estimator.REGRESSION)
if isinstance(estimator, Fitter):
estimator = estimator.get_learner(
estimator.REGRESSION)
# If sklearn learner, get the underlying sklearn representation
if isinstance(base_estimator, SklLearner):
base_estimator = base_estimator.__wraps__(**base_estimator.params)
if isinstance(estimator, SklLearner):
estimator = estimator.__wraps__(**estimator.params)
super().__init__(preprocessors=preprocessors)
self.params = vars()
29 changes: 23 additions & 6 deletions Orange/tests/test_ada_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
# pylint: disable=missing-docstring

import unittest
from distutils.version import LooseVersion

import numpy as np

import Orange
from Orange.data import Table
from Orange.classification import SklTreeLearner
from Orange.regression import SklTreeRegressionLearner
Expand All @@ -11,6 +15,7 @@
SklAdaBoostRegressionLearner,
)
from Orange.evaluation import CrossValidation, CA, RMSE
from Orange.util import OrangeDeprecationWarning


class TestSklAdaBoostLearner(unittest.TestCase):
Expand All @@ -27,14 +32,14 @@ def test_adaboost(self):
self.assertGreater(ca, 0.9)
self.assertLess(ca, 0.99)

def test_adaboost_base_estimator(self):
def test_adaboost_estimator(self):
np.random.seed(0)
stump_estimator = SklTreeLearner(max_depth=1)
tree_estimator = SklTreeLearner()
stump = SklAdaBoostClassificationLearner(
base_estimator=stump_estimator, n_estimators=5)
estimator=stump_estimator, n_estimators=5)
tree = SklAdaBoostClassificationLearner(
base_estimator=tree_estimator, n_estimators=5)
estimator=tree_estimator, n_estimators=5)
cv = CrossValidation(k=4)
results = cv(self.iris, [stump, tree])
ca = CA(results)
Expand Down Expand Up @@ -68,12 +73,12 @@ def test_adaboost_reg(self):
results = cv(self.housing, [learn])
_ = RMSE(results)

def test_adaboost_reg_base_estimator(self):
def test_adaboost_reg_estimator(self):
np.random.seed(0)
stump_estimator = SklTreeRegressionLearner(max_depth=1)
tree_estimator = SklTreeRegressionLearner()
stump = SklAdaBoostRegressionLearner(base_estimator=stump_estimator)
tree = SklAdaBoostRegressionLearner(base_estimator=tree_estimator)
stump = SklAdaBoostRegressionLearner(estimator=stump_estimator)
tree = SklAdaBoostRegressionLearner(estimator=tree_estimator)
cv = CrossValidation(k=3)
results = cv(self.housing, [stump, tree])
rmse = RMSE(results)
Expand Down Expand Up @@ -103,3 +108,15 @@ def test_predict_numpy_reg(self):
def test_adaboost_adequacy_reg(self):
learner = SklAdaBoostRegressionLearner()
self.assertRaises(ValueError, learner, self.iris)

def test_remove_deprecation(self):
if LooseVersion(Orange.__version__) >= LooseVersion("3.39"):
self.fail(
"`base_estimator` was deprecated in "
"version 3.37. Please remove everything related to it."
)
stump_estimator = SklTreeLearner(max_depth=1)
with self.assertWarns(OrangeDeprecationWarning):
SklAdaBoostClassificationLearner(base_estimator=stump_estimator)
with self.assertWarns(OrangeDeprecationWarning):
SklAdaBoostClassificationLearner(base_estimator=stump_estimator)
3 changes: 0 additions & 3 deletions Orange/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from unittest.mock import MagicMock

import numpy as np
from sklearn import __version__ as sklearn_version
from sklearn.utils import check_random_state

from Orange.data import Table, Domain
Expand Down Expand Up @@ -155,8 +154,6 @@ def test_improved_randomized_pca_sparse_data(self):
pca.singular_values_, rpca.singular_values_, decimal=8
)

@unittest.skipIf(sklearn_version.startswith('0.20'),
"https://github.com/scikit-learn/scikit-learn/issues/12234")
def test_incremental_pca(self):
data = self.ionosphere
self.__ipca_test_helper(data, n_com=3, min_xpl_var=0.49)
Expand Down
13 changes: 0 additions & 13 deletions Orange/widgets/evaluate/tests/test_owliftcurve.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# pylint: disable=protected-access,duplicate-code
import copy
import pkg_resources
import unittest
from unittest.mock import Mock

import numpy as np
import sklearn

from AnyQt.QtGui import QFont, QPen

Expand All @@ -23,14 +21,6 @@
from Orange.tests import test_filename


# scikit-learn==1.1.1 does not support read the docs, therefore
# we can not make it a requirement for now. When the minimum required
# version is >=1.1.1, delete these exceptions.
OK_SKLEARN = pkg_resources.parse_version(sklearn.__version__) >= \
pkg_resources.parse_version("1.1.1")
SKIP_REASON = "Only test precision-recall with scikit-learn>=1.1.1"


class TestOWLiftCurve(EvaluateTest):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -304,7 +294,6 @@ def test_cumulative_gains_from_results():
assert_almost_equal(thresholds, [])

@staticmethod
@unittest.skipUnless(OK_SKLEARN, SKIP_REASON)
def test_precision_recall_from_results():
y_true = np.array([1, 0, 1, 0, 0, 1])
y_scores = np.array([0.6, 0.5, 0.9, 0.4, 0.2, 0.4])
Expand All @@ -324,7 +313,6 @@ def test_precision_recall_from_results():
np.array([0.2, 0.4, 0.5, 0.6, 0.9, 1]))

@staticmethod
@unittest.skipUnless(OK_SKLEARN, SKIP_REASON)
def test_precision_recall_from_results_one():
y_true = np.array([1, 0, 1, 0, 0, 1])
y_scores = np.array([0.6, 0.5, 1, 0.4, 0.2, 0.4])
Expand All @@ -344,7 +332,6 @@ def test_precision_recall_from_results_one():
np.array([0.2, 0.4, 0.5, 0.6, 1]))

@staticmethod
@unittest.skipUnless(OK_SKLEARN, SKIP_REASON)
def test_precision_recall_from_results_multiclass():
y_true = np.array([1, 0, 1, 0, 2, 2])
y_scores = np.array([[0.3, 0.3, 0.4],
Expand Down
2 changes: 1 addition & 1 deletion Orange/widgets/model/owadaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def create_learner(self):
if self.base_estimator is None:
return None
return self.LEARNER(
base_estimator=self.base_estimator,
estimator=self.base_estimator,
n_estimators=self.n_estimators,
learning_rate=self.learning_rate,
random_state=self.random_seed,
Expand Down
4 changes: 2 additions & 2 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ requirements:
- catboost >=1.0.1
- chardet >=3.0.2
- httpx >=0.21
- joblib >=1.0.0
- joblib >=1.1.1
- keyring
- keyrings.alt
- networkx
Expand All @@ -64,7 +64,7 @@ requirements:
- python-louvain >=0.13
- pyyaml
- requests
- scikit-learn >=1.1.0,!=1.2.*,<1.4 # ignoring 1.2.*: scikit-learn/issues/26241
- scikit-learn >=1.3.0
- scipy >=1.9
- serverfiles
- setuptools >=51.0.0
Expand Down
4 changes: 4 additions & 0 deletions i18n/si.jaml
Original file line number Diff line number Diff line change
Expand Up @@ -2010,14 +2010,18 @@ distance/distance.py:
def `compute_distances`:
hamming: false
ensembles/ada_boost.py:
def `base_estimator_deprecation`:
'`base_estimator` is deprecated (to be removed in 3.39): use `estimator` instead.': false
SklAdaBoostClassificationLearner: false
SklAdaBoostRegressionLearner: false
class `SklAdaBoostClassificationLearner`:
def `__init__`:
SAMME.R: false
deprecated: false
class `SklAdaBoostRegressionLearner`:
def `__init__`:
linear: false
deprecated: false
ensembles/stack.py:
StackedLearner: false
StackedClassificationLearner: false
Expand Down
4 changes: 2 additions & 2 deletions requirements-core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ catboost>=1.0.1
chardet>=3.0.2
httpx>=0.21.0
# Multiprocessing abstraction
joblib>=1.0.0
joblib>=1.1.1
keyring
keyrings.alt # for alternative keyring implementations
networkx
Expand All @@ -17,7 +17,7 @@ pip>=18.0
python-louvain>=0.13
pyyaml
requests
scikit-learn>=1.1.0,!=1.2.*,<1.4 # ignoring 1.2.*: scikit-learn/issues/26241
scikit-learn>=1.3.0
scipy>=1.9
serverfiles # for Data Sets synchronization
setuptools>=51.0.0
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ deps =
oldest: catboost==1.0.1
oldest: chardet==3.0.2
oldest: httpx==0.21.0
oldest: joblib==1.0.0
oldest: joblib==1.1.1
# oldest: keyring
# oldest: keyrings.alt
# oldest: networkx
Expand All @@ -62,7 +62,7 @@ deps =
oldest: python-louvain==0.13
# oldest: pyyaml
# oldest: requests
oldest: scikit-learn==1.1.0
oldest: scikit-learn==1.3.0
oldest: scipy==1.9
# oldest: serverfiles
oldest: setuptools==51.0.0
Expand Down

0 comments on commit 42f633f

Please sign in to comment.