In [1]:
using BenchmarkTools
using Distances
using MLDatasets
using NNDescent
using NNDescent: brute_knn, brute_search, NNTuple

In [2]:
recall(nn, true_nn) = sum(_recall(nn[:,i], true_nn[:,i]) for i in 1:size(nn,2))/size(nn,2)
_recall(π, πₜ) = length(intersect(π, πₜ))/length(πₜ)

_recall (generic function with 1 method)

In [3]:
train_x, train_y = FashionMNIST.traindata()
test_x, test_y = FashionMNIST.testdata()
train_x = reshape(train_x, size(train_x)[1]*size(train_x)[2], size(train_x)[3])
test_x = reshape(test_x, size(test_x)[1]*size(test_x)[2], size(test_x)[3])

data = [convert.(Float32, train_x[:,i]) for i = 1:size(train_x, 2)]
data = data[1:5000]
queries = [convert.(Float32, test_x[:,i]) for i = 1:size(test_x, 2)]
queries = queries[1:500];

In [4]:
knn_graph = DescentGraph(data, 10)
nn = getindex.(knn_graph.graph, 1);

In [5]:
brute_graph = brute_knn(data, Euclidean(), 10)
true_nn = getindex.(brute_graph, 1);

In [6]:
recall(nn, true_nn)

0.9967400000000002

In [7]:
true_idx, true_dist = brute_search(data, queries, 5, Euclidean());

In [8]:
idx, dist = search(knn_graph, queries, 5, 10);

In [9]:
recall(idx, true_idx)

0.6680000000000005

In [10]:
q_per_sec = length(queries)/(@belapsed search(knn_graph, queries, 5, 10))

2320.4369650119456

In [22]:
using Profile
Profile.clear()
@profile search(knn_graph, queries, 10, 100)
Profile.print(combine=true,maxdepth=20,noisefloor=2.0)

4     ./sort.jl:452; sort!(::Array{NNTuple{Int64,Float...
15    ./sort.jl:465; sort!(::Array{NNTuple{Int64,Float...
18    ./sort.jl:544; sort!(::Array{NNTuple{Int64,Float...
1     ./sort.jl:548; sort!(::Array{NNTuple{Int64,Float...
5     ./sort.jl:578; sort!(::Array{NNTuple{Int64,Float...
11760 ./task.jl:259; (::getfield(IJulia, Symbol("##12#...
 11760 ...L02A/src/eventloop.jl:8; eventloop(::ZMQ.Socket)
  11760 ./essentials.jl:696; invokelatest
   11760 ./essentials.jl:697; #invokelatest#1
    11760 ...c/execute_request.jl:65; execute_request(::ZMQ.Socket, :...
     11760 .../SoftGlobalScope.jl:206; softscope_include_string(::Mod...
      11760 ./boot.jl:319; eval
       11760 ./In[22]:3; top-level scope
        11760 ...ile/src/Profile.jl:25; macro expansion
         5996 .../src/nn_descent.jl:133; search(::DescentGraph{Array{...
          5968 ...src/nn_descent.jl:161; unexpanded(::DataStructures....
           2022 ./array.jl:2352; filter
            1866 ./abstractset.jl:336; mapfi

In [12]:
@benchmark DescentGraph(data, 10)

BenchmarkTools.Trial: 
  memory estimate:  533.45 MiB
  allocs estimate:  11895303
  --------------
  minimum time:     1.041 s (12.05% GC)
  median time:      1.079 s (13.88% GC)
  mean time:        1.078 s (14.61% GC)
  maximum time:     1.125 s (17.45% GC)
  --------------
  samples:          5
  evals/sample:     1

In [13]:
@benchmark search(knn_graph, queries, 5, 20)

BenchmarkTools.Trial: 
  memory estimate:  261.09 MiB
  allocs estimate:  6421635
  --------------
  minimum time:     454.121 ms (6.97% GC)
  median time:      478.940 ms (10.80% GC)
  mean time:        477.467 ms (10.46% GC)
  maximum time:     500.144 ms (12.10% GC)
  --------------
  samples:          11
  evals/sample:     1

In [14]:
@benchmark brute_knn(data, Euclidean(), 10)

BenchmarkTools.Trial: 
  memory estimate:  1.98 GiB
  allocs estimate:  70208364
  --------------
  minimum time:     8.502 s (10.16% GC)
  median time:      8.502 s (10.16% GC)
  mean time:        8.502 s (10.16% GC)
  maximum time:     8.502 s (10.16% GC)
  --------------
  samples:          1
  evals/sample:     1

In [15]:
@benchmark brute_search(data, queries, 5, Euclidean())

BenchmarkTools.Trial: 
  memory estimate:  151.08 MiB
  allocs estimate:  4799945
  --------------
  minimum time:     719.649 ms (10.73% GC)
  median time:      755.787 ms (17.22% GC)
  mean time:        752.226 ms (16.47% GC)
  maximum time:     766.788 ms (18.17% GC)
  --------------
  samples:          7
  evals/sample:     1