diff --git a/stubs/sklearn/_config.pyi b/stubs/sklearn/_config.pyi index f8e5c8d1..ab99fa7d 100644 --- a/stubs/sklearn/_config.pyi +++ b/stubs/sklearn/_config.pyi @@ -16,6 +16,8 @@ def set_config( enable_cython_pairwise_dist: None | bool = None, array_api_dispatch: None | bool = None, transform_output: None | str = None, + enable_metadata_routing: None | bool = None, + skip_parameter_validation: None | bool = None, ) -> None: ... def config_context( *, @@ -27,4 +29,6 @@ def config_context( enable_cython_pairwise_dist: None | bool = None, array_api_dispatch: None | bool = None, transform_output: None | str = None, + enable_metadata_routing: None | bool = None, + skip_parameter_validation: None | bool = None, ) -> Iterator[None]: ... diff --git a/stubs/sklearn/base.pyi b/stubs/sklearn/base.pyi index a2ab07f1..234f7433 100644 --- a/stubs/sklearn/base.pyi +++ b/stubs/sklearn/base.pyi @@ -15,6 +15,7 @@ from ._config import get_config as get_config from ._typing import ArrayLike, Float, Int, MatrixLike from .metrics import accuracy_score as accuracy_score, r2_score as r2_score from .utils._estimator_html_repr import estimator_html_repr as estimator_html_repr +from .utils._metadata_requests import _MetadataRequester from .utils._param_validation import validate_parameter_constraints as validate_parameter_constraints from .utils._set_output import _SetOutputMixin from .utils.validation import check_array as check_array, check_is_fitted as check_is_fitted, check_X_y as check_X_y @@ -24,7 +25,7 @@ from .utils.validation import check_array as check_array, check_is_fitted as che def clone(estimator: BaseEstimator | Iterable[BaseEstimator], *, safe: bool = True) -> Any: ... -class BaseEstimator: +class BaseEstimator(_MetadataRequester): def get_params(self, deep: bool = True) -> dict: ... def set_params(self, **params) -> Self: ... def __repr__(self, N_CHAR_MAX: int = 700) -> str: ... diff --git a/stubs/sklearn/stubtest_allowlist.txt b/stubs/sklearn/stubtest_allowlist.txt index 98ab6a3d..b7d62874 100644 --- a/stubs/sklearn/stubtest_allowlist.txt +++ b/stubs/sklearn/stubtest_allowlist.txt @@ -78,7 +78,6 @@ sklearn.base._UnstableArchMixin.__sklearn_tags__ sklearn.base.is_clusterer sklearn.calibration sklearn.calibration.CalibratedClassifierCV.__sklearn_tags__ -sklearn.calibration.CalibratedClassifierCV.get_metadata_routing sklearn.calibration.CalibratedClassifierCV.set_fit_request sklearn.calibration.CalibratedClassifierCV.set_score_request sklearn.calibration.LabelEncoder.__sklearn_tags__ @@ -147,15 +146,11 @@ sklearn.cluster.tests.test_spectral sklearn.compose sklearn.compose.ColumnTransformer.__getitem__ sklearn.compose.ColumnTransformer.__sklearn_tags__ -sklearn.compose.ColumnTransformer.get_metadata_routing sklearn.compose.TransformedTargetRegressor.__sklearn_tags__ -sklearn.compose.TransformedTargetRegressor.get_metadata_routing sklearn.compose.TransformedTargetRegressor.set_score_request sklearn.compose._column_transformer.ColumnTransformer.__getitem__ sklearn.compose._column_transformer.ColumnTransformer.__sklearn_tags__ -sklearn.compose._column_transformer.ColumnTransformer.get_metadata_routing sklearn.compose._target.TransformedTargetRegressor.__sklearn_tags__ -sklearn.compose._target.TransformedTargetRegressor.get_metadata_routing sklearn.compose._target.TransformedTargetRegressor.set_score_request sklearn.compose.tests sklearn.compose.tests.test_column_transformer @@ -170,8 +165,6 @@ sklearn.conftest.raccoon_face_fxt sklearn.conftest.raccoon_face_or_skip sklearn.conftest.scipy_datasets_require_network sklearn.covariance -sklearn.covariance.GraphicalLassoCV.get_metadata_routing -sklearn.covariance._graph_lasso.GraphicalLassoCV.get_metadata_routing sklearn.covariance.tests sklearn.covariance.tests.test_covariance sklearn.covariance.tests.test_elliptic_envelope @@ -343,7 +336,6 @@ sklearn.ensemble._bagging.BaggingClassifier.set_score_request sklearn.ensemble._bagging.BaggingRegressor.set_fit_request sklearn.ensemble._bagging.BaggingRegressor.set_score_request sklearn.ensemble._bagging.BaseBagging.__sklearn_tags__ -sklearn.ensemble._bagging.BaseBagging.get_metadata_routing sklearn.ensemble._bagging.BaseBagging.set_fit_request sklearn.ensemble._base._BaseHeterogeneousEnsemble.__sklearn_tags__ sklearn.ensemble._forest.BaseForest.__sklearn_tags__ @@ -420,13 +412,11 @@ sklearn.ensemble._stacking.StackingClassifier.set_fit_request sklearn.ensemble._stacking.StackingClassifier.set_score_request sklearn.ensemble._stacking.StackingRegressor.set_fit_request sklearn.ensemble._stacking.StackingRegressor.set_score_request -sklearn.ensemble._stacking._BaseStacking.get_metadata_routing sklearn.ensemble._voting.VotingClassifier.__sklearn_tags__ sklearn.ensemble._voting.VotingClassifier.set_fit_request sklearn.ensemble._voting.VotingClassifier.set_score_request sklearn.ensemble._voting.VotingRegressor.set_fit_request sklearn.ensemble._voting.VotingRegressor.set_score_request -sklearn.ensemble._voting._BaseVoting.get_metadata_routing sklearn.ensemble._weight_boosting.AdaBoostClassifier.set_fit_request sklearn.ensemble._weight_boosting.AdaBoostClassifier.set_score_request sklearn.ensemble._weight_boosting.AdaBoostRegressor.set_fit_request @@ -448,7 +438,6 @@ sklearn.exceptions.EstimatorCheckFailedWarning sklearn.exceptions.UnsetMetadataPassedError sklearn.exceptions.__all__ sklearn.experimental -sklearn.experimental.enable_iterative_imputer.IterativeImputer.get_metadata_routing sklearn.experimental.tests sklearn.experimental.tests.test_enable_hist_gradient_boosting sklearn.experimental.tests.test_enable_iterative_imputer @@ -478,22 +467,14 @@ sklearn.feature_extraction.text.TfidfVectorizer.__sklearn_tags__ sklearn.feature_selection sklearn.feature_selection.GenericUnivariateSelect.__sklearn_tags__ sklearn.feature_selection.RFE.__sklearn_tags__ -sklearn.feature_selection.RFE.get_metadata_routing -sklearn.feature_selection.RFECV.get_metadata_routing sklearn.feature_selection.SelectFromModel.__sklearn_tags__ -sklearn.feature_selection.SelectFromModel.get_metadata_routing sklearn.feature_selection.SelectKBest.__sklearn_tags__ sklearn.feature_selection.SelectPercentile.__sklearn_tags__ sklearn.feature_selection.SequentialFeatureSelector.__sklearn_tags__ -sklearn.feature_selection.SequentialFeatureSelector.get_metadata_routing sklearn.feature_selection.VarianceThreshold.__sklearn_tags__ sklearn.feature_selection._from_model.SelectFromModel.__sklearn_tags__ -sklearn.feature_selection._from_model.SelectFromModel.get_metadata_routing sklearn.feature_selection._rfe.RFE.__sklearn_tags__ -sklearn.feature_selection._rfe.RFE.get_metadata_routing -sklearn.feature_selection._rfe.RFECV.get_metadata_routing sklearn.feature_selection._sequential.SequentialFeatureSelector.__sklearn_tags__ -sklearn.feature_selection._sequential.SequentialFeatureSelector.get_metadata_routing sklearn.feature_selection._univariate_selection.GenericUnivariateSelect.__sklearn_tags__ sklearn.feature_selection._univariate_selection.SelectKBest.__sklearn_tags__ sklearn.feature_selection._univariate_selection.SelectPercentile.__sklearn_tags__ @@ -525,14 +506,12 @@ sklearn.gaussian_process.tests.test_gpc sklearn.gaussian_process.tests.test_gpr sklearn.gaussian_process.tests.test_kernels sklearn.impute -sklearn.impute.IterativeImputer.get_metadata_routing sklearn.impute.MissingIndicator.__sklearn_tags__ sklearn.impute.SimpleImputer.__sklearn_tags__ sklearn.impute.__all__ sklearn.impute._base.MissingIndicator.__sklearn_tags__ sklearn.impute._base.SimpleImputer.__sklearn_tags__ sklearn.impute._base._BaseImputer.__sklearn_tags__ -sklearn.impute._iterative.IterativeImputer.get_metadata_routing sklearn.impute.tests sklearn.impute.tests.test_base sklearn.impute.tests.test_common @@ -579,7 +558,6 @@ sklearn.linear_model.HuberRegressor.set_score_request sklearn.linear_model.Lars.set_fit_request sklearn.linear_model.Lars.set_score_request sklearn.linear_model.LarsCV.__sklearn_tags__ -sklearn.linear_model.LarsCV.get_metadata_routing sklearn.linear_model.LarsCV.parameter sklearn.linear_model.LarsCV.set_score_request sklearn.linear_model.Lasso.set_fit_request @@ -600,7 +578,6 @@ sklearn.linear_model.LogisticRegression.__sklearn_tags__ sklearn.linear_model.LogisticRegression.set_fit_request sklearn.linear_model.LogisticRegression.set_score_request sklearn.linear_model.LogisticRegressionCV.__sklearn_tags__ -sklearn.linear_model.LogisticRegressionCV.get_metadata_routing sklearn.linear_model.LogisticRegressionCV.param sklearn.linear_model.LogisticRegressionCV.set_fit_request sklearn.linear_model.LogisticRegressionCV.set_score_request @@ -611,7 +588,6 @@ sklearn.linear_model.MultiTaskElasticNetCV.set_score_request sklearn.linear_model.MultiTaskLassoCV.__sklearn_tags__ sklearn.linear_model.MultiTaskLassoCV.set_score_request sklearn.linear_model.OrthogonalMatchingPursuit.set_score_request -sklearn.linear_model.OrthogonalMatchingPursuitCV.get_metadata_routing sklearn.linear_model.OrthogonalMatchingPursuitCV.set_score_request sklearn.linear_model.PassiveAggressiveClassifier.set_fit_request sklearn.linear_model.PassiveAggressiveClassifier.set_partial_fit_request @@ -627,7 +603,6 @@ sklearn.linear_model.QuantileRegressor.__sklearn_tags__ sklearn.linear_model.QuantileRegressor.set_fit_request sklearn.linear_model.QuantileRegressor.set_score_request sklearn.linear_model.RANSACRegressor.__sklearn_tags__ -sklearn.linear_model.RANSACRegressor.get_metadata_routing sklearn.linear_model.RANSACRegressor.set_fit_request sklearn.linear_model.Ridge.__sklearn_tags__ sklearn.linear_model.Ridge.set_fit_request @@ -673,7 +648,6 @@ sklearn.linear_model._coordinate_descent.Lasso.set_score_request sklearn.linear_model._coordinate_descent.LassoCV.set_fit_request sklearn.linear_model._coordinate_descent.LassoCV.set_score_request sklearn.linear_model._coordinate_descent.LinearModelCV.__sklearn_tags__ -sklearn.linear_model._coordinate_descent.LinearModelCV.get_metadata_routing sklearn.linear_model._coordinate_descent.LinearModelCV.set_fit_request sklearn.linear_model._coordinate_descent.MultiTaskElasticNet.__sklearn_tags__ sklearn.linear_model._coordinate_descent.MultiTaskElasticNet.param @@ -707,7 +681,6 @@ sklearn.linear_model._huber.HuberRegressor.set_score_request sklearn.linear_model._least_angle.Lars.set_fit_request sklearn.linear_model._least_angle.Lars.set_score_request sklearn.linear_model._least_angle.LarsCV.__sklearn_tags__ -sklearn.linear_model._least_angle.LarsCV.get_metadata_routing sklearn.linear_model._least_angle.LarsCV.parameter sklearn.linear_model._least_angle.LarsCV.set_score_request sklearn.linear_model._least_angle.LassoLars.set_fit_request @@ -722,12 +695,10 @@ sklearn.linear_model._logistic.LogisticRegression.__sklearn_tags__ sklearn.linear_model._logistic.LogisticRegression.set_fit_request sklearn.linear_model._logistic.LogisticRegression.set_score_request sklearn.linear_model._logistic.LogisticRegressionCV.__sklearn_tags__ -sklearn.linear_model._logistic.LogisticRegressionCV.get_metadata_routing sklearn.linear_model._logistic.LogisticRegressionCV.param sklearn.linear_model._logistic.LogisticRegressionCV.set_fit_request sklearn.linear_model._logistic.LogisticRegressionCV.set_score_request sklearn.linear_model._omp.OrthogonalMatchingPursuit.set_score_request -sklearn.linear_model._omp.OrthogonalMatchingPursuitCV.get_metadata_routing sklearn.linear_model._omp.OrthogonalMatchingPursuitCV.set_score_request sklearn.linear_model._passive_aggressive.PassiveAggressiveClassifier.set_fit_request sklearn.linear_model._passive_aggressive.PassiveAggressiveClassifier.set_partial_fit_request @@ -741,7 +712,6 @@ sklearn.linear_model._quantile.QuantileRegressor.__sklearn_tags__ sklearn.linear_model._quantile.QuantileRegressor.set_fit_request sklearn.linear_model._quantile.QuantileRegressor.set_score_request sklearn.linear_model._ransac.RANSACRegressor.__sklearn_tags__ -sklearn.linear_model._ransac.RANSACRegressor.get_metadata_routing sklearn.linear_model._ransac.RANSACRegressor.set_fit_request sklearn.linear_model._ridge.Ridge.__sklearn_tags__ sklearn.linear_model._ridge.Ridge.set_fit_request @@ -757,7 +727,6 @@ sklearn.linear_model._ridge.RidgeClassifierCV.set_score_request sklearn.linear_model._ridge._BaseRidge.set_fit_request sklearn.linear_model._ridge._BaseRidgeCV.__sklearn_tags__ sklearn.linear_model._ridge._BaseRidgeCV.cv_values_ -sklearn.linear_model._ridge._BaseRidgeCV.get_metadata_routing sklearn.linear_model._ridge._BaseRidgeCV.set_fit_request sklearn.linear_model._ridge._IdentityClassifier.set_decision_function_request sklearn.linear_model._ridge._IdentityClassifier.set_score_request @@ -929,7 +898,6 @@ sklearn.model_selection._classification_threshold.FixedThresholdClassifier.set_s sklearn.model_selection._classification_threshold.TunedThresholdClassifierCV.set_score_request sklearn.model_selection._plot.LearningCurveDisplay.from_estimator sklearn.model_selection._search.BaseSearchCV.__sklearn_tags__ -sklearn.model_selection._search.BaseSearchCV.get_metadata_routing sklearn.model_selection._split.GroupKFold.set_split_request sklearn.model_selection._split.GroupShuffleSplit.set_split_request sklearn.model_selection._split.LeaveOneGroupOut.set_split_request @@ -947,19 +915,15 @@ sklearn.model_selection.tests.test_successive_halving sklearn.model_selection.tests.test_validation sklearn.multiclass sklearn.multiclass.OneVsOneClassifier.__sklearn_tags__ -sklearn.multiclass.OneVsOneClassifier.get_metadata_routing sklearn.multiclass.OneVsOneClassifier.set_partial_fit_request sklearn.multiclass.OneVsOneClassifier.set_score_request sklearn.multiclass.OneVsRestClassifier.__sklearn_tags__ -sklearn.multiclass.OneVsRestClassifier.get_metadata_routing sklearn.multiclass.OneVsRestClassifier.set_partial_fit_request sklearn.multiclass.OneVsRestClassifier.set_score_request sklearn.multiclass.OutputCodeClassifier.__sklearn_tags__ -sklearn.multiclass.OutputCodeClassifier.get_metadata_routing sklearn.multiclass.OutputCodeClassifier.set_score_request sklearn.multioutput sklearn.multioutput.ClassifierChain.__sklearn_tags__ -sklearn.multioutput.ClassifierChain.get_metadata_routing sklearn.multioutput.ClassifierChain.predict_log_proba sklearn.multioutput.ClassifierChain.set_score_request sklearn.multioutput.MultiOutputClassifier.__sklearn_tags__ @@ -969,11 +933,9 @@ sklearn.multioutput.MultiOutputRegressor.set_fit_request sklearn.multioutput.MultiOutputRegressor.set_partial_fit_request sklearn.multioutput.MultiOutputRegressor.set_score_request sklearn.multioutput.RegressorChain.__sklearn_tags__ -sklearn.multioutput.RegressorChain.get_metadata_routing sklearn.multioutput.RegressorChain.set_score_request sklearn.multioutput._BaseChain.__sklearn_tags__ sklearn.multioutput._MultiOutputEstimator.__sklearn_tags__ -sklearn.multioutput._MultiOutputEstimator.get_metadata_routing sklearn.multioutput._MultiOutputEstimator.set_fit_request sklearn.multioutput._MultiOutputEstimator.set_partial_fit_request sklearn.naive_bayes @@ -1073,10 +1035,8 @@ sklearn.pipeline sklearn.pipeline.FeatureUnion.__getitem__ sklearn.pipeline.FeatureUnion.__sklearn_tags__ sklearn.pipeline.FeatureUnion.feature_names_in_ -sklearn.pipeline.FeatureUnion.get_metadata_routing sklearn.pipeline.FunctionTransformer.__sklearn_tags__ sklearn.pipeline.Pipeline.__sklearn_tags__ -sklearn.pipeline.Pipeline.get_metadata_routing sklearn.pipeline.Pipeline.set_score_request sklearn.preprocessing sklearn.preprocessing.Binarizer.__sklearn_tags__ @@ -1144,13 +1104,11 @@ sklearn.semi_supervised sklearn.semi_supervised.LabelPropagation.set_score_request sklearn.semi_supervised.LabelSpreading.set_score_request sklearn.semi_supervised.SelfTrainingClassifier.__sklearn_tags__ -sklearn.semi_supervised.SelfTrainingClassifier.get_metadata_routing sklearn.semi_supervised._label_propagation.BaseLabelPropagation.__sklearn_tags__ sklearn.semi_supervised._label_propagation.BaseLabelPropagation.set_score_request sklearn.semi_supervised._label_propagation.LabelPropagation.set_score_request sklearn.semi_supervised._label_propagation.LabelSpreading.set_score_request sklearn.semi_supervised._self_training.SelfTrainingClassifier.__sklearn_tags__ -sklearn.semi_supervised._self_training.SelfTrainingClassifier.get_metadata_routing sklearn.semi_supervised.tests sklearn.semi_supervised.tests.test_label_propagation sklearn.semi_supervised.tests.test_self_training @@ -1447,8 +1405,6 @@ sklearn._build_utils.openmp_helpers sklearn._build_utils.parse sklearn._build_utils.pre_build_helpers sklearn._config._threadlocal -sklearn._config.config_context -sklearn._config.set_config sklearn._loss._loss.CyAbsoluteError.cy_grad_hess sklearn._loss._loss.CyAbsoluteError.cy_gradient sklearn._loss._loss.CyAbsoluteError.cy_loss @@ -1528,7 +1484,6 @@ sklearn.compose.ColumnTransformer.__init__ sklearn.compose._column_transformer.ColumnTransformer.__init__ sklearn.compose._column_transformer.make_column_transformer sklearn.compose.make_column_transformer -sklearn.config_context sklearn.conftest._SKIP32_MARK sklearn.conftest.fetch_20newsgroups sklearn.conftest.fetch_20newsgroups_fxt @@ -2045,7 +2000,6 @@ sklearn.metrics.mean_squared_log_error sklearn.metrics.pairwise.DistanceMetric sklearn.metrics.pairwise.check_array sklearn.metrics.pairwise.check_pairwise_arrays -sklearn.metrics.pairwise.config_context sklearn.metrics.pairwise.manhattan_distances sklearn.metrics.pairwise.pairwise_distances sklearn.metrics.pairwise_distances @@ -2126,7 +2080,6 @@ sklearn.random_projection.check_array sklearn.random_projection.sample_without_replacement sklearn.semi_supervised.SelfTrainingClassifier.__init__ sklearn.semi_supervised._self_training.SelfTrainingClassifier.__init__ -sklearn.set_config sklearn.svm.LinearSVC.__init__ sklearn.svm.LinearSVR.__init__ sklearn.svm.NuSVC.coef_ @@ -2275,7 +2228,6 @@ sklearn.utils.estimator_checks.check_estimators_pickle sklearn.utils.estimator_checks.check_global_ouptut_transform_pandas sklearn.utils.estimator_checks.check_parameters_default_constructible sklearn.utils.estimator_checks.check_sample_weights_invariance -sklearn.utils.estimator_checks.config_context sklearn.utils.estimator_checks.create_memmap_backed_data sklearn.utils.estimator_checks.generate_invalid_param_val sklearn.utils.estimator_checks.ignore_warnings @@ -2312,7 +2264,6 @@ sklearn.utils.multiclass.dok_matrix sklearn.utils.multiclass.get_namespace sklearn.utils.multiclass.lil_matrix sklearn.utils.multiclass.type_of_target -sklearn.utils.parallel.config_context sklearn.utils.parallel_backend sklearn.utils.random.sample_without_replacement sklearn.utils.register_parallel_backend diff --git a/stubs/sklearn/utils/_metadata_requests.pyi b/stubs/sklearn/utils/_metadata_requests.pyi index 36e66ce9..c636b4ec 100644 --- a/stubs/sklearn/utils/_metadata_requests.pyi +++ b/stubs/sklearn/utils/_metadata_requests.pyi @@ -70,8 +70,8 @@ class RequestMethod: def __get__(self, instance, owner): ... class _MetadataRequester: - def __init_subclass__(cls, **kwargs): ... - def get_metadata_routing(self): ... + def __init_subclass__(cls, **kwargs) -> None: ... + def get_metadata_routing(self) -> MetadataRequest: ... # This code is never run in runtime, but it's here for type checking. # Type checkers fail to understand that the `set_{method}_request` # methods are dynamically generated, and they complain that they are