Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ dependencies = [
"pandas <3.0.0",
"gradient-free-optimizers >=1.2.4, <2.0.0",
"scikit-base <1.0.0",
"scikit-learn <1.8.0",
]

[project.optional-dependencies]
sklearn-integration = [
"scikit-learn == 1.6.1",
"scikit-learn <1.8.0",
]
build = [
"setuptools",
Expand Down
8 changes: 8 additions & 0 deletions src/hyperactive/integrations/sklearn/_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# copyright: hyperactive developers, MIT License (see LICENSE file)


from hyperactive.integrations.sklearn._adapter._sklearnadapter import _SklearnAdapter

__all__ = [
"_SklearnAdapter",
]
32 changes: 32 additions & 0 deletions src/hyperactive/integrations/sklearn/_adapter/_sklearnadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Adapter for sklearn regressors and classifiers for Hyperactive optimizers."""

from sklearn.base import clone
from sklearn.utils.validation import indexable, _check_method_params


class _SklearnAdapter:

_required_parameters = ["estimator", "optimizer", "params_config"]

def _refit(self, X, y=None, **fit_params):
self.best_estimator_ = clone(self.estimator).set_params(
**clone(self.best_params_, safe=False)
)

self.best_estimator_.fit(X, y, **fit_params)
return self

def _check_data(self, X, y):
X, y = indexable(X, y)
if hasattr(X, "ndim") and X.ndim == 1:
X = X.reshape(-1, 1)
if hasattr(self, "_validate_data"):
validate_data = self._validate_data
else:
from sklearn.utils.validation import validate_data

return validate_data(X, y)

@property
def fit_successful(self):
self._fit_successful
5 changes: 1 addition & 4 deletions src/hyperactive/integrations/sklearn/best_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@


from sklearn.utils.metaestimators import available_if
from sklearn.utils.deprecation import _deprecate_Xt_in_inverse_transform
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted

from .utils import _estimator_has
Expand Down Expand Up @@ -47,8 +45,7 @@ def transform(self, X):
return self.best_estimator_.transform(X)

@available_if(_estimator_has("inverse_transform"))
def inverse_transform(self, X=None, Xt=None):
X = _deprecate_Xt_in_inverse_transform(X, Xt)
def inverse_transform(self, X=None):
check_is_fitted(self)
return self.best_estimator_.inverse_transform(X)

Expand Down
20 changes: 2 additions & 18 deletions src/hyperactive/integrations/sklearn/hyperactive_search_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from .checks import Checks
from ...optimizers import RandomSearchOptimizer
from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment
from hyperactive.integrations.sklearn._adapter import _SklearnAdapter


class HyperactiveSearchCV(BaseEstimator, _BestEstimator_, Checks):
class HyperactiveSearchCV(_SklearnAdapter, BaseEstimator, _BestEstimator_, Checks):
"""
HyperactiveSearchCV class for hyperparameter tuning using cross-validation with sklearn estimators.

Expand Down Expand Up @@ -77,23 +78,6 @@ def __init__(
self.refit = refit
self.cv = cv

def _refit(self, X, y=None, **fit_params):
self.best_estimator_ = clone(self.estimator).set_params(
**clone(self.best_params_, safe=False)
)

self.best_estimator_.fit(X, y, **fit_params)
return self

def _check_data(self, X, y):
X, y = indexable(X, y)
if hasattr(self, "_validate_data"):
validate_data = self._validate_data
else:
from sklearn.utils.validation import validate_data

return validate_data(X, y)

@Checks.verify_fit
def fit(self, X, y, **fit_params):
"""
Expand Down
20 changes: 2 additions & 18 deletions src/hyperactive/integrations/sklearn/opt_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from hyperactive.integrations.sklearn.best_estimator import (
BestEstimator as _BestEstimator_
)
from hyperactive.integrations.sklearn._adapter import _SklearnAdapter
from hyperactive.integrations.sklearn.checks import Checks


class OptCV(BaseEstimator, _BestEstimator_, Checks):
class OptCV(_SklearnAdapter, BaseEstimator, _BestEstimator_, Checks):
"""Tuning via any optimizer in the hyperactive API.

Parameters
Expand Down Expand Up @@ -83,23 +84,6 @@ def __init__(
self.refit = refit
self.cv = cv

def _refit(self, X, y=None, **fit_params):
self.best_estimator_ = clone(self.estimator).set_params(
**clone(self.best_params_, safe=False)
)

self.best_estimator_.fit(X, y, **fit_params)
return self

def _check_data(self, X, y):
X, y = indexable(X, y)
if hasattr(self, "_validate_data"):
validate_data = self._validate_data
else:
from sklearn.utils.validation import validate_data

return validate_data(X, y)

@Checks.verify_fit
def fit(self, X, y, **fit_params):
"""Fit the model.
Expand Down
Loading