### Approximate nearest neighbor search based on the NN-descent method.
**Paper:**  
Dong, Wei, Charikar Moses, and Kai Li. "Efficient k-nearest neighbor graph construction for generic similarity measures." Proceedings of the 20th international conference on World wide web. ACM, 2011.  
https://www.cs.princeton.edu/cass/papers/www11.pdf  

**Code:** https://github.com/lmcinnes/pynndescent  

Documentation of the NNDescent class can be found here:  
https://github.com/lmcinnes/pynndescent/blob/master/pynndescent/pynndescent_.py#L462

Some important parameters are listed below:  
`n_neighbors`:  
Number of neighbors to use in the KNN graph data structure. Default is 15. 
Larger values will result in more accurate search results at the cost of computation time. In the paper they 
show that a larger number of neighbors is needed as the intrinsic dimension of the data increases. They used 
a value of 50 for some data sets with high intrinsic dimensionality.

`rho`:  
The sample rate parameter of the algorithm. In the paper, they use a value of `1.0` to get high recall and 
`0.5` for faster results with lower recall. The default value is `0.5`.

`metric`:  
The metric to use for computing distances. Choose from the large list of available metrics or define 
a custom metric as a callable function that is numba JIT compiled.

`n_trees`:  
Number of trees in the random projection forest. A larger value will result in more accurate neighbor 
computation at the cost of performance. The default of None means a value will be chosen based on the 
size of the data. The default value is given by `5 + int(round(data.shape[0] ** 0.5 / 20.0))`. 

`n_jobs`:  
The number of parallel jobs to run for the neighborhood index construction. Default value is `None` 
which sets it to 1. To use all the available processors, set it to `-1`.

For best results, the parameters `n_neighbors` and `rho` should be tuned using grid search.

In [1]:
import numpy as np
from pprint import pprint
from pynndescent import NNDescent
from sklearn.neighbors import NearestNeighbors
from multiprocessing import cpu_count

In [2]:
num_proc = max(cpu_count() - 2, 1)
seed_rng = np.random.randint(1, high=1001)
np.random.seed(seed_rng)
N = 100
d = 5
data = np.random.randn(N, d)
n_trees_def = 5 + int(np.round((data.shape[0] ** 0.5) / 20.0))
params = {
    'metric': 'euclidean', 
    'n_neighbors': 20, 
    'rho': 0.5,
    'n_trees': None,
    'random_state': seed_rng, 
    'n_jobs': num_proc, 
    'verbose': True
}

In [3]:
%time
index = NNDescent(data, **params)
pprint(index.__dict__.keys())
nn_indices_data, nn_distances_data = index._neighbor_graph
print(nn_indices_data.shape, nn_distances_data.shape)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 5.25 µs
Thu Nov 14 03:36:15 2019 Building RP forest with 5 trees
Thu Nov 14 03:36:16 2019 parallel NN descent for 7 iterations
	 0  /  7
	 1  /  7
dict_keys(['n_trees', 'n_neighbors', 'metric', 'metric_kwds', 'leaf_size', 'prune_level', 'max_candidates', 'low_memory', 'n_iters', 'delta', 'rho', 'dim', 'verbose', '_raw_data', 'tree_init', '_dist_args', 'random_state', '_distance_func', '_angular_trees', 'rng_state', '_rp_forest', '_is_sparse', '_neighbor_graph'])
(100, 20) (100, 20)


In [4]:
# Find the 5 nearest neighbors of query points
k = 5
x = np.random.randn(5, d)
nn_indices, nn_distances = index.query(x, k=k)
# `nn_indices` should be an array with the index of the nearest neighbors corresponding to each query point.
# Suppose `x` has shape `(m, d)`, then `nn_indices` will have shape `(m, k)`.
print(nn_indices)

# `nn_distances` has the same shape as `nn_indices` and it has the corresponding distances
# print(nn_distances)

[[32 13 99 67 14]
 [12 66  3 70 38]
 [ 4 21 96 69 30]
 [70  3 58 17 66]
 [55 58 13 44 68]]


In [5]:
# Use brute force nearest neighbor search for comparison with the ANN method
neigh = NearestNeighbors(n_neighbors=k, algorithm='brute', p=2, n_jobs=num_proc)
neigh.fit(data)
_, nn_indices_true = neigh.kneighbors(x)
print(nn_indices_true)

[[32 13 99 67 14]
 [12 66  3 70 38]
 [ 4 21 96 69 30]
 [70  3 58 17 66]
 [55 58 13 44 68]]


In [6]:
# Generate more complex data using a mixture of factor analyzers (MFA) model
from generate_data import MFA_model

# Define the MFA model
n_components = 10
dim = 100
dim_latent = 2
dim_latent_range = (10, 20)
model = MFA_model(n_components, dim, dim_latent_range=dim_latent_range, seed_rng=seed_rng)

# Generate data from the model
N = 1000
N_test = 100
k = 5
data, labels = model.generate_data(N)
data_test, labels_test = model.generate_data(N_test)

In [7]:
%time
# Construct the ANN index
params = {
    'metric': 'euclidean', 
    'n_neighbors': 20,
    'rho': 0.5,
    'n_trees': None,
    'random_state': seed_rng, 
    'n_jobs': num_proc, 
    'verbose': True
}
index = NNDescent(data, **params)

CPU times: user 2 µs, sys: 1e+03 ns, total: 3 µs
Wall time: 5.25 µs
Thu Nov 14 03:36:21 2019 Building RP forest with 7 trees
Thu Nov 14 03:36:21 2019 parallel NN descent for 10 iterations
	 0  /  10
	 1  /  10
	 2  /  10
	 3  /  10


In [8]:
%time
# Construct the exact KNN graph
neigh = NearestNeighbors(n_neighbors=k, algorithm='brute', p=2, n_jobs=num_proc)
neigh.fit(data)

CPU times: user 14 µs, sys: 1 µs, total: 15 µs
Wall time: 4.77 µs


NearestNeighbors(algorithm='brute', leaf_size=30, metric='minkowski',
         metric_params=None, n_jobs=10, n_neighbors=5, p=2, radius=1.0)

In [9]:
%time
# Query the ANN index and compare with the exact nearest neighbors
nn_indices, _ = index.query(data_test, k=k)

CPU times: user 1e+03 ns, sys: 0 ns, total: 1e+03 ns
Wall time: 4.05 µs


In [10]:
%time
# Query the exact nearest neighbors
_, nn_indices_true = neigh.kneighbors(data_test)

CPU times: user 1 µs, sys: 0 ns, total: 1 µs
Wall time: 3.81 µs


In [11]:
# Calculate the recall of the ANN method, i.e. the fraction of neearest neighbors that are correctly retrieved,
# averaged over all the query points
arr1 = np.array(nn_indices, dtype=np.int)
arr2 = np.array(nn_indices_true, dtype=np.int)
recall_per_point = (1. / k) * np.sum(arr1 == arr2, axis=1)
recall = np.sum(recall_per_point) / N_test
print("Average recall over the query points = {}".format(recall))

Average recall over the query points = 1.0
