-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Addition of kmeanscrossvalidator #61
Addition of kmeanscrossvalidator #61
Conversation
Maybe it's a bit nitpicky, but currently this implementation is fully built on KMeans whereas any clusteringmethod should work right? Maybe something for the future to make the CV take any clusterer as a parameter |
@MBrouns I agree. Related: I think |
@kayhoogland whats your opinion on this? it feels weird to merge this knowing that we will add a more custom thing as well. (also ... typically clustering algorithms are very prone to overfitting on a column if there's no standardisation so i think we don't just want to accept multiple clustering algorithms but that we want to allow for pipelines) |
I agree, it seems the logical thing to do. Also, the name is indeed catchy ;) |
I did some changes to enable for more clustering methods. A difficult thing to take into account is the n_splits from _BaseKFold. Some (most) clustering methods (for example DBScan) generate splits based on the data itself. |
closed it because i hadnt heard from it in a while. ill gladly re-open if there's more attention for it. |
I just discussed this with @kayhoogland and we see a good way going forward. I'm putting my feedback in a revie |
2541602
to
efaef5e
Compare
sklego/model_selection.py
Outdated
if isinstance(X, pd.DataFrame): | ||
X = X.values | ||
|
||
clusters = self.cluster_method.fit_predict(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For not refitting if self.cluster_method
is already fitted:
from sklearn.exceptions import NotFittedError
try:
clusters = self.cluster_method.predict(X)
except NotFittedError:
clusters = self.cluster_method.fit_predict(X)
super(KlusterFoldValidation, self).__init__(n_splits=3, | ||
shuffle=False, | ||
random_state=random_state) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably nice to set self.n_splits = None
here
w00t. |
First version of the kmeans cross validator discussed here #5