-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sample weights to KPrototypes. #171
Conversation
Hi @nicodv ! Happy to hear your thoughts. :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonderful contribution, @kklein . Thank you!
I left some comments and questions.
kmodes/kprototypes.py
Outdated
@@ -130,13 +130,17 @@ def __init__(self, n_clusters=8, max_iter=100, num_dissim=euclidean_dissim, | |||
"Setting n_init to 1.") | |||
self.n_init = 1 | |||
|
|||
def fit(self, X, y=None, categorical=None): | |||
def fit(self, X, y=None, categorical=None, sample_weights=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can also add it to KModes
, since you've done most of the legwork for that already.
This can be a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also seemed to me as if it would be fitting and consistent to enable the functionality for KModes
as well.
If possible, I would appreciate it if this could be done in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that would be great
kmodes/kprototypes.py
Outdated
@@ -513,3 +527,16 @@ def _split_num_cat(X, categorical): | |||
if ii not in categorical]]).astype(np.float64) | |||
Xcat = np.asanyarray(X[:, categorical]) | |||
return Xnum, Xcat | |||
|
|||
|
|||
def _validate_sample_weights(sample_weights, n_samples): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're enabling this for both KModes and KPrototypes, I suggest moving this method to the former.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted (for a possible follow-up PR)!
kmodes/kprototypes.py
Outdated
# Initial assignment to clusters | ||
clust = np.argmin( | ||
num_dissim(centroids[0], Xnum[ipoint]) + gamma * | ||
cat_dissim(centroids[1], Xcat[ipoint], X=Xcat, membship=membship) | ||
) | ||
membship[clust, ipoint] = 1 | ||
cl_memb_sum[clust] += 1 | ||
cl_memb_sum[clust] += sample_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ultimately, we calculate the mean by dividing the cl_attr_sum
by this, the cl_mem_sum
: https://github.com/nicodv/kmodes/blob/master/kmodes/kprototypes.py#L471
If we apply the weight to both the numerator and denominator, they cancel out, no?
Shouldn't we solely apply the weight to cl_attr_sum
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most definitely! This slipped through the cracks.
978d973
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do wonder now how the unit test that tests for a single overweighted sample was able to pass. 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, so do I. I think the problem was only for one of numerical/categorical features. Maybe the other having been correct was sufficient to push the centroid to the right point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks a bunch for your fast and very useful feedback! :) |
As of now, every data point contributes equally to the loss function and derived cluster updates.
Yet, in some use cases, it might be desirable to attach weights to data points.
This PR introduces
sample_weights
, a sequence of numeric values, as an optional parameter forKPrototypes
'fit
method as well as all downstream functions.Some basic input validation as well as some testing are provided.