Skip to content

Commit

Permalink
Added distance_threshold parameter to hierarchical clustering (scikit…
Browse files Browse the repository at this point in the history
  • Loading branch information
VathsalaAchar authored and koenvandevelde committed Jul 12, 2019
1 parent d5bcf86 commit 29134a0
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 26 deletions.
4 changes: 2 additions & 2 deletions doc/modules/clustering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ Overview of clustering methods
- Graph distance (e.g. nearest-neighbor graph)

* - :ref:`Ward hierarchical clustering <hierarchical_clustering>`
- number of clusters
- number of clusters or distance threshold
- Large ``n_samples`` and ``n_clusters``
- Many clusters, possibly connectivity constraints
- Distances between points

* - :ref:`Agglomerative clustering <hierarchical_clustering>`
- number of clusters, linkage type, distance
- number of clusters or distance threshold, linkage type, distance
- Large ``n_samples`` and ``n_clusters``
- Many clusters, possibly connectivity constraints, non Euclidean
distances
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ Support for Python 3.4 and below has been officially dropped.
``n_connected_components_``.
:pr:`13427` by :user:`Stephane Couvreur <scouvreur>`.

- |Enhancement| :class:`cluster.AgglomerativeClustering` and
:class:`cluster.FeatureAgglomeration` now accept a ``distance_threshold``
parameter which can be used to find the clusters instead of ``n_clusters``.
:issue:`9069` by :user:`Vathsala Achar <VathsalaAchar>` and `Adrin Jalali`_.

:mod:`sklearn.datasets`
.......................

Expand Down
107 changes: 84 additions & 23 deletions sklearn/cluster/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,9 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin):
Parameters
----------
n_clusters : int, default=2
The number of clusters to find.
n_clusters : int or None, optional (default=2)
The number of clusters to find. It must be ``None`` if
``distance_threshold`` is not ``None``.
affinity : string or callable, default: "euclidean"
Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
Expand Down Expand Up @@ -688,7 +689,8 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin):
not small compared to the number of samples. This option is
useful only when specifying a connectivity matrix. Note also that
when varying the number of clusters and using caching, it may
be advantageous to compute the full tree.
be advantageous to compute the full tree. It must be ``True`` if
``distance_threshold`` is not ``None``.
linkage : {"ward", "complete", "average", "single"}, optional \
(default="ward")
Expand All @@ -711,8 +713,20 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin):
``pooling_func`` has been deprecated in 0.20 and will be removed
in 0.22.
distance_threshold : float, optional (default=None)
The linkage distance threshold above which, clusters will not be
merged. If not ``None``, ``n_clusters`` must be ``None`` and
``compute_full_tree`` must be ``True``.
.. versionadded:: 0.21
Attributes
----------
n_clusters_ : int
The number of clusters found by the algorithm. If
``distance_threshold=None``, it will be equal to the given
``n_clusters``.
labels_ : array [n_samples]
cluster labels for each point
Expand All @@ -739,8 +753,9 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin):
>>> clustering = AgglomerativeClustering().fit(X)
>>> clustering # doctest: +NORMALIZE_WHITESPACE
AgglomerativeClustering(affinity='euclidean', compute_full_tree='auto',
connectivity=None, linkage='ward', memory=None, n_clusters=2,
pooling_func='deprecated')
connectivity=None, distance_threshold=None,
linkage='ward', memory=None, n_clusters=2,
pooling_func='deprecated')
>>> clustering.labels_
array([1, 1, 1, 0, 0, 0])
Expand All @@ -749,8 +764,10 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin):
def __init__(self, n_clusters=2, affinity="euclidean",
memory=None,
connectivity=None, compute_full_tree='auto',
linkage='ward', pooling_func='deprecated'):
linkage='ward', pooling_func='deprecated',
distance_threshold=None):
self.n_clusters = n_clusters
self.distance_threshold = distance_threshold
self.memory = memory
self.connectivity = connectivity
self.compute_full_tree = compute_full_tree
Expand Down Expand Up @@ -788,10 +805,20 @@ def fit(self, X, y=None):
X = check_array(X, ensure_min_samples=2, estimator=self)
memory = check_memory(self.memory)

if self.n_clusters <= 0:
if self.n_clusters is not None and self.n_clusters <= 0:
raise ValueError("n_clusters should be an integer greater than 0."
" %s was provided." % str(self.n_clusters))

if not ((self.n_clusters is None) ^ (self.distance_threshold is None)):
raise ValueError("Exactly one of n_clusters and "
"distance_threshold has to be set, and the other "
"needs to be None.")

if (self.distance_threshold is not None
and not self.compute_full_tree):
raise ValueError("compute_full_tree must be True if "
"distance_threshold is set.")

if self.linkage == "ward" and self.affinity != "euclidean":
raise ValueError("%s was provided as affinity. Ward can only "
"work with euclidean distances." %
Expand All @@ -815,10 +842,13 @@ def fit(self, X, y=None):
if self.connectivity is None:
compute_full_tree = True
if compute_full_tree == 'auto':
# Early stopping is likely to give a speed up only for
# a large number of clusters. The actual threshold
# implemented here is heuristic
compute_full_tree = self.n_clusters < max(100, .02 * n_samples)
if self.distance_threshold is not None:
compute_full_tree = True
else:
# Early stopping is likely to give a speed up only for
# a large number of clusters. The actual threshold
# implemented here is heuristic
compute_full_tree = self.n_clusters < max(100, .02 * n_samples)
n_clusters = self.n_clusters
if compute_full_tree:
n_clusters = None
Expand All @@ -828,14 +858,29 @@ def fit(self, X, y=None):
if self.linkage != 'ward':
kwargs['linkage'] = self.linkage
kwargs['affinity'] = self.affinity
(self.children_, self.n_connected_components_, self.n_leaves_,
parents) = memory.cache(tree_builder)(X, connectivity,
n_clusters=n_clusters,
**kwargs)

distance_threshold = self.distance_threshold

return_distance = distance_threshold is not None
out = memory.cache(tree_builder)(X, connectivity,
n_clusters=n_clusters,
return_distance=return_distance,
**kwargs)
(self.children_,
self.n_connected_components_,
self.n_leaves_,
parents) = out[:4]

if distance_threshold is not None:
distances = out[-1]
self.n_clusters_ = np.count_nonzero(
distances >= distance_threshold) + 1
else:
self.n_clusters_ = self.n_clusters

# Cut the tree
if compute_full_tree:
self.labels_ = _hc_cut(self.n_clusters, self.children_,
self.labels_ = _hc_cut(self.n_clusters_, self.children_,
self.n_leaves_)
else:
labels = _hierarchical.hc_get_heads(parents, copy=False)
Expand All @@ -856,8 +901,9 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
Parameters
----------
n_clusters : int, default 2
The number of clusters to find.
n_clusters : int or None, optional (default=2)
The number of clusters to find. It must be ``None`` if
``distance_threshold`` is not ``None``.
affinity : string or callable, default "euclidean"
Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
Expand All @@ -883,7 +929,8 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
not small compared to the number of features. This option is
useful only when specifying a connectivity matrix. Note also that
when varying the number of clusters and using caching, it may
be advantageous to compute the full tree.
be advantageous to compute the full tree. It must be ``True`` if
``distance_threshold`` is not ``None``.
linkage : {"ward", "complete", "average", "single"}, optional\
(default="ward")
Expand All @@ -904,8 +951,20 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
value, and should accept an array of shape [M, N] and the keyword
argument `axis=1`, and reduce it to an array of size [M].
distance_threshold : float, optional (default=None)
The linkage distance threshold above which, clusters will not be
merged. If not ``None``, ``n_clusters`` must be ``None`` and
``compute_full_tree`` must be ``True``.
.. versionadded:: 0.21
Attributes
----------
n_clusters_ : int
The number of clusters found by the algorithm. If
``distance_threshold=None``, it will be equal to the given
``n_clusters``.
labels_ : array-like, (n_features,)
cluster labels for each feature.
Expand Down Expand Up @@ -933,8 +992,9 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
>>> agglo = cluster.FeatureAgglomeration(n_clusters=32)
>>> agglo.fit(X) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
FeatureAgglomeration(affinity='euclidean', compute_full_tree='auto',
connectivity=None, linkage='ward', memory=None, n_clusters=32,
pooling_func=...)
connectivity=None, distance_threshold=None, linkage='ward',
memory=None, n_clusters=32,
pooling_func=...)
>>> X_reduced = agglo.transform(X)
>>> X_reduced.shape
(1797, 32)
Expand All @@ -943,11 +1003,12 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
def __init__(self, n_clusters=2, affinity="euclidean",
memory=None,
connectivity=None, compute_full_tree='auto',
linkage='ward', pooling_func=np.mean):
linkage='ward', pooling_func=np.mean,
distance_threshold=None):
super().__init__(
n_clusters=n_clusters, memory=memory, connectivity=connectivity,
compute_full_tree=compute_full_tree, linkage=linkage,
affinity=affinity)
affinity=affinity, distance_threshold=distance_threshold)
self.pooling_func = pooling_func

def fit(self, X, y=None, **params):
Expand Down
116 changes: 115 additions & 1 deletion sklearn/cluster/tests/test_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy import sparse
from scipy.cluster import hierarchy

from sklearn.metrics.cluster.supervised import adjusted_rand_score
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_almost_equal
Expand Down Expand Up @@ -597,7 +598,120 @@ def increment(self, *args, **kwargs):

linkage_tree(X, connectivity=connectivity, affinity=fa.increment)

assert_equal(fa.counter, 3)
assert fa.counter == 3


@pytest.mark.parametrize('linkage', ['ward', 'complete', 'average'])
def test_agglomerative_clustering_with_distance_threshold(linkage):
# Check that we obtain the correct number of clusters with
# agglomerative clustering with distance_threshold.
rng = np.random.RandomState(0)
mask = np.ones([10, 10], dtype=np.bool)
n_samples = 100
X = rng.randn(n_samples, 50)
connectivity = grid_to_graph(*mask.shape)
# test when distance threshold is set to 10
distance_threshold = 10
for conn in [None, connectivity]:
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=distance_threshold,
connectivity=conn, linkage=linkage)
clustering.fit(X)
clusters_produced = clustering.labels_
num_clusters_produced = len(np.unique(clustering.labels_))
# test if the clusters produced match the point in the linkage tree
# where the distance exceeds the threshold
tree_builder = _TREE_BUILDERS[linkage]
children, n_components, n_leaves, parent, distances = \
tree_builder(X, connectivity=conn, n_clusters=None,
return_distance=True)
num_clusters_at_threshold = np.count_nonzero(
distances >= distance_threshold) + 1
# test number of clusters produced
assert num_clusters_at_threshold == num_clusters_produced
# test clusters produced
clusters_at_threshold = _hc_cut(n_clusters=num_clusters_produced,
children=children,
n_leaves=n_leaves)
assert np.array_equiv(clusters_produced,
clusters_at_threshold)


def test_small_distance_threshold():
rng = np.random.RandomState(0)
n_samples = 10
X = rng.randint(-300, 300, size=(n_samples, 3))
# this should result in all data in their own clusters, given that
# their pairwise distances are bigger than .1 (which may not be the case
# with a different random seed).
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=1.,
linkage="single").fit(X)
# check that the pairwise distances are indeed all larger than .1
all_distances = pairwise_distances(X, metric='minkowski', p=2)
np.fill_diagonal(all_distances, np.inf)
assert np.all(all_distances > .1)
assert clustering.n_clusters_ == n_samples


def test_cluster_distances_with_distance_threshold():
rng = np.random.RandomState(0)
n_samples = 100
X = rng.randint(-10, 10, size=(n_samples, 3))
# check the distances within the clusters and with other clusters
distance_threshold = 4
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=distance_threshold,
linkage="single").fit(X)
labels = clustering.labels_
D = pairwise_distances(X, metric="minkowski", p=2)
# to avoid taking the 0 diagonal in min()
np.fill_diagonal(D, np.inf)
for label in np.unique(labels):
in_cluster_mask = labels == label
max_in_cluster_distance = (D[in_cluster_mask][:, in_cluster_mask]
.min(axis=0).max())
min_out_cluster_distance = (D[in_cluster_mask][:, ~in_cluster_mask]
.min(axis=0).min())
# single data point clusters only have that inf diagonal here
if in_cluster_mask.sum() > 1:
assert max_in_cluster_distance < distance_threshold
assert min_out_cluster_distance >= distance_threshold


@pytest.mark.parametrize('linkage', ['ward', 'complete', 'average'])
@pytest.mark.parametrize(('threshold', 'y_true'),
[(0.5, [1, 0]), (1.0, [1, 0]), (1.5, [0, 0])])
def test_agglomerative_clustering_with_distance_threshold_edge_case(
linkage, threshold, y_true):
# test boundary case of distance_threshold matching the distance
X = [[0], [1]]
clusterer = AgglomerativeClustering(
n_clusters=None,
distance_threshold=threshold,
linkage=linkage)
y_pred = clusterer.fit_predict(X)
assert adjusted_rand_score(y_true, y_pred) == 1


def test_dist_threshold_invalid_parameters():
X = [[0], [1]]
with pytest.raises(ValueError, match="Exactly one of "):
AgglomerativeClustering(n_clusters=None,
distance_threshold=None).fit(X)

with pytest.raises(ValueError, match="Exactly one of "):
AgglomerativeClustering(n_clusters=2,
distance_threshold=1).fit(X)

X = [[0], [1]]
with pytest.raises(ValueError, match="compute_full_tree must be True if"):
AgglomerativeClustering(n_clusters=None,
distance_threshold=1,
compute_full_tree=False).fit(X)


def test_n_components_deprecation():
Expand Down

0 comments on commit 29134a0

Please sign in to comment.