In [1]:
using BenchmarkTools
using Distances
using MLDatasets
using NNDescent
using NNDescent: brute_knn, brute_search, NNTuple

In [2]:
recall(nn, true_nn) = sum(_recall(nn[:,i], true_nn[:,i]) for i in 1:size(nn,2))/size(nn,2)
_recall(π, πₜ) = length(intersect(π, πₜ))/length(πₜ)

_recall (generic function with 1 method)

In [3]:
train_x, train_y = FashionMNIST.traindata()
test_x, test_y = FashionMNIST.testdata()
train_x = reshape(train_x, size(train_x)[1]*size(train_x)[2], size(train_x)[3])
test_x = reshape(test_x, size(test_x)[1]*size(test_x)[2], size(test_x)[3])

data = [convert.(Float32, train_x[:,i]) for i = 1:size(train_x, 2)]
data = data[1:5000]
queries = [convert.(Float32, test_x[:,i]) for i = 1:size(test_x, 2)]
queries = queries[1:500];

In [4]:
knn_graph = DescentGraph(data, 10)
nn = getindex.(knn_graph.graph, 1);

In [5]:
brute_graph = brute_knn(data, Euclidean(), 10)
true_nn = getindex.(brute_graph, 1);

In [6]:
recall(nn, true_nn)

0.9968400000000002

In [7]:
true_idx, true_dist = brute_search(data, queries, 5, Euclidean());

In [8]:
idx, dist = search(knn_graph, queries, 5, 100);

In [9]:
recall(idx, true_idx)

0.8800000000000003

In [10]:
q_per_sec = length(queries)/(@belapsed search(knn_graph, queries, 5, 100))

133.0004222109043

In [11]:
@benchmark DescentGraph(data, 10)

BenchmarkTools.Trial: 
  memory estimate:  534.20 MiB
  allocs estimate:  11924267
  --------------
  minimum time:     1.034 s (11.75% GC)
  median time:      1.075 s (13.73% GC)
  mean time:        1.077 s (14.39% GC)
  maximum time:     1.140 s (17.22% GC)
  --------------
  samples:          5
  evals/sample:     1

In [12]:
@benchmark search(knn_graph, queries, 5, 20)

BenchmarkTools.Trial: 
  memory estimate:  261.46 MiB
  allocs estimate:  6426462
  --------------
  minimum time:     443.178 ms (6.93% GC)
  median time:      471.562 ms (10.70% GC)
  mean time:        468.908 ms (10.30% GC)
  maximum time:     497.649 ms (11.62% GC)
  --------------
  samples:          11
  evals/sample:     1

In [13]:
@benchmark brute_knn(data, Euclidean(), 10)

BenchmarkTools.Trial: 
  memory estimate:  1.98 GiB
  allocs estimate:  70208364
  --------------
  minimum time:     8.433 s (10.10% GC)
  median time:      8.433 s (10.10% GC)
  mean time:        8.433 s (10.10% GC)
  maximum time:     8.433 s (10.10% GC)
  --------------
  samples:          1
  evals/sample:     1

In [14]:
@benchmark brute_search(data, queries, 5, Euclidean())

BenchmarkTools.Trial: 
  memory estimate:  151.08 MiB
  allocs estimate:  4799945
  --------------
  minimum time:     715.664 ms (10.54% GC)
  median time:      749.121 ms (17.04% GC)
  mean time:        746.433 ms (16.26% GC)
  maximum time:     760.739 ms (18.01% GC)
  --------------
  samples:          7
  evals/sample:     1