Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,32 @@ jobs:
- name: Test with pytest
run: |
python -m pytest src/hyperactive -p no:warnings

test-sklearn-versions:
name: test-sklearn-${{ matrix.sklearn-version }} python-${{ matrix.python-version }}
runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
sklearn-version: ["1.5", "1.6", "1.7"]
python-version: ["3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4

- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies for scikit-learn ${{ matrix.sklearn-version }}
run: |
python -m pip install --upgrade pip
python -m pip install build pytest
make install
python -m pip install scikit-learn==${{ matrix.sklearn-version }}

- name: Run sklearn integration tests for ${{ matrix.sklearn-version }}
run: |
python -m pytest -x -p no:warnings tests/integrations/sklearn/
104 changes: 104 additions & 0 deletions src/hyperactive/integrations/sklearn/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Internal helpers that bridge behavioural differences between
scikit-learn versions. Import *private* scikit-learn symbols **only**
here and nowhere else.

Copyright: Hyperactive contributors
License: MIT
"""

from __future__ import annotations

import warnings
from typing import Dict, Any

import sklearn
from packaging import version
from sklearn.utils.validation import indexable

_SK_VERSION = version.parse(sklearn.__version__)


def _safe_validate_X_y(estimator, X, y):
"""
Version-independent replacement for naive validate_data(X, y).

• Ensures X is 2-D.
• Allows y to stay 1-D (required by scikit-learn >=1.7 checks).
• Uses BaseEstimator._validate_data when available so that
estimator tags and sample-weight checks keep working.
"""
X, y = indexable(X, y)

if hasattr(estimator, "_validate_data"):
return estimator._validate_data(
X,
y,
validate_separately=(
{"ensure_2d": True}, # parameters for X
{"ensure_2d": False}, # parameters for y
),
)

# Fallback for very old scikit-learn versions (<0.23)
from sklearn.utils.validation import check_X_y

return check_X_y(X, y, ensure_2d=True)


def _safe_refit(estimator, X, y, fit_params):
if estimator.refit:
estimator._refit(X, y, **fit_params)

# make the wrapper itself expose n_features_in_
if hasattr(estimator.best_estimator_, "n_features_in_"):
estimator.n_features_in_ = estimator.best_estimator_.n_features_in_
else:
# Even when `refit=False` we must satisfy the contract
estimator.n_features_in_ = X.shape[1]


# Replacement for `_deprecate_Xt_in_inverse_transform`
if _SK_VERSION < version.parse("1.7"):
# Still exists → re-export
from sklearn.utils.deprecation import _deprecate_Xt_in_inverse_transform
else:
# Removed in 1.7 → provide drop-in replacement
def _deprecate_Xt_in_inverse_transform( # noqa: N802 keep sklearn’s name
X: Any | None,
Xt: Any | None,
):
"""
scikit-learn ≤1.6 accepted both the old `Xt` parameter and the new
`X` parameter for `inverse_transform`. When only `Xt` is given we
return `Xt` and raise a deprecation warning (same behaviour that
scikit-learn had before 1.7); otherwise we return `X`.
"""
if Xt is not None:
warnings.warn(
"'Xt' was deprecated in scikit-learn 1.2 and has been "
"removed in 1.7; use the positional argument 'X' instead.",
FutureWarning,
stacklevel=2,
)
return Xt
return X


# Replacement for `_check_method_params`
try:
from sklearn.utils.validation import _check_method_params # noqa: F401
except ImportError: # fallback for future releases

def _check_method_params( # type: ignore[override] # noqa: N802
X,
params: Dict[str, Any],
):
# passthrough – rely on estimator & indexable for validation
return params


__all__ = [
"_deprecate_Xt_in_inverse_transform",
"_check_method_params",
]
2 changes: 1 addition & 1 deletion src/hyperactive/integrations/sklearn/best_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@


from sklearn.utils.metaestimators import available_if
from sklearn.utils.deprecation import _deprecate_Xt_in_inverse_transform
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted

from .utils import _estimator_has
from ._compat import _deprecate_Xt_in_inverse_transform


# NOTE Implementations of following methods from:
Expand Down
15 changes: 5 additions & 10 deletions src/hyperactive/integrations/sklearn/hyperactive_search_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sklearn.base import BaseEstimator, clone
from sklearn.metrics import check_scoring
from sklearn.utils.validation import indexable, _check_method_params


from sklearn.base import BaseEstimator as SklearnBaseEstimator

Expand All @@ -18,6 +18,8 @@
from ...optimizers import RandomSearchOptimizer
from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment

from ._compat import _check_method_params, _safe_validate_X_y, _safe_refit


class HyperactiveSearchCV(BaseEstimator, _BestEstimator_, Checks):
"""
Expand Down Expand Up @@ -86,13 +88,7 @@ def _refit(self, X, y=None, **fit_params):
return self

def _check_data(self, X, y):
X, y = indexable(X, y)
if hasattr(self, "_validate_data"):
validate_data = self._validate_data
else:
from sklearn.utils.validation import validate_data

return validate_data(X, y)
return _safe_validate_X_y(self, X, y)

@Checks.verify_fit
def fit(self, X, y, **fit_params):
Expand Down Expand Up @@ -141,8 +137,7 @@ def fit(self, X, y, **fit_params):
self.best_score_ = hyper.best_score(objective_function)
self.search_data_ = hyper.search_data(objective_function)

if self.refit:
self._refit(X, y, **fit_params)
_safe_refit(self, X, y, fit_params)

return self

Expand Down
16 changes: 5 additions & 11 deletions src/hyperactive/integrations/sklearn/opt_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Union

from sklearn.base import BaseEstimator, clone
from sklearn.utils.validation import indexable, _check_method_params

from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment
from hyperactive.integrations.sklearn.best_estimator import (
BestEstimator as _BestEstimator_
BestEstimator as _BestEstimator_,
)
from hyperactive.integrations.sklearn.checks import Checks

from ._compat import _check_method_params, _safe_validate_X_y, _safe_refit


class OptCV(BaseEstimator, _BestEstimator_, Checks):
"""Tuning via any optimizer in the hyperactive API.
Expand Down Expand Up @@ -92,13 +93,7 @@ def _refit(self, X, y=None, **fit_params):
return self

def _check_data(self, X, y):
X, y = indexable(X, y)
if hasattr(self, "_validate_data"):
validate_data = self._validate_data
else:
from sklearn.utils.validation import validate_data

return validate_data(X, y)
return _safe_validate_X_y(self, X, y)

@Checks.verify_fit
def fit(self, X, y, **fit_params):
Expand Down Expand Up @@ -138,8 +133,7 @@ def fit(self, X, y, **fit_params):
self.best_params_ = best_params
self.best_estimator_ = clone(self.estimator).set_params(**best_params)

if self.refit:
self._refit(X, y, **fit_params)
_safe_refit(self, X, y, fit_params)

return self

Expand Down