Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Pass additional parameters to fit underlying estimator in EstimatorTransformer #530

Closed
CarloLepelaars opened this issue Sep 12, 2022 · 14 comments
Labels
enhancement New feature or request

Comments

@CarloLepelaars
Copy link
Contributor

In EstimatorTransformer the underlying estimator is being fitted without the ability to pass along additional arguments to self.estimator_.fit.

This limits use cases for EstimatorTransformer. For example, if the underlying estimator is an XGBClassifier we would like to be able to pass eval_set to monitor validation performance and enable early stopping. This is currently not possible. Adding *args, **kwargs should fix this issue.

self.estimator_.fit(X, y)

@CarloLepelaars CarloLepelaars added the enhancement New feature or request label Sep 12, 2022
@CarloLepelaars CarloLepelaars changed the title [FEATURE] Pass parameters to fit of underlying estimator in EstimatorTransformer [FEATURE] Pass additional parameters to fit underlying estimator in EstimatorTransformer Sep 12, 2022
@CarloLepelaars
Copy link
Contributor Author

@koaning

In the future, please don't make a PR until the direction of the solution has been discussed in the issues.

I'll pick up the conversation there.

Ok, no problem! Will keep that in mind.

@koaning
Copy link
Owner

koaning commented Sep 12, 2022

It's a bit of an sklearn antipattern to pass lots of settings via .fit(). Is XGBoost part of sklearn or a 3rd party lib?

@CarloLepelaars
Copy link
Contributor Author

CarloLepelaars commented Sep 12, 2022

XGBoost is a 3rd party library maintained by dmlc.

I agree hyperparameters shouldn't be passed via .fit(). Unfortunately, some parameters like eval_set often can be passed only with .fit(). I believe sample_weight for scikit-learn estimators can also only be passed through .fit(), but am not completely sure.

Other example use cases include CatBoost parameters and Lightgbm parameters that can only be passed through .fit(). This libraries are also 3rd party, but very often used within scikit-learn Pipelines.

UPDATE: From looking at scikit-learn source code it seems sample_weight can only be passed through .fit() and never with class initialization parameters. Adding this as an optional parameter to EstimatorTransformer.fit() would not necessarily extend this library beyond scikit-learn. Understand the case for generalizing to *args, **kwargs is a bit trickier.

@koaning
Copy link
Owner

koaning commented Sep 12, 2022

I'm a bit uneasy to extend this library beyond scikit-learn because the dependencies quickly start to stack up. @MBrouns what's your opinion on this?

@MBrouns
Copy link
Collaborator

MBrouns commented Sep 12, 2022

sklearn does describe the use of kwargs in fit methods: https://scikit-learn.org/stable/developers/develop.html#fitting. I'm not sure I like varargs, but I don't see a lot of problems with varkwargs. I would like to see a test added in the PR though before accepting it

@koaning
Copy link
Owner

koaning commented Sep 12, 2022

TIL.

Yeah so if scikit-learn supports **kwargs then I won't mind.

@CarloLepelaars
Copy link
Contributor Author

Sounds great! I removed the *args option and added a test.

@koaning Can we reopen this PR or should I create a new one?

@koaning
Copy link
Owner

koaning commented Sep 12, 2022

Either option is fine. Just as long as an issue is discussed before a solution is implemented.

@koaning
Copy link
Owner

koaning commented Sep 12, 2022

Oh! And one more thing. If you're adding this behavior to the estimator transformer, could you also add it to the estimatorpredictor?

@CarloLepelaars
Copy link
Contributor Author

Sure! Will check that out tomorrow. Do you mean pass **kwargs through the .transform method in EstimatorTransformer (+ a test case)?

@koaning
Copy link
Owner

koaning commented Sep 13, 2022

This issue is about the .fit() method, no?

@CarloLepelaars
Copy link
Contributor Author

Yes, but what exactly do you mean by estimatorpredictor otherwise? Don't see an EstimatorPredictor object in this repository (on the main branch).

@koaning
Copy link
Owner

koaning commented Sep 13, 2022

Ah! Crud. My bad.

I was confused with the Grouped variant of the meta estimators. These come with a predictor variant.

Please ignore the previous comment.

@CarloLepelaars
Copy link
Contributor Author

Aha, thanks for clearing that up! Then I think we are ready for the PR. Will create a fresh one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants