In [1]:
import numpy as np
from scipy.spatial import KDTree as KDTree

# the data

In [2]:
data = np.random.random((1000, 3))
query = np.array([0.5, 0.5, 0.5])

# python version

In [3]:
def knn_py(data, query, k=3):
    tree = KDTree(data)
    return tree.query(query, k=k)

In [4]:
def knn_precomputed_py(tree, query, k=3):
    return tree.query(query, k=k)

In [5]:
results_py = knn_py(data, query)

tree_py = KDTree(data)
results_precomputed_py = knn_precomputed_py(tree_py, query)

print(results_py[0], results_py[1])
print(results_precomputed_py[0], results_precomputed_py[1])
print(np.array_equal(results_py, results_precomputed_py))

[0.06855726 0.08288424 0.08613393] [921 207 511]
[0.06855726 0.08288424 0.08613393] [921 207 511]
True


In [6]:
%timeit knn_py(data, query)

165 µs ± 368 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [7]:
%timeit knn_precomputed_py(tree_py, query)

16.8 µs ± 552 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


# using juliacall

In [8]:
from juliacall import Main as jl
jl.seval("using NearestNeighbors")

Querying Julia versions from https://julialang-s3.julialang.org/bin/versions.json
Found Julia 1.7.1 at 'julia'


In [9]:
def knn_jl(data, query, k=3):
    return jl.knn(jl.KDTree(data), query, k, True)

In [10]:
def knn_precomputed_jl(kdtree, query, k=3):
    return jl.knn(kdtree, query, k, True)

In [11]:
data_t = data.transpose() # Not timing this part
results_jl = knn_jl(data_t, query)

tree_jl = jl.KDTree(data_t)
results_precomputed_jl = knn_precomputed_jl(tree_jl, query)

print(results_jl[1], results_jl[0])
print(results_precomputed_jl[1], results_precomputed_jl[0])
print(np.array_equal(results_jl, results_precomputed_jl))

[0.0685572642370606, 0.08288424104821678, 0.08613393269163352] [922, 208, 512]
[0.0685572642370606, 0.08288424104821678, 0.08613393269163352] [922, 208, 512]
True


In [12]:
%timeit knn_jl(data_t, query)

228 µs ± 5.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [13]:
%timeit knn_precomputed_jl(tree_jl, query)

59.5 µs ± 3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
