In [1]:
using BenchmarkTools
using Distances
using MLDatasets
using NNDescent
using NNDescent: brute_knn, brute_search, NNTuple

In [2]:
recall(nn, true_nn) = sum(_recall(nn[:,i], true_nn[:,i]) for i in 1:size(nn,2))/size(nn,2)
_recall(π, πₜ) = length(intersect(π, πₜ))/length(πₜ)

_recall (generic function with 1 method)

In [3]:
train_x, train_y = FashionMNIST.traindata()
test_x, test_y = FashionMNIST.testdata()
train_x = reshape(train_x, size(train_x)[1]*size(train_x)[2], size(train_x)[3])
test_x = reshape(test_x, size(test_x)[1]*size(test_x)[2], size(test_x)[3])

data = [convert.(Float32, train_x[:,i]) for i = 1:size(train_x, 2)]
data = data[1:5000]
queries = [convert.(Float32, test_x[:,i]) for i = 1:size(test_x, 2)]
queries = queries[1:500];

In [4]:
knn_graph = DescentGraph(data, 10)
nn = getindex.(knn_graph.graph, 1);

In [5]:
brute_graph = brute_knn(data, Euclidean(), 10)
true_nn = getindex.(brute_graph, 1);

In [6]:
recall(nn, true_nn)

0.9963999999999998

In [7]:
true_idx, true_dist = brute_search(data, queries, 5, Euclidean());

In [8]:
idx, dist = search(knn_graph, queries, 5, 10);

In [9]:
recall(idx, true_idx)

0.6900000000000004

In [10]:
q_per_sec = length(queries)/(@belapsed search(knn_graph, queries, 5, 10))

2390.546897844664

In [15]:
using Profile
@profile (@belapsed search(knn_graph, queries, 5, 10))
Profile.print(format=:flat)

 Count File                        Line Function                               
 10946 ./In[15]                       2 top-level scope                        
    20 ./abstractarray.jl            75 axes                                   
    21 ./abstractarray.jl            93 axes1                                  
   230 ./abstractarray.jl           823 copymutable                            
    21 ./abstractarray.jl           214 eachindex                              
   132 ./abstractarray.jl           635 empty                                  
     1 ./abstractarray.jl          1835 foreach                                
  1072 ./abstractarray.jl           905 getindex                               
    21 ./abstractarray.jl           249 lastindex                              
     2 ./abstractarray.jl           861 pointer                                
   126 ./abstractarray.jl           573 similar                                
    15 ./abstractarray.jl           617 

In [11]:
@benchmark DescentGraph(data, 10)

BenchmarkTools.Trial: 
  memory estimate:  533.43 MiB
  allocs estimate:  11893486
  --------------
  minimum time:     1.034 s (11.82% GC)
  median time:      1.052 s (13.49% GC)
  mean time:        1.068 s (14.42% GC)
  maximum time:     1.114 s (17.24% GC)
  --------------
  samples:          5
  evals/sample:     1

In [12]:
@benchmark search(knn_graph, queries, 5, 20)

BenchmarkTools.Trial: 
  memory estimate:  258.05 MiB
  allocs estimate:  6348815
  --------------
  minimum time:     442.534 ms (6.98% GC)
  median time:      463.900 ms (10.93% GC)
  mean time:        463.386 ms (10.49% GC)
  maximum time:     483.275 ms (12.34% GC)
  --------------
  samples:          11
  evals/sample:     1

In [13]:
@benchmark brute_knn(data, Euclidean(), 10)

BenchmarkTools.Trial: 
  memory estimate:  1.98 GiB
  allocs estimate:  70208364
  --------------
  minimum time:     8.488 s (10.11% GC)
  median time:      8.488 s (10.11% GC)
  mean time:        8.488 s (10.11% GC)
  maximum time:     8.488 s (10.11% GC)
  --------------
  samples:          1
  evals/sample:     1

In [14]:
@benchmark brute_search(data, queries, 5, Euclidean())

BenchmarkTools.Trial: 
  memory estimate:  151.08 MiB
  allocs estimate:  4799945
  --------------
  minimum time:     717.771 ms (10.66% GC)
  median time:      752.276 ms (17.11% GC)
  mean time:        749.349 ms (16.37% GC)
  maximum time:     764.047 ms (18.06% GC)
  --------------
  samples:          7
  evals/sample:     1