diff --git a/doc/whats_new.rst b/doc/whats_new.rst index a434ed82a11b8..cc481740c96f7 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -97,6 +97,10 @@ Bug fixes in R (lars library). :issue:`7849` by `Jair Montoya Martinez`_ + - Fix a bug regarding fitting :class:`sklearn.cluster.KMeans` with a + sparse array X and initial centroids, where X's means were unnecessarily + being subtracted from the centroids. :issue:`7872` by `Josh Karnofsky `_. + .. _changes_0_18_1: Version 0.18.1 diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index bd48a1c36224a..f33b3f65b714e 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -298,18 +298,11 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', ", but a value of %r was passed" % precompute_distances) - # subtract of mean of x for more accurate distance computations - if not sp.issparse(X) or hasattr(init, '__array__'): - X_mean = X.mean(axis=0) - if not sp.issparse(X): - # The copy was already done above - X -= X_mean - + # Validate init array if hasattr(init, '__array__'): init = check_array(init, dtype=X.dtype.type, copy=True) _validate_center_shape(X, n_clusters, init) - init -= X_mean if n_init != 1: warnings.warn( 'Explicit initial center position passed: ' @@ -317,6 +310,15 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', % n_init, RuntimeWarning, stacklevel=2) n_init = 1 + # subtract of mean of x for more accurate distance computations + if not sp.issparse(X): + X_mean = X.mean(axis=0) + # The copy was already done above + X -= X_mean + + if hasattr(init, '__array__'): + init -= X_mean + # precompute squared norms of data points x_squared_norms = row_norms(X, squared=True) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 4b23ab9cc1677..31307e55801a5 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -812,7 +812,7 @@ def test_float_precision(): decimal=4) -def test_KMeans_init_centers(): +def test_k_means_init_centers(): # This test is used to check KMeans won't mutate the user provided input # array silently even if input data and init centers have the same type X_small = np.array([[1.1, 1.1], [-7.5, -7.5], [-1.1, -1.1], [7.5, 7.5]]) @@ -824,3 +824,47 @@ def test_KMeans_init_centers(): km = KMeans(init=init_centers_test, n_clusters=3, n_init=1) km.fit(X_test) assert_equal(False, np.may_share_memory(km.cluster_centers_, init_centers)) + + +def test_sparse_k_means_init_centers(): + from sklearn.datasets import load_iris + + iris = load_iris() + X = iris.data + + # Get a local optimum + centers = KMeans(n_clusters=3).fit(X).cluster_centers_ + + # Fit starting from a local optimum shouldn't change the solution + np.testing.assert_allclose( + centers, + KMeans(n_clusters=3, + init=centers, + n_init=1).fit(X).cluster_centers_ + ) + + # The same should be true when X is sparse + X_sparse = sp.csr_matrix(X) + np.testing.assert_allclose( + centers, + KMeans(n_clusters=3, + init=centers, + n_init=1).fit(X_sparse).cluster_centers_ + ) + + +def test_sparse_validate_centers(): + from sklearn.datasets import load_iris + + iris = load_iris() + X = iris.data + + # Get a local optimum + centers = KMeans(n_clusters=4).fit(X).cluster_centers_ + + # Test that a ValueError is raised for validate_center_shape + classifier = KMeans(n_clusters=3, init=centers, n_init=1) + + msg = "The shape of the initial centers \(\(4L?, 4L?\)\) " \ + "does not match the number of clusters 3" + assert_raises_regex(ValueError, msg, classifier.fit, X)