-
Notifications
You must be signed in to change notification settings - Fork 18
/
utils.jl
94 lines (78 loc) · 2.85 KB
/
utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#=
Utilities used by UMAP.jl
=#
@inline fit_ab(_, __, a, b) = a, b
"""
fit_ab(min_dist, spread, _a, _b) -> a, b
Find a smooth approximation to the membership function of points embedded in ℜᵈ.
This fits a smooth curve that approximates an exponential decay offset by `min_dist`,
returning the parameters `(a, b)`.
"""
function fit_ab(min_dist, spread, ::Nothing, ::Nothing)
ψ(d) = d >= min_dist ? exp(-(d - min_dist)/spread) : 1.
xs = LinRange(0., spread*3, 300)
ys = map(ψ, xs)
@. curve(x, p) = (1. + p[1]*x^(2*p[2]))^(-1)
result = curve_fit(curve, xs, ys, [1., 1.], lower=[0., -Inf])
a, b = result.param
return a, b
end
knn_search(dist_mat, k, metric::Symbol) = knn_search(dist_mat, k, Val(metric))
"""
knn_search(dist_mat, k, :precomputed) -> knns, dists
Find the `k` nearest neighbors of each point in a precomputed distance
matrix.
"""
knn_search(dist_mat, k, ::Val{:precomputed}) = _knn_from_dists(dist_mat, k)
"""
knn_search(X, k, metric) -> knns, dists
Find the `k` nearest neighbors of each point in `X` by `metric`.
"""
function knn_search(X,
k,
metric::SemiMetric)
if size(X, 2) < 4096
return knn_search(X, k, metric, Val(:pairwise))
else
return knn_search(X, k, metric, Val(:approximate))
end
end
# compute all pairwise distances
# return the nearest k to each point v, other than v itself
function knn_search(X::AbstractMatrix{S},
k,
metric,
::Val{:pairwise}) where {S <: Real}
num_points = size(X, 2)
dist_mat = Array{S}(undef, num_points, num_points)
pairwise!(dist_mat, metric, X, dims=2)
# all_dists is symmetric distance matrix
return _knn_from_dists(dist_mat, k)
end
# find the approximate k nearest neighbors using NNDescent
function knn_search(X::AbstractMatrix{S},
k,
metric,
::Val{:approximate}) where {S <: Real}
knngraph = nndescent(X, k, metric)
return knn_matrices(knngraph)
end
function _knn_from_dists(dist_mat::AbstractMatrix{S}, k) where {S <: Real}
knns_ = [partialsortperm(view(dist_mat, :, i), 2:k+1) for i in 1:size(dist_mat, 1)]
dists_ = [dist_mat[:, i][knns_[i]] for i in eachindex(knns_)]
knns = hcat(knns_...)::Matrix{Int}
dists = hcat(dists_...)::Matrix{S}
return knns, dists
end
# combine local fuzzy simplicial sets
@inline function combine_fuzzy_sets(fs_set::AbstractMatrix{T},
set_op_ratio) where {T}
return set_op_ratio .* fuzzy_set_union(fs_set) .+
(one(T) - set_op_ratio) .* fuzzy_set_intersection(fs_set)
end
@inline function fuzzy_set_union(fs_set::AbstractMatrix)
return fs_set .+ fs_set' .- (fs_set .* fs_set')
end
@inline function fuzzy_set_intersection(fs_set::AbstractMatrix)
return fs_set .* fs_set'
end