# Orchard's algorithm for fast nearest neighbour calculation

### Imports

In [1]:
import numpy as np
from random import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from datetime import datetime
import operator
import functools

from fast_nn import OrchardNN
from sklearn.neighbors import NearestNeighbors

### Utility functions

In [2]:
def random_point(dimensions=2):
    return np.random.rand(dimensions)

def dist(x, y, _norm=np.linalg.norm):
    return _norm(x - y)


## Testing

### Create dummy data

In [3]:
# Create random points
dimensions = 2
num_candidates = 10000
num_queries = 10000

points = np.asarray([random_point(dimensions=dimensions) for n in range(num_candidates)])
query_points = np.asarray([random_point(dimensions=dimensions) for n in range(num_queries)])

query_point = query_points[0]

### Initialise Orchard's method, and precompute distance pairs

In [None]:
start = datetime.now()
orchard = OrchardNN(points, dist)
end = datetime.now()
print("pre-computation time: {}".format(end - start))

### Single query

In [None]:
# find nearest neighbour to query point
start = datetime.now()
single_neighbour = orchard.nearest_neighbour(query_point, verbose=False)
end = datetime.now()
print("Time for 1 query: {}".format(end-start))

### Plot results

In [None]:
fig = plt.figure(figsize=(20,20))

if dimensions == 3:
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(points[:,0], points[:,1], points[:,2], c='b')
    ax.scatter(single_neighbour[0], single_neighbour[1], single_neighbour[2], c='r', label='nearest neighbour')
    ax.scatter(query_point[0], query_point[1], query_point[2], c='g', label='query')
    ax.legend()
else:
    plt.scatter(points[:,0], points[:,1], c='b')
    plt.scatter(single_neighbour[0], single_neighbour[1], c='r', label='nearest neighbour')
    plt.scatter(query_point[0], query_point[1], c='g', label='query')
    plt.legend()
    
plt.show()
plt.close()

In [None]:
break

### Initialise SKLearn KNN classifier

In [None]:
# Fit sklearn KNN classifier
start = datetime.now()
sk_knn = NearestNeighbors(n_neighbors=1, algorithm='brute', metric='euclidean')
sk_knn.fit(points) 
end = datetime.now()
print("Time to fit SKLearn KNN: {}".format(end-start))

### Single query

In [None]:
start = datetime.now()
sk_neighbour_distance, sk_neighbour = sk_knn.kneighbors([query_point])
sk_knn.fit(points) 
end = datetime.now()

verbose = False

if verbose:
    print("Query point: {}".format(query_point))
    print("Nearest neighbour to query is point {}: {}".format(sk_neighbour[0][0], orchard.candidates[sk_neighbour][0][0]))
    print("Distance: {}".format(sk_neighbour_distance[0]))
print("Time for 1 query: {}".format(end-start))


## Compare for large number of queries

### Orchard

In [None]:
orch_times = []

for q in query_points:
    start = datetime.now()
    neighbour = orchard.nearest_neighbour(q, verbose=False)
    end = datetime.now()
    orch_times.append(end-start)
print("Total time for {} queries: {}".format(len(query_points), functools.reduce(operator.add, orch_times)))

### SKLearn

In [None]:
sk_times = []

for q in query_points:
    start = datetime.now()
    sk_neighbour_distance, sk_neighbour = sk_knn.kneighbors([q])
    end = datetime.now()
    sk_times.append(end-start)
print("Total time for {} queries: {}".format(len(query_points), functools.reduce(operator.add, sk_times)))