Skip to content

Commit

Permalink
Fix arg order and keyword the args; add tests for low_memory
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Sep 11, 2019
1 parent 439db74 commit 494024f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
39 changes: 39 additions & 0 deletions umap/tests/test_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,25 @@ def test_nn_descent_neighbor_accuracy():
"NN-descent did not get 99% " "accuracy on nearest neighbors",
)

def test_nn_descent_neighbor_accuracy_low_memory():
knn_indices, knn_dists, _ = nearest_neighbors(
nn_data, 10, "euclidean", {}, False, np.random, low_memory=True
)

tree = KDTree(nn_data)
true_indices = tree.query(nn_data, 10, return_distance=False)

num_correct = 0.0
for i in range(nn_data.shape[0]):
num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))

percent_correct = num_correct / (nn_data.shape[0] * 10)
assert_greater_equal(
percent_correct,
0.89,
"NN-descent did not get 99% " "accuracy on nearest neighbors",
)


def test_angular_nn_descent_neighbor_accuracy():
knn_indices, knn_dists, _ = nearest_neighbors(
Expand Down Expand Up @@ -336,6 +355,26 @@ def test_sparse_nn_descent_neighbor_accuracy():
)


def test_sparse_nn_descent_neighbor_accuracy_low_memory():
knn_indices, knn_dists, _ = nearest_neighbors(
sparse_nn_data, 20, "euclidean", {}, False, np.random, low_memory=True
)

tree = KDTree(sparse_nn_data.todense())
true_indices = tree.query(sparse_nn_data.todense(), 10, return_distance=False)

num_correct = 0.0
for i in range(sparse_nn_data.shape[0]):
num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))

percent_correct = num_correct / (sparse_nn_data.shape[0] * 10)
assert_greater_equal(
percent_correct,
0.90,
"Sparse NN-descent did not get " "99% accuracy on nearest " "neighbors",
)


def test_sparse_angular_nn_descent_neighbor_accuracy():
knn_indices, knn_dists, _ = nearest_neighbors(
sparse_nn_data, 20, "cosine", {}, True, np.random
Expand Down
8 changes: 4 additions & 4 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ def nearest_neighbors(
X.shape[0],
n_neighbors,
rng_state,
distance_func,
tuple(metric_kwds.values()),
max_candidates=60,
dist=distance_func,
dist_args=tuple(metric_kwds.values()),
low_memory=low_memory,
rp_tree_init=True,
leaf_array=leaf_array,
Expand All @@ -349,9 +349,9 @@ def nearest_neighbors(
X,
n_neighbors,
rng_state,
distance_func,
tuple(metric_kwds.values()),
max_candidates=60,
dist=distance_func,
dist_args=tuple(metric_kwds.values()),
low_memory=low_memory,
rp_tree_init=True,
leaf_array=leaf_array,
Expand Down

0 comments on commit 494024f

Please sign in to comment.