diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 116878f3..6b5d0002 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -84,3 +84,32 @@ jobs: - name: Test with pytest run: | python -m pytest src/hyperactive -p no:warnings + + test-sklearn-versions: + name: test-sklearn-${{ matrix.sklearn-version }} python-${{ matrix.python-version }} + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + sklearn-version: ["1.5", "1.6", "1.7"] + python-version: ["3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies for scikit-learn ${{ matrix.sklearn-version }} + run: | + python -m pip install --upgrade pip + python -m pip install build pytest + make install + python -m pip install scikit-learn==${{ matrix.sklearn-version }} + + - name: Run sklearn integration tests for ${{ matrix.sklearn-version }} + run: | + python -m pytest -x -p no:warnings tests/integrations/sklearn/ \ No newline at end of file diff --git a/src/hyperactive/integrations/sklearn/_compat.py b/src/hyperactive/integrations/sklearn/_compat.py new file mode 100644 index 00000000..528af594 --- /dev/null +++ b/src/hyperactive/integrations/sklearn/_compat.py @@ -0,0 +1,104 @@ +""" +Internal helpers that bridge behavioural differences between +scikit-learn versions. Import *private* scikit-learn symbols **only** +here and nowhere else. + +Copyright: Hyperactive contributors +License: MIT +""" + +from __future__ import annotations + +import warnings +from typing import Dict, Any + +import sklearn +from packaging import version +from sklearn.utils.validation import indexable + +_SK_VERSION = version.parse(sklearn.__version__) + + +def _safe_validate_X_y(estimator, X, y): + """ + Version-independent replacement for naive validate_data(X, y). + + • Ensures X is 2-D. + • Allows y to stay 1-D (required by scikit-learn >=1.7 checks). + • Uses BaseEstimator._validate_data when available so that + estimator tags and sample-weight checks keep working. + """ + X, y = indexable(X, y) + + if hasattr(estimator, "_validate_data"): + return estimator._validate_data( + X, + y, + validate_separately=( + {"ensure_2d": True}, # parameters for X + {"ensure_2d": False}, # parameters for y + ), + ) + + # Fallback for very old scikit-learn versions (<0.23) + from sklearn.utils.validation import check_X_y + + return check_X_y(X, y, ensure_2d=True) + + +def _safe_refit(estimator, X, y, fit_params): + if estimator.refit: + estimator._refit(X, y, **fit_params) + + # make the wrapper itself expose n_features_in_ + if hasattr(estimator.best_estimator_, "n_features_in_"): + estimator.n_features_in_ = estimator.best_estimator_.n_features_in_ + else: + # Even when `refit=False` we must satisfy the contract + estimator.n_features_in_ = X.shape[1] + + +# Replacement for `_deprecate_Xt_in_inverse_transform` +if _SK_VERSION < version.parse("1.7"): + # Still exists → re-export + from sklearn.utils.deprecation import _deprecate_Xt_in_inverse_transform +else: + # Removed in 1.7 → provide drop-in replacement + def _deprecate_Xt_in_inverse_transform( # noqa: N802 keep sklearn’s name + X: Any | None, + Xt: Any | None, + ): + """ + scikit-learn ≤1.6 accepted both the old `Xt` parameter and the new + `X` parameter for `inverse_transform`. When only `Xt` is given we + return `Xt` and raise a deprecation warning (same behaviour that + scikit-learn had before 1.7); otherwise we return `X`. + """ + if Xt is not None: + warnings.warn( + "'Xt' was deprecated in scikit-learn 1.2 and has been " + "removed in 1.7; use the positional argument 'X' instead.", + FutureWarning, + stacklevel=2, + ) + return Xt + return X + + +# Replacement for `_check_method_params` +try: + from sklearn.utils.validation import _check_method_params # noqa: F401 +except ImportError: # fallback for future releases + + def _check_method_params( # type: ignore[override] # noqa: N802 + X, + params: Dict[str, Any], + ): + # passthrough – rely on estimator & indexable for validation + return params + + +__all__ = [ + "_deprecate_Xt_in_inverse_transform", + "_check_method_params", +] diff --git a/src/hyperactive/integrations/sklearn/best_estimator.py b/src/hyperactive/integrations/sklearn/best_estimator.py index def5f828..11d61e7b 100644 --- a/src/hyperactive/integrations/sklearn/best_estimator.py +++ b/src/hyperactive/integrations/sklearn/best_estimator.py @@ -4,11 +4,11 @@ from sklearn.utils.metaestimators import available_if -from sklearn.utils.deprecation import _deprecate_Xt_in_inverse_transform from sklearn.exceptions import NotFittedError from sklearn.utils.validation import check_is_fitted from .utils import _estimator_has +from ._compat import _deprecate_Xt_in_inverse_transform # NOTE Implementations of following methods from: diff --git a/src/hyperactive/integrations/sklearn/hyperactive_search_cv.py b/src/hyperactive/integrations/sklearn/hyperactive_search_cv.py index 2acb414f..cf7fdd61 100644 --- a/src/hyperactive/integrations/sklearn/hyperactive_search_cv.py +++ b/src/hyperactive/integrations/sklearn/hyperactive_search_cv.py @@ -7,7 +7,7 @@ from sklearn.base import BaseEstimator, clone from sklearn.metrics import check_scoring -from sklearn.utils.validation import indexable, _check_method_params + from sklearn.base import BaseEstimator as SklearnBaseEstimator @@ -18,6 +18,8 @@ from ...optimizers import RandomSearchOptimizer from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment +from ._compat import _check_method_params, _safe_validate_X_y, _safe_refit + class HyperactiveSearchCV(BaseEstimator, _BestEstimator_, Checks): """ @@ -86,13 +88,7 @@ def _refit(self, X, y=None, **fit_params): return self def _check_data(self, X, y): - X, y = indexable(X, y) - if hasattr(self, "_validate_data"): - validate_data = self._validate_data - else: - from sklearn.utils.validation import validate_data - - return validate_data(X, y) + return _safe_validate_X_y(self, X, y) @Checks.verify_fit def fit(self, X, y, **fit_params): @@ -141,8 +137,7 @@ def fit(self, X, y, **fit_params): self.best_score_ = hyper.best_score(objective_function) self.search_data_ = hyper.search_data(objective_function) - if self.refit: - self._refit(X, y, **fit_params) + _safe_refit(self, X, y, fit_params) return self diff --git a/src/hyperactive/integrations/sklearn/opt_cv.py b/src/hyperactive/integrations/sklearn/opt_cv.py index eb83cf90..3e535885 100644 --- a/src/hyperactive/integrations/sklearn/opt_cv.py +++ b/src/hyperactive/integrations/sklearn/opt_cv.py @@ -4,14 +4,15 @@ from typing import Union from sklearn.base import BaseEstimator, clone -from sklearn.utils.validation import indexable, _check_method_params from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment from hyperactive.integrations.sklearn.best_estimator import ( - BestEstimator as _BestEstimator_ + BestEstimator as _BestEstimator_, ) from hyperactive.integrations.sklearn.checks import Checks +from ._compat import _check_method_params, _safe_validate_X_y, _safe_refit + class OptCV(BaseEstimator, _BestEstimator_, Checks): """Tuning via any optimizer in the hyperactive API. @@ -92,13 +93,7 @@ def _refit(self, X, y=None, **fit_params): return self def _check_data(self, X, y): - X, y = indexable(X, y) - if hasattr(self, "_validate_data"): - validate_data = self._validate_data - else: - from sklearn.utils.validation import validate_data - - return validate_data(X, y) + return _safe_validate_X_y(self, X, y) @Checks.verify_fit def fit(self, X, y, **fit_params): @@ -138,8 +133,7 @@ def fit(self, X, y, **fit_params): self.best_params_ = best_params self.best_estimator_ = clone(self.estimator).set_params(**best_params) - if self.refit: - self._refit(X, y, **fit_params) + _safe_refit(self, X, y, fit_params) return self