Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions stubs/sklearn/_config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def set_config(
enable_cython_pairwise_dist: None | bool = None,
array_api_dispatch: None | bool = None,
transform_output: None | str = None,
enable_metadata_routing: None | bool = None,
skip_parameter_validation: None | bool = None,
) -> None: ...
def config_context(
*,
Expand All @@ -27,4 +29,6 @@ def config_context(
enable_cython_pairwise_dist: None | bool = None,
array_api_dispatch: None | bool = None,
transform_output: None | str = None,
enable_metadata_routing: None | bool = None,
skip_parameter_validation: None | bool = None,
) -> Iterator[None]: ...
3 changes: 2 additions & 1 deletion stubs/sklearn/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ from ._config import get_config as get_config
from ._typing import ArrayLike, Float, Int, MatrixLike
from .metrics import accuracy_score as accuracy_score, r2_score as r2_score
from .utils._estimator_html_repr import estimator_html_repr as estimator_html_repr
from .utils._metadata_requests import _MetadataRequester
from .utils._param_validation import validate_parameter_constraints as validate_parameter_constraints
from .utils._set_output import _SetOutputMixin
from .utils.validation import check_array as check_array, check_is_fitted as check_is_fitted, check_X_y as check_X_y
Expand All @@ -24,7 +25,7 @@ from .utils.validation import check_array as check_array, check_is_fitted as che

def clone(estimator: BaseEstimator | Iterable[BaseEstimator], *, safe: bool = True) -> Any: ...

class BaseEstimator:
class BaseEstimator(_MetadataRequester):
def get_params(self, deep: bool = True) -> dict: ...
def set_params(self, **params) -> Self: ...
def __repr__(self, N_CHAR_MAX: int = 700) -> str: ...
Expand Down
49 changes: 0 additions & 49 deletions stubs/sklearn/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ sklearn.base._UnstableArchMixin.__sklearn_tags__
sklearn.base.is_clusterer
sklearn.calibration
sklearn.calibration.CalibratedClassifierCV.__sklearn_tags__
sklearn.calibration.CalibratedClassifierCV.get_metadata_routing
sklearn.calibration.CalibratedClassifierCV.set_fit_request
sklearn.calibration.CalibratedClassifierCV.set_score_request
sklearn.calibration.LabelEncoder.__sklearn_tags__
Expand Down Expand Up @@ -147,15 +146,11 @@ sklearn.cluster.tests.test_spectral
sklearn.compose
sklearn.compose.ColumnTransformer.__getitem__
sklearn.compose.ColumnTransformer.__sklearn_tags__
sklearn.compose.ColumnTransformer.get_metadata_routing
sklearn.compose.TransformedTargetRegressor.__sklearn_tags__
sklearn.compose.TransformedTargetRegressor.get_metadata_routing
sklearn.compose.TransformedTargetRegressor.set_score_request
sklearn.compose._column_transformer.ColumnTransformer.__getitem__
sklearn.compose._column_transformer.ColumnTransformer.__sklearn_tags__
sklearn.compose._column_transformer.ColumnTransformer.get_metadata_routing
sklearn.compose._target.TransformedTargetRegressor.__sklearn_tags__
sklearn.compose._target.TransformedTargetRegressor.get_metadata_routing
sklearn.compose._target.TransformedTargetRegressor.set_score_request
sklearn.compose.tests
sklearn.compose.tests.test_column_transformer
Expand All @@ -170,8 +165,6 @@ sklearn.conftest.raccoon_face_fxt
sklearn.conftest.raccoon_face_or_skip
sklearn.conftest.scipy_datasets_require_network
sklearn.covariance
sklearn.covariance.GraphicalLassoCV.get_metadata_routing
sklearn.covariance._graph_lasso.GraphicalLassoCV.get_metadata_routing
sklearn.covariance.tests
sklearn.covariance.tests.test_covariance
sklearn.covariance.tests.test_elliptic_envelope
Expand Down Expand Up @@ -343,7 +336,6 @@ sklearn.ensemble._bagging.BaggingClassifier.set_score_request
sklearn.ensemble._bagging.BaggingRegressor.set_fit_request
sklearn.ensemble._bagging.BaggingRegressor.set_score_request
sklearn.ensemble._bagging.BaseBagging.__sklearn_tags__
sklearn.ensemble._bagging.BaseBagging.get_metadata_routing
sklearn.ensemble._bagging.BaseBagging.set_fit_request
sklearn.ensemble._base._BaseHeterogeneousEnsemble.__sklearn_tags__
sklearn.ensemble._forest.BaseForest.__sklearn_tags__
Expand Down Expand Up @@ -420,13 +412,11 @@ sklearn.ensemble._stacking.StackingClassifier.set_fit_request
sklearn.ensemble._stacking.StackingClassifier.set_score_request
sklearn.ensemble._stacking.StackingRegressor.set_fit_request
sklearn.ensemble._stacking.StackingRegressor.set_score_request
sklearn.ensemble._stacking._BaseStacking.get_metadata_routing
sklearn.ensemble._voting.VotingClassifier.__sklearn_tags__
sklearn.ensemble._voting.VotingClassifier.set_fit_request
sklearn.ensemble._voting.VotingClassifier.set_score_request
sklearn.ensemble._voting.VotingRegressor.set_fit_request
sklearn.ensemble._voting.VotingRegressor.set_score_request
sklearn.ensemble._voting._BaseVoting.get_metadata_routing
sklearn.ensemble._weight_boosting.AdaBoostClassifier.set_fit_request
sklearn.ensemble._weight_boosting.AdaBoostClassifier.set_score_request
sklearn.ensemble._weight_boosting.AdaBoostRegressor.set_fit_request
Expand All @@ -448,7 +438,6 @@ sklearn.exceptions.EstimatorCheckFailedWarning
sklearn.exceptions.UnsetMetadataPassedError
sklearn.exceptions.__all__
sklearn.experimental
sklearn.experimental.enable_iterative_imputer.IterativeImputer.get_metadata_routing
sklearn.experimental.tests
sklearn.experimental.tests.test_enable_hist_gradient_boosting
sklearn.experimental.tests.test_enable_iterative_imputer
Expand Down Expand Up @@ -478,22 +467,14 @@ sklearn.feature_extraction.text.TfidfVectorizer.__sklearn_tags__
sklearn.feature_selection
sklearn.feature_selection.GenericUnivariateSelect.__sklearn_tags__
sklearn.feature_selection.RFE.__sklearn_tags__
sklearn.feature_selection.RFE.get_metadata_routing
sklearn.feature_selection.RFECV.get_metadata_routing
sklearn.feature_selection.SelectFromModel.__sklearn_tags__
sklearn.feature_selection.SelectFromModel.get_metadata_routing
sklearn.feature_selection.SelectKBest.__sklearn_tags__
sklearn.feature_selection.SelectPercentile.__sklearn_tags__
sklearn.feature_selection.SequentialFeatureSelector.__sklearn_tags__
sklearn.feature_selection.SequentialFeatureSelector.get_metadata_routing
sklearn.feature_selection.VarianceThreshold.__sklearn_tags__
sklearn.feature_selection._from_model.SelectFromModel.__sklearn_tags__
sklearn.feature_selection._from_model.SelectFromModel.get_metadata_routing
sklearn.feature_selection._rfe.RFE.__sklearn_tags__
sklearn.feature_selection._rfe.RFE.get_metadata_routing
sklearn.feature_selection._rfe.RFECV.get_metadata_routing
sklearn.feature_selection._sequential.SequentialFeatureSelector.__sklearn_tags__
sklearn.feature_selection._sequential.SequentialFeatureSelector.get_metadata_routing
sklearn.feature_selection._univariate_selection.GenericUnivariateSelect.__sklearn_tags__
sklearn.feature_selection._univariate_selection.SelectKBest.__sklearn_tags__
sklearn.feature_selection._univariate_selection.SelectPercentile.__sklearn_tags__
Expand Down Expand Up @@ -525,14 +506,12 @@ sklearn.gaussian_process.tests.test_gpc
sklearn.gaussian_process.tests.test_gpr
sklearn.gaussian_process.tests.test_kernels
sklearn.impute
sklearn.impute.IterativeImputer.get_metadata_routing
sklearn.impute.MissingIndicator.__sklearn_tags__
sklearn.impute.SimpleImputer.__sklearn_tags__
sklearn.impute.__all__
sklearn.impute._base.MissingIndicator.__sklearn_tags__
sklearn.impute._base.SimpleImputer.__sklearn_tags__
sklearn.impute._base._BaseImputer.__sklearn_tags__
sklearn.impute._iterative.IterativeImputer.get_metadata_routing
sklearn.impute.tests
sklearn.impute.tests.test_base
sklearn.impute.tests.test_common
Expand Down Expand Up @@ -579,7 +558,6 @@ sklearn.linear_model.HuberRegressor.set_score_request
sklearn.linear_model.Lars.set_fit_request
sklearn.linear_model.Lars.set_score_request
sklearn.linear_model.LarsCV.__sklearn_tags__
sklearn.linear_model.LarsCV.get_metadata_routing
sklearn.linear_model.LarsCV.parameter
sklearn.linear_model.LarsCV.set_score_request
sklearn.linear_model.Lasso.set_fit_request
Expand All @@ -600,7 +578,6 @@ sklearn.linear_model.LogisticRegression.__sklearn_tags__
sklearn.linear_model.LogisticRegression.set_fit_request
sklearn.linear_model.LogisticRegression.set_score_request
sklearn.linear_model.LogisticRegressionCV.__sklearn_tags__
sklearn.linear_model.LogisticRegressionCV.get_metadata_routing
sklearn.linear_model.LogisticRegressionCV.param
sklearn.linear_model.LogisticRegressionCV.set_fit_request
sklearn.linear_model.LogisticRegressionCV.set_score_request
Expand All @@ -611,7 +588,6 @@ sklearn.linear_model.MultiTaskElasticNetCV.set_score_request
sklearn.linear_model.MultiTaskLassoCV.__sklearn_tags__
sklearn.linear_model.MultiTaskLassoCV.set_score_request
sklearn.linear_model.OrthogonalMatchingPursuit.set_score_request
sklearn.linear_model.OrthogonalMatchingPursuitCV.get_metadata_routing
sklearn.linear_model.OrthogonalMatchingPursuitCV.set_score_request
sklearn.linear_model.PassiveAggressiveClassifier.set_fit_request
sklearn.linear_model.PassiveAggressiveClassifier.set_partial_fit_request
Expand All @@ -627,7 +603,6 @@ sklearn.linear_model.QuantileRegressor.__sklearn_tags__
sklearn.linear_model.QuantileRegressor.set_fit_request
sklearn.linear_model.QuantileRegressor.set_score_request
sklearn.linear_model.RANSACRegressor.__sklearn_tags__
sklearn.linear_model.RANSACRegressor.get_metadata_routing
sklearn.linear_model.RANSACRegressor.set_fit_request
sklearn.linear_model.Ridge.__sklearn_tags__
sklearn.linear_model.Ridge.set_fit_request
Expand Down Expand Up @@ -673,7 +648,6 @@ sklearn.linear_model._coordinate_descent.Lasso.set_score_request
sklearn.linear_model._coordinate_descent.LassoCV.set_fit_request
sklearn.linear_model._coordinate_descent.LassoCV.set_score_request
sklearn.linear_model._coordinate_descent.LinearModelCV.__sklearn_tags__
sklearn.linear_model._coordinate_descent.LinearModelCV.get_metadata_routing
sklearn.linear_model._coordinate_descent.LinearModelCV.set_fit_request
sklearn.linear_model._coordinate_descent.MultiTaskElasticNet.__sklearn_tags__
sklearn.linear_model._coordinate_descent.MultiTaskElasticNet.param
Expand Down Expand Up @@ -707,7 +681,6 @@ sklearn.linear_model._huber.HuberRegressor.set_score_request
sklearn.linear_model._least_angle.Lars.set_fit_request
sklearn.linear_model._least_angle.Lars.set_score_request
sklearn.linear_model._least_angle.LarsCV.__sklearn_tags__
sklearn.linear_model._least_angle.LarsCV.get_metadata_routing
sklearn.linear_model._least_angle.LarsCV.parameter
sklearn.linear_model._least_angle.LarsCV.set_score_request
sklearn.linear_model._least_angle.LassoLars.set_fit_request
Expand All @@ -722,12 +695,10 @@ sklearn.linear_model._logistic.LogisticRegression.__sklearn_tags__
sklearn.linear_model._logistic.LogisticRegression.set_fit_request
sklearn.linear_model._logistic.LogisticRegression.set_score_request
sklearn.linear_model._logistic.LogisticRegressionCV.__sklearn_tags__
sklearn.linear_model._logistic.LogisticRegressionCV.get_metadata_routing
sklearn.linear_model._logistic.LogisticRegressionCV.param
sklearn.linear_model._logistic.LogisticRegressionCV.set_fit_request
sklearn.linear_model._logistic.LogisticRegressionCV.set_score_request
sklearn.linear_model._omp.OrthogonalMatchingPursuit.set_score_request
sklearn.linear_model._omp.OrthogonalMatchingPursuitCV.get_metadata_routing
sklearn.linear_model._omp.OrthogonalMatchingPursuitCV.set_score_request
sklearn.linear_model._passive_aggressive.PassiveAggressiveClassifier.set_fit_request
sklearn.linear_model._passive_aggressive.PassiveAggressiveClassifier.set_partial_fit_request
Expand All @@ -741,7 +712,6 @@ sklearn.linear_model._quantile.QuantileRegressor.__sklearn_tags__
sklearn.linear_model._quantile.QuantileRegressor.set_fit_request
sklearn.linear_model._quantile.QuantileRegressor.set_score_request
sklearn.linear_model._ransac.RANSACRegressor.__sklearn_tags__
sklearn.linear_model._ransac.RANSACRegressor.get_metadata_routing
sklearn.linear_model._ransac.RANSACRegressor.set_fit_request
sklearn.linear_model._ridge.Ridge.__sklearn_tags__
sklearn.linear_model._ridge.Ridge.set_fit_request
Expand All @@ -757,7 +727,6 @@ sklearn.linear_model._ridge.RidgeClassifierCV.set_score_request
sklearn.linear_model._ridge._BaseRidge.set_fit_request
sklearn.linear_model._ridge._BaseRidgeCV.__sklearn_tags__
sklearn.linear_model._ridge._BaseRidgeCV.cv_values_
sklearn.linear_model._ridge._BaseRidgeCV.get_metadata_routing
sklearn.linear_model._ridge._BaseRidgeCV.set_fit_request
sklearn.linear_model._ridge._IdentityClassifier.set_decision_function_request
sklearn.linear_model._ridge._IdentityClassifier.set_score_request
Expand Down Expand Up @@ -929,7 +898,6 @@ sklearn.model_selection._classification_threshold.FixedThresholdClassifier.set_s
sklearn.model_selection._classification_threshold.TunedThresholdClassifierCV.set_score_request
sklearn.model_selection._plot.LearningCurveDisplay.from_estimator
sklearn.model_selection._search.BaseSearchCV.__sklearn_tags__
sklearn.model_selection._search.BaseSearchCV.get_metadata_routing
sklearn.model_selection._split.GroupKFold.set_split_request
sklearn.model_selection._split.GroupShuffleSplit.set_split_request
sklearn.model_selection._split.LeaveOneGroupOut.set_split_request
Expand All @@ -947,19 +915,15 @@ sklearn.model_selection.tests.test_successive_halving
sklearn.model_selection.tests.test_validation
sklearn.multiclass
sklearn.multiclass.OneVsOneClassifier.__sklearn_tags__
sklearn.multiclass.OneVsOneClassifier.get_metadata_routing
sklearn.multiclass.OneVsOneClassifier.set_partial_fit_request
sklearn.multiclass.OneVsOneClassifier.set_score_request
sklearn.multiclass.OneVsRestClassifier.__sklearn_tags__
sklearn.multiclass.OneVsRestClassifier.get_metadata_routing
sklearn.multiclass.OneVsRestClassifier.set_partial_fit_request
sklearn.multiclass.OneVsRestClassifier.set_score_request
sklearn.multiclass.OutputCodeClassifier.__sklearn_tags__
sklearn.multiclass.OutputCodeClassifier.get_metadata_routing
sklearn.multiclass.OutputCodeClassifier.set_score_request
sklearn.multioutput
sklearn.multioutput.ClassifierChain.__sklearn_tags__
sklearn.multioutput.ClassifierChain.get_metadata_routing
sklearn.multioutput.ClassifierChain.predict_log_proba
sklearn.multioutput.ClassifierChain.set_score_request
sklearn.multioutput.MultiOutputClassifier.__sklearn_tags__
Expand All @@ -969,11 +933,9 @@ sklearn.multioutput.MultiOutputRegressor.set_fit_request
sklearn.multioutput.MultiOutputRegressor.set_partial_fit_request
sklearn.multioutput.MultiOutputRegressor.set_score_request
sklearn.multioutput.RegressorChain.__sklearn_tags__
sklearn.multioutput.RegressorChain.get_metadata_routing
sklearn.multioutput.RegressorChain.set_score_request
sklearn.multioutput._BaseChain.__sklearn_tags__
sklearn.multioutput._MultiOutputEstimator.__sklearn_tags__
sklearn.multioutput._MultiOutputEstimator.get_metadata_routing
sklearn.multioutput._MultiOutputEstimator.set_fit_request
sklearn.multioutput._MultiOutputEstimator.set_partial_fit_request
sklearn.naive_bayes
Expand Down Expand Up @@ -1073,10 +1035,8 @@ sklearn.pipeline
sklearn.pipeline.FeatureUnion.__getitem__
sklearn.pipeline.FeatureUnion.__sklearn_tags__
sklearn.pipeline.FeatureUnion.feature_names_in_
sklearn.pipeline.FeatureUnion.get_metadata_routing
sklearn.pipeline.FunctionTransformer.__sklearn_tags__
sklearn.pipeline.Pipeline.__sklearn_tags__
sklearn.pipeline.Pipeline.get_metadata_routing
sklearn.pipeline.Pipeline.set_score_request
sklearn.preprocessing
sklearn.preprocessing.Binarizer.__sklearn_tags__
Expand Down Expand Up @@ -1144,13 +1104,11 @@ sklearn.semi_supervised
sklearn.semi_supervised.LabelPropagation.set_score_request
sklearn.semi_supervised.LabelSpreading.set_score_request
sklearn.semi_supervised.SelfTrainingClassifier.__sklearn_tags__
sklearn.semi_supervised.SelfTrainingClassifier.get_metadata_routing
sklearn.semi_supervised._label_propagation.BaseLabelPropagation.__sklearn_tags__
sklearn.semi_supervised._label_propagation.BaseLabelPropagation.set_score_request
sklearn.semi_supervised._label_propagation.LabelPropagation.set_score_request
sklearn.semi_supervised._label_propagation.LabelSpreading.set_score_request
sklearn.semi_supervised._self_training.SelfTrainingClassifier.__sklearn_tags__
sklearn.semi_supervised._self_training.SelfTrainingClassifier.get_metadata_routing
sklearn.semi_supervised.tests
sklearn.semi_supervised.tests.test_label_propagation
sklearn.semi_supervised.tests.test_self_training
Expand Down Expand Up @@ -1447,8 +1405,6 @@ sklearn._build_utils.openmp_helpers
sklearn._build_utils.parse
sklearn._build_utils.pre_build_helpers
sklearn._config._threadlocal
sklearn._config.config_context
sklearn._config.set_config
sklearn._loss._loss.CyAbsoluteError.cy_grad_hess
sklearn._loss._loss.CyAbsoluteError.cy_gradient
sklearn._loss._loss.CyAbsoluteError.cy_loss
Expand Down Expand Up @@ -1528,7 +1484,6 @@ sklearn.compose.ColumnTransformer.__init__
sklearn.compose._column_transformer.ColumnTransformer.__init__
sklearn.compose._column_transformer.make_column_transformer
sklearn.compose.make_column_transformer
sklearn.config_context
sklearn.conftest._SKIP32_MARK
sklearn.conftest.fetch_20newsgroups
sklearn.conftest.fetch_20newsgroups_fxt
Expand Down Expand Up @@ -2045,7 +2000,6 @@ sklearn.metrics.mean_squared_log_error
sklearn.metrics.pairwise.DistanceMetric
sklearn.metrics.pairwise.check_array
sklearn.metrics.pairwise.check_pairwise_arrays
sklearn.metrics.pairwise.config_context
sklearn.metrics.pairwise.manhattan_distances
sklearn.metrics.pairwise.pairwise_distances
sklearn.metrics.pairwise_distances
Expand Down Expand Up @@ -2126,7 +2080,6 @@ sklearn.random_projection.check_array
sklearn.random_projection.sample_without_replacement
sklearn.semi_supervised.SelfTrainingClassifier.__init__
sklearn.semi_supervised._self_training.SelfTrainingClassifier.__init__
sklearn.set_config
sklearn.svm.LinearSVC.__init__
sklearn.svm.LinearSVR.__init__
sklearn.svm.NuSVC.coef_
Expand Down Expand Up @@ -2275,7 +2228,6 @@ sklearn.utils.estimator_checks.check_estimators_pickle
sklearn.utils.estimator_checks.check_global_ouptut_transform_pandas
sklearn.utils.estimator_checks.check_parameters_default_constructible
sklearn.utils.estimator_checks.check_sample_weights_invariance
sklearn.utils.estimator_checks.config_context
sklearn.utils.estimator_checks.create_memmap_backed_data
sklearn.utils.estimator_checks.generate_invalid_param_val
sklearn.utils.estimator_checks.ignore_warnings
Expand Down Expand Up @@ -2312,7 +2264,6 @@ sklearn.utils.multiclass.dok_matrix
sklearn.utils.multiclass.get_namespace
sklearn.utils.multiclass.lil_matrix
sklearn.utils.multiclass.type_of_target
sklearn.utils.parallel.config_context
sklearn.utils.parallel_backend
sklearn.utils.random.sample_without_replacement
sklearn.utils.register_parallel_backend
Expand Down
4 changes: 2 additions & 2 deletions stubs/sklearn/utils/_metadata_requests.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class RequestMethod:
def __get__(self, instance, owner): ...

class _MetadataRequester:
def __init_subclass__(cls, **kwargs): ...
def get_metadata_routing(self): ...
def __init_subclass__(cls, **kwargs) -> None: ...
def get_metadata_routing(self) -> MetadataRequest: ...
# This code is never run in runtime, but it's here for type checking.
# Type checkers fail to understand that the `set_{method}_request`
# methods are dynamically generated, and they complain that they are
Expand Down
Loading