Skip to content

Commit

Permalink
Further hdbscan coverage improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Aug 20, 2016
1 parent 81cc4ce commit 916b1cd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion hdbscan/hdbscan_.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _hdbscan_generic(X, min_samples=5, alpha=1.0,
if gen_min_span_tree:
result_min_span_tree = min_spanning_tree.copy()
for index, row in enumerate(result_min_span_tree[1:], 1):
candidates = np.where(np.isclose(mutual_reachability_[row[1]], row[2]))[0]
candidates = np.where(np.isclose(mutual_reachability_[int(row[1])], row[2]))[0]
candidates = np.intersect1d(candidates, min_spanning_tree[:index, :2].astype(int))
candidates = candidates[candidates != row[1]]
assert (len(candidates) > 0)
Expand Down
13 changes: 11 additions & 2 deletions hdbscan/tests/test_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_hdbscan_prims_balltree():
metric='cosine')

def test_hdbscan_boruvka_kdtree():
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='boruvka_kdtree')
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='boruvka_kdtree', leaf_size=5)
n_clusters_1 = len(set(labels)) - int(-1 in labels)
assert_equal(n_clusters_1, n_clusters)

Expand All @@ -172,7 +172,7 @@ def test_hdbscan_boruvka_kdtree():
metric='russelrao')

def test_hdbscan_boruvka_balltree():
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='boruvka_balltree')
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='boruvka_balltree', leaf_size=5)
n_clusters_1 = len(set(labels)) - int(-1 in labels)
assert_equal(n_clusters_1, n_clusters)

Expand All @@ -186,6 +186,15 @@ def test_hdbscan_boruvka_balltree():
algorithm='boruvka_balltree',
metric='cosine')

def test_hdbscan_generic():
labels, p, persist, ctree, ltree, mtree = hdbscan(X, algorithm='generic')
n_clusters_1 = len(set(labels)) - int(-1 in labels)
assert_equal(n_clusters_1, n_clusters)

labels = HDBSCAN(algorithm='generic', gen_min_span_tree=True).fit(X).labels_
n_clusters_2 = len(set(labels)) - int(-1 in labels)
assert_equal(n_clusters_2, n_clusters)


def test_hdbscan_high_dimensional():
H, y = make_blobs(n_samples=50, random_state=0, n_features=64)
Expand Down

0 comments on commit 916b1cd

Please sign in to comment.