/
search.jl
115 lines (100 loc) · 3.81 KB
/
search.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
search(graph, queries, n_neighbors; max_candidates) -> indices, distances
Search the kNN `graph` for the nearest neighbors of the points in `queries`.
`max_candidates` controls how large the candidate queue should be (min `n_neighbors`);
larger values increase accuracy at the cost of speed.
"""
function search(graph::G,
queries::AbstractVector,
n_neighbors::Integer;
max_candidates=max(n_neighbors, 20),
) where {V, U, G <: ApproximateKNNGraph{V, U}}
length(queries) ≥ 1 || error("queries must have at least 1 element")
n_neighbors ≥ 1 || error("n_neighbors must be at least 1")
max_candidates ≥ 5 || error("max_candidates must be at least 5")
data = graph.data
metric = graph.metric
# lists of candidates, sorted by distance
candidates = [BinaryMaxHeap{Tuple{U, V, Bool}}() for _ in 1:length(queries)]
# a set of seen candidates per thread
seen_sets = [BitVector(undef, length(data)) for _ in 1:Threads.nthreads()]
Threads.@threads for i in eachindex(queries)
# zero out seen
seen = seen_sets[Threads.threadid()]
seen .= false
# initialize with random
init_candidates!(candidates[i], seen, graph, queries[i], max_candidates)
while true
next_candidate = get_next_candidate!(candidates[i])
if isnothing(next_candidate)
break
end
for v in outneighbors(graph, next_candidate[2])
if !seen[v]
dist = evaluate(metric, queries[i], data[v])
if dist ≤ top(candidates[i])[1]
pop!(candidates[i]) # pop maximum
push!(candidates[i], (dist, v, false))
end
seen[v] = true
end
end
end
end
return deheap_knns(candidates, n_neighbors)
end
function search(graph::ApproximateKNNGraph,
queries::AbstractMatrix,
n_neighbors::Integer;
max_candidates=max(n_neighbors, 20),
)
query_cols = collect(eachcol(queries))
return search(graph, query_cols, n_neighbors; max_candidates=max_candidates)
end
function get_next_candidate!(candidates::BinaryMaxHeap{Tuple{U, V, Bool}}) where {U <: Real, V}
min_idx = -1
min_dist = typemax(U)
for (i, t) in enumerate(candidates.valtree)
if t[1] < min_dist && !t[3]
min_idx = i
min_dist = t[1]
end
end
if min_idx != -1 # found an unvisited candidate
dist, node, _ = candidates.valtree[min_idx]
cand = (dist, node, true)
candidates.valtree[min_idx] = cand # mark visited
return cand
end
return nothing
end
function init_candidates!(candidates, seen, graph, query, max_candidates)
for v in KNNGraphs.sample_neighbors(nv(graph), max_candidates)
dist = evaluate(graph.metric, query, graph.data[v])
push!(candidates, (dist, v, false))
seen[v] = true
end
return candidates
end
"""
Remove the `k` nearest neighbors from each heap in `knn_heaps`.
Return two k x length(knn_heaps) arrays for the indices and
distances to each point's kNN.
"""
function deheap_knns(heaps::Vector{BinaryMaxHeap{Tuple{U, V, Bool}}}, k) where {U, V}
ids = Array{V}(undef, (k, length(heaps)))
dists = Array{U}(undef, (k, length(heaps)))
for i in 1:length(heaps)
len = length(heaps[i])
for j in 1:len
# NOTE: these are max heaps, so we only want the last k
node_dist, node_idx, _ = pop!(heaps[i])
neighbor_idx = 1 + len - j
if neighbor_idx <= k
ids[neighbor_idx, i] = node_idx
dists[neighbor_idx, i] = node_dist
end
end
end
return ids, dists
end