Skip to content

Commit

Permalink
Merge 7462492 into df07940
Browse files Browse the repository at this point in the history
  • Loading branch information
nicodv committed Apr 13, 2022
2 parents df07940 + 7462492 commit 02070e6
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 8 deletions.
8 changes: 6 additions & 2 deletions kmodes/kmodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def fit(self, X, y=None, sample_weight=None, **kwargs):
X = pandas_to_numpy(X)

random_state = check_random_state(self.random_state)
_validate_sample_weight(sample_weight, n_samples=X.shape[0])
_validate_sample_weight(sample_weight, n_samples=X.shape[0],
n_clusters=self.n_clusters)

self._enc_cluster_centroids, self._enc_map, self.labels_, self.cost_, \
self.n_iter_, self.epoch_costs_ = k_modes(
Expand Down Expand Up @@ -407,7 +408,7 @@ def _move_point_cat(point, ipoint, to_clust, from_clust, cl_attr_freq,
return cl_attr_freq, membship, centroids


def _validate_sample_weight(sample_weight, n_samples):
def _validate_sample_weight(sample_weight, n_samples, n_clusters):
if sample_weight is not None:
if len(sample_weight) != n_samples:
raise ValueError("sample_weight should be of equal size as samples.")
Expand All @@ -418,3 +419,6 @@ def _validate_sample_weight(sample_weight, n_samples):
raise ValueError("sample_weight elements should either be int or floats.")
if any(sample < 0 for sample in sample_weight):
raise ValueError("sample_weight elements should be positive.")
if sum([abs(x) > 0 for x in sample_weight]) < n_clusters:
raise ValueError("Number of non-zero sample_weight elements should be "
"larger than the number of clusters.")
5 changes: 3 additions & 2 deletions kmodes/kprototypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def fit(self, X, y=None, categorical=None, sample_weight=None):
X = pandas_to_numpy(X)

random_state = check_random_state(self.random_state)
kmodes._validate_sample_weight(sample_weight, n_samples=X.shape[0])
kmodes._validate_sample_weight(sample_weight, n_samples=X.shape[0],
n_clusters=self.n_clusters)

# If self.gamma is None, gamma will be automatically determined from
# the data. The function below returns its value.
Expand All @@ -175,7 +176,7 @@ def fit(self, X, y=None, categorical=None, sample_weight=None):

return self

def predict(self, X, categorical=None):
def predict(self, X, categorical=None, **kwargs):
"""Predict the closest cluster each sample in X belongs to.
Parameters
Expand Down
8 changes: 8 additions & 0 deletions kmodes/tests/test_kmodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,11 @@ def test_k_modes_sample_weight_unchanged(self):
tuple_pairs = zip(sorted(expected), sorted(factual))
for tuple_expected, tuple_factual in tuple_pairs:
self.assertAlmostEqual(tuple_expected, tuple_factual)

def test_kmodes_fit_predict(self):
"""Test whether fit_predict interface works the same as fit and predict."""
kmodes = KModes(n_clusters=4, init='Cao', random_state=42)
sample_weight = [0.5] * TEST_DATA.shape[0]
data1 = kmodes.fit_predict(TEST_DATA, sample_weight=sample_weight)
data2 = kmodes.fit(TEST_DATA, sample_weight=sample_weight).predict(TEST_DATA)
assert_cluster_splits_equal(data1, data2)
39 changes: 35 additions & 4 deletions kmodes/tests/test_kprototypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,17 +337,26 @@ def test_kprototypes_ninit(self):
def test_kprototypes_sample_weights_validation(self):
kproto = kprototypes.KPrototypes(n_clusters=4, init='Cao', verbose=2)
sample_weight_too_few = [1] * 11
with self.assertRaisesRegex(ValueError, "sample_weight should be of equal size as samples."):
with self.assertRaisesRegex(
ValueError,
"sample_weight should be of equal size as samples."
):
kproto.fit_predict(
STOCKS, categorical=[1, 2], sample_weight=sample_weight_too_few
)
sample_weight_negative = [-1] + [1] * 11
with self.assertRaisesRegex(ValueError, "sample_weight elements should be positive."):
with self.assertRaisesRegex(
ValueError,
"sample_weight elements should be positive."
):
kproto.fit_predict(
STOCKS, categorical=[1, 2], sample_weight=sample_weight_negative
)
sample_weight_non_numerical = [None] + [1] * 11
with self.assertRaisesRegex(ValueError, "sample_weight elements should either be int or floats."):
with self.assertRaisesRegex(
ValueError,
"sample_weight elements should either be int or floats."
):
kproto.fit_predict(
STOCKS, categorical=[1, 2], sample_weight=sample_weight_non_numerical
)
Expand All @@ -362,7 +371,21 @@ def test_k_prototypes_sample_weight_all_but_one_zero(self):
model = kproto.fit(
STOCKS[:n_samples, :], categorical=[1, 2], sample_weight=sample_weight
)
self.assertTrue((model.cluster_centroids_[0, :] == STOCKS[indicator, :]).all())
np.testing.assert_array_equal(
model.cluster_centroids_[0, :],
STOCKS[indicator, :]
)

def test_k_prototypes_sample_weight_not_enough_non_zero(self):
kproto = kprototypes.KPrototypes(n_clusters=2, init='Cao', random_state=42)
sample_weight = np.zeros(STOCKS.shape[0])
sample_weight[0] = 1
with self.assertRaisesRegex(
ValueError,
"Number of non-zero sample_weight elements should be larger "
"than the number of clusters."
):
kproto.fit(STOCKS, categorical=[1, 2], sample_weight=sample_weight)

def test_k_prototypes_sample_weight_unchanged(self):
"""Test whether centroid definition remains unchanged when scaling uniformly."""
Expand Down Expand Up @@ -390,3 +413,11 @@ def test_k_prototypes_sample_weight_unchanged(self):
for index in categorical:
self.assertTrue(tuple_expected[index] == tuple_factual[index])

def test_kmodes_fit_predict_equality(self):
"""Test whether fit_predict interface works the same as fit and predict."""
kproto = kprototypes.KPrototypes(n_clusters=3, init='Cao', random_state=42)
sample_weight = [0.5] * STOCKS.shape[0]
model1 = kproto.fit(STOCKS, categorical=[1, 2], sample_weight=sample_weight)
data1 = model1.predict(STOCKS, categorical=[1, 2])
data2 = kproto.fit_predict(STOCKS, categorical=[1, 2], sample_weight=sample_weight)
assert_cluster_splits_equal(data1, data2)

0 comments on commit 02070e6

Please sign in to comment.