From db581f26fb9c0eeb6fc8a885db294bb5f46a00b0 Mon Sep 17 00:00:00 2001 From: Nico de Vos Date: Tue, 6 Sep 2022 12:11:49 -0700 Subject: [PATCH] improve estimation of gamma --- kmodes/kprototypes.py | 2 +- kmodes/tests/test_kprototypes.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/kmodes/kprototypes.py b/kmodes/kprototypes.py index 3d4b57f..e8a0648 100644 --- a/kmodes/kprototypes.py +++ b/kmodes/kprototypes.py @@ -290,7 +290,7 @@ def k_prototypes(X, categorical, n_clusters, max_iter, num_dissim, cat_dissim, # Estimate a good value for gamma, which determines the weighing of # categorical values in clusters (see Huang [1997]). if gamma is None: - gamma = 0.5 * Xnum.std() + gamma = 0.5 * np.mean(Xnum.std(axis=0)) results = [] seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) diff --git a/kmodes/tests/test_kprototypes.py b/kmodes/tests/test_kprototypes.py index ce3a2c1..b2ed9fc 100644 --- a/kmodes/tests/test_kprototypes.py +++ b/kmodes/tests/test_kprototypes.py @@ -428,3 +428,23 @@ def test_pandas_numpy_equality(self): result_np = kproto.fit_predict(STOCKS, categorical=[1, 2]) result_pd = kproto.fit_predict(pd.DataFrame(STOCKS), categorical=[1, 2]) np.testing.assert_array_equal(result_np, result_pd) + + def test_gamma_estimation(self): + data = np.hstack([ + np.array([ + [0.0], + [0.0], + [0.0], + [1.0], + [1.0], + [1.0], + [2.0], + [2.0], + [2.0], + [3.0], + [4.0], + [5.0], + ]), STOCKS]) + kproto = kprototypes.KPrototypes(n_clusters=4, init='Cao', random_state=42) + kproto_fitted = kproto.fit(data, categorical=[2, 3]) + self.assertEqual(kproto_fitted.gamma, 35.33525036439546)