In [1]:
using BenchmarkTools
using Distances
using MLDatasets
using NearestNeighborDescent
using NearestNeighborDescent: brute_knn, brute_search

In [2]:
recall(nn, true_nn) = sum(_recall(nn[:,i], true_nn[:,i]) for i in 1:size(nn,2))/size(nn,2)
_recall(π, πₜ) = length(intersect(π, πₜ))/length(πₜ)

_recall (generic function with 1 method)

In [3]:
function get_fmnist()
    train_x, train_y = FashionMNIST.traindata()
    test_x, test_y = FashionMNIST.testdata()
    train_x = reshape(train_x, size(train_x)[1]*size(train_x)[2], size(train_x)[3])
    test_x = reshape(test_x, size(test_x)[1]*size(test_x)[2], size(test_x)[3])

    data = [convert.(Float32, train_x[:,i]) for i = 1:size(train_x, 2)]
    data = data[1:5000]
    queries = [convert.(Float32, test_x[:,i]) for i = 1:size(test_x, 2)]
    queries = queries[1:500]
    return data, queries
end

get_fmnist (generic function with 1 method)

In [4]:
function bmark(name, data, queries, metric)
    @show name
    knn_graph = DescentGraph(data, 10, metric)
    nn = getindex.(knn_graph.graph, 1)
    brute_graph = brute_knn(data, metric, 10)
    true_nn = getindex.(brute_graph, 1)
    @show recall(nn, true_nn)
    true_idx, true_dist = brute_search(data, queries, 10, metric)
    idx, dist = search(knn_graph, queries, 10, 70)
    @show recall(idx, true_idx)
    q_per_sec = length(queries)/(@belapsed search($knn_graph, $queries, 10, 70))
    @show q_per_sec
end

bmark (generic function with 1 method)

In [5]:
fm_data, fm_queries = get_fmnist()
bmark("Fashion MNIST", fm_data, fm_queries, Euclidean())

name = "Fashion MNIST"
recall(nn, true_nn) = 0.9966799999999998
recall(idx, true_idx) = 0.9099999999999993
q_per_sec = 306.35653137860936


306.35653137860936

In [6]:
# Cosine Tests 
rn_data = [rand(800) for _ in 1:5000]
rn_queries = [rand(800) for _ in 1:500]
bmark("Cosine Random", rn_data, rn_queries, CosineDist())

name = "Cosine Random"
recall(nn, true_nn) = 0.6454999999999979
recall(idx, true_idx) = 0.8749999999999991
q_per_sec = 192.44264434039687


192.44264434039687

In [7]:
# Hamming Tests
ham_data = [rand([0, 1], 800) for _ in 1:5000]
ham_queries = [rand([0, 1], 800) for _ in 1:500]
bmark("Hamming Random", ham_data, ham_queries, Hamming())

name = "Hamming Random"
recall(nn, true_nn) = 0.13036000000000472
recall(idx, true_idx) = 0.8429999999999988
q_per_sec = 193.82943617775595


193.82943617775595