Skip to content

Commit

Permalink
Enable angular/cosine RP forest initialisation (issue #15)
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Nov 19, 2017
1 parent 9a5bfa7 commit c98e882
Showing 1 changed file with 44 additions and 10 deletions.
54 changes: 44 additions & 10 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,29 @@ def random_projection_split(data, indices, rng_state):
'left_child', 'right_child'])


def make_tree(data, indices, rng_state, leaf_size=30):
def make_tree(data, indices, rng_state, leaf_size=30, angular=False):
# Make a tree recursively until we get below the leaf size
if indices.shape[0] > leaf_size:
left_indices, right_indices = random_projection_split(data,
indices,
rng_state)
left_node = make_tree(data, left_indices, rng_state, leaf_size)
right_node = make_tree(data, right_indices, rng_state, leaf_size)
if angular:
(left_indices,
right_indices) = random_projection_cosine_split(data,
indices,
rng_state)
else:
left_indices, right_indices = random_projection_split(data,
indices,
rng_state)
left_node = make_tree(data,
left_indices,
rng_state,
leaf_size,
angular)
right_node = make_tree(data,
right_indices,
rng_state,
leaf_size,
angular)

node = RandomProjectionTreeNode(indices, False, left_node, right_node)
else:
node = RandomProjectionTreeNode(indices, True, None, None)
Expand Down Expand Up @@ -286,15 +301,16 @@ def heap_push(heap, row, weight, index, flag):
return 1


def rptree_leaf_array(data, n_neighbors, rng_state, n_trees=10):
def rptree_leaf_array(data, n_neighbors, rng_state, n_trees=10, angular=False):
leaves = []
try:
leaf_size = max(10, n_neighbors)
for t in range(n_trees):
tree = make_tree(data,
np.arange(data.shape[0]),
rng_state,
leaf_size=leaf_size)
leaf_size=leaf_size,
angular=angular)
leaves += get_leaves(tree)

leaf_array = -1 * np.ones([len(leaves), leaf_size], dtype=np.int64)
Expand Down Expand Up @@ -446,7 +462,10 @@ def smooth_knn_dist(distances, k, n_iter=128):


@numba.jit(parallel=True)
def fuzzy_simplicial_set(X, n_neighbors, random_state, metric, metric_kwds={}, verbose=False):
def fuzzy_simplicial_set(X, n_neighbors, random_state,
metric, metric_kwds={}, angular=False,
verbose=False):

rows = np.zeros((X.shape[0] * n_neighbors), dtype=np.int64)
cols = np.zeros((X.shape[0] * n_neighbors), dtype=np.int64)
vals = np.zeros((X.shape[0] * n_neighbors), dtype=np.float64)
Expand All @@ -458,11 +477,16 @@ def fuzzy_simplicial_set(X, n_neighbors, random_state, metric, metric_kwds={}, v
else:
raise ValueError('Metric is neither callable, nor a recognised string')

if metric in ('cosine', 'correlation', 'dice', 'jaccard'):
angular=True

rng_state = random_state.randint(INT32_MIN, INT32_MAX, 3).astype(np.int64)

metric_nn_descent = make_nn_descent(distance_func,
tuple(metric_kwds.values()))
leaf_array = rptree_leaf_array(X, n_neighbors, rng_state, n_trees=10)
leaf_array = rptree_leaf_array(X, n_neighbors,
rng_state, n_trees=10,
angular=angular)
tmp_indices, knn_dists = metric_nn_descent(X,
n_neighbors,
rng_state,
Expand Down Expand Up @@ -850,6 +874,13 @@ class UMAP(BaseEstimator):
Arguments to pass on to the metric, such as the ``p`` value for
Minkowski distance.
angular_rp_forest: bool (optional, default False)
Whether to use an angular random projection forest to initialise
the approximate nearest neighbor search. This can be faster, but is
mostly on useful for metric that use an angular style distance such
as cosine, correlation etc. In the case of those metrics angular forests
will be chosen automatically.
verbose: bool (optional, default False)
Controls verbosity of logging.
"""
Expand All @@ -868,6 +899,7 @@ def __init__(self,
b=None,
random_state=None,
metric_kwds={},
angular_rp_forest=False,
verbose=False
):

Expand All @@ -884,6 +916,7 @@ def __init__(self,
self.spread = spread
self.min_dist = min_dist
self.random_state = random_state
self.angular_rp_forest = angular_rp_forest
self.verbose = verbose

if metric in dist.named_distances:
Expand Down Expand Up @@ -935,6 +968,7 @@ def fit(self, X, y=None):
random_state,
self._metric,
self.metric_kwds,
self.angular_rp_forest,
self.verbose
)

Expand Down

0 comments on commit c98e882

Please sign in to comment.