-
Notifications
You must be signed in to change notification settings - Fork 18
/
find.jl
119 lines (98 loc) · 4.24 KB
/
find.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
compare(s1, s2, dist)
return a similarity score between 0 and 1 for the strings `s1` and
`s2` based on the distance `dist`.
### Examples
```julia-repl
julia> compare("martha", "marhta", Levenshtein())
0.6666666666666667
```
"""
function compare(s1, s2, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.0)
1 - Normalized(dist)(s1, s2; max_dist = 1 - min_score)
end
"""
findnearest(s, itr, dist::Union{StringMetric, StringSemiMetric}) -> (x, index)
`findnearest` returns the value and index of the element of `itr` that has the
lowest distance with `s` according to the distance `dist`.
It is particularly optimized for [`Levenshtein`](@ref) and [`DamerauLevenshtein`](@ref) distances
(as well as their modifications via [`Partial`](@ref), [`TokenSort`](@ref), [`TokenSet`](@ref), or [`TokenMax`](@ref)).
### Examples
```julia-repl
julia> using StringDistances
julia> s = "Newark"
julia> iter = ["New York", "Princeton", "San Francisco"]
julia> findnearest(s, iter, Levenshtein())
("NewYork", 1)
julia> findnearest(s, iter, Levenshtein(); min_score = 0.9)
(nothing, nothing)
```
"""
function findnearest(s, itr, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.0)
_citr = collect(itr)
isempty(_citr) && return (nothing, nothing)
_preprocessed_s = _preprocess(dist, s)
min_score_atomic = Threads.Atomic{Float64}(min_score)
chunk_size = max(1, length(_citr) ÷ (2 * Threads.nthreads()))
data_chunks = Iterators.partition(_citr, chunk_size)
chunk_score_tasks = map(data_chunks) do chunk
Threads.@spawn begin
map(chunk) do x
score = compare(_preprocessed_s, _preprocess(dist, x), dist; min_score = min_score)
Threads.atomic_max!(min_score_atomic, score)
score
end
end
end
# retrieve return type of `compare` for type stability in task
_self_cmp = compare(_preprocessed_s, _preprocessed_s, dist; min_score = min_score)
chunk_scores = fetch.(chunk_score_tasks)::Vector{Vector{typeof(_self_cmp)}}
scores = reduce(vcat, fetch.(chunk_scores))
imax = argmax(scores)
iszero(scores) ? (nothing, nothing) : (_citr[imax], imax)
end
_preprocess(dist::AbstractQGramDistance, ::Missing) = missing
_preprocess(dist::AbstractQGramDistance, s) = QGramSortedVector(s, dist.q)
_preprocess(dist::Union{StringSemiMetric, StringMetric}, s) = s
function Base.findmax(s, itr, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.0)
@warn "findmax(s, itr, dist; min_score) is deprecated. Use findnearest(s, itr, dist; min_score)"
findnearest(s, itr, dist; min_score = min_score)
end
"""
findall(s, itr , dist::StringDistance; min_score = 0.8)
`findall` returns the vector of indices for elements of `itr` that have a
similarity score higher or equal than `min_score` according to the distance `dist`.
If there are no such elements, return an empty array.
It is particularly optimized for [`Levenshtein`](@ref) and [`DamerauLevenshtein`](@ref) distances
(as well as their modifications via `Partial`, `TokenSort`, `TokenSet`, or `TokenMax`).
### Examples
```julia-repl
julia> using StringDistances
julia> s = "Newark"
julia> iter = ["Newwark", "Princeton", "San Francisco"]
julia> findall(s, iter, Levenshtein())
1-element Array{Int64,1}:
1
julia> findall(s, iter, Levenshtein(); min_score = 0.9)
0-element Array{Int64,1}
```
"""
function Base.findall(s, itr, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.8)
_citr = collect(itr)
_preprocessed_s = _preprocess(dist, s)
chunk_size = max(1, length(_citr) ÷ (2 * Threads.nthreads()))
data_chunks = Iterators.partition(itr, chunk_size)
isempty(data_chunks) && return empty(eachindex(_citr))
chunk_score_tasks = map(data_chunks) do chunk
Threads.@spawn begin
map(chunk) do x
compare(_preprocessed_s, _preprocess(dist, x), dist; min_score = min_score)
end
end
end
# retrieve return type of `compare` for type stability in task
_self_cmp = compare(_preprocessed_s, _preprocessed_s, dist; min_score = min_score)
chunk_scores::Vector{Vector{typeof(_self_cmp)}} = fetch.(chunk_score_tasks)
scores = reduce(vcat, fetch.(chunk_scores))
return findall(>=(min_score), scores)
end