/
funres.jl
106 lines (91 loc) · 3.17 KB
/
funres.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Dict-like data structure which maps function signature to a value.
Unlike real dict, getindex(rsv, sig) returns either exact match, or
closest matching function signature. Example:
rsv = FunctionResolver{Symbol}()
rsv[Tuple{typeof(sin), Float64}] = :Float64
rsv[Tuple{typeof(sin), Real}] = :Real
rsv[Tuple{typeof(sin), Number}] = :Number
rsv[Tuple{typeof(sin), Float64}] # ==> :Float64
rsv[Tuple{typeof(sin), Float32}] # ==> :Real
"""
mutable struct FunctionResolver{T}
signatures::Dict{Symbol, Vector{Pair{Any, T}}}
ordered::Bool
FunctionResolver{T}() where T = new{T}(Dict(), false)
end
# function FunctionResolver{T}(pairs::Vector{Pair{S, T}}) where {S, T}
# function FunctionResolver{T}(pairs::Vector{Pair{S, T} where S}) where T
function FunctionResolver{T}(pairs::Vector) where T
rsv = FunctionResolver{T}()
for (sig, val) in pairs
rsv[sig] = val
end
order!(rsv)
return rsv
end
Base.show(io::IO, rsv::FunctionResolver) = print(io, "FunctionResolver($(length(rsv.signatures)))")
function_type(@nospecialize sig) = sig isa UnionAll ? function_type(sig.body) : sig.parameters[1]
function_type_key(fn_typ) = Symbol("$(Base.parentmodule(fn_typ)).$(Base.nameof(fn_typ))")
function Base.setindex!(rsv::FunctionResolver{T}, val::T, @nospecialize sig::Type{<:Tuple}) where T
fn_typ = function_type(sig)
key = function_type_key(fn_typ)
if !haskey(rsv.signatures, key)
rsv.signatures[key] = Pair{Type, T}[]
end
pairs = rsv.signatures[key]
# if such signature already exists, just replace the value
updated = false
for (i, (old_sig, _)) in enumerate(pairs)
if sig == old_sig
pairs[i] = sig => val
updated = true
break
end
end
# otherwise push new pair and mark as unordered
if !updated
push!(pairs, sig => val)
rsv.ordered = false
end
return val
end
function Base.getindex(rsv::FunctionResolver{T}, @nospecialize sig::Type{<:Tuple}) where T
rsv.ordered || order!(rsv)
fn_typ = function_type(sig)
key = function_type_key(fn_typ)
if haskey(rsv.signatures, key)
for (TT, val) in rsv.signatures[key]
if sig <: TT
return val
end
end
end
return nothing
end
if Vararg isa Type
# Vararg changes between julia 1.6 and julia 1.7
is_Vararg(T) = T isa Type && T <: Vararg
else
is_Vararg(T) = T isa typeof(Vararg)
end
function isless_signature(sig1, sig2)
# signatures with Varargs should go last
if any(is_Vararg, get_type_parameters(sig2))
return true
else
return sig1 <: sig2
end
end
function order!(rsv::FunctionResolver)
for (fn_typ, sigs) in rsv.signatures
sort!(sigs, lt=(p1, p2) -> isless_signature(p1[1], p2[1]))
end
rsv.ordered = true
end
Base.haskey(rsv::FunctionResolver, sig::Type{<:Tuple}) = (rsv[sig] !== nothing)
Base.in(sig::Type{<:Tuple}, rsv::FunctionResolver) = haskey(rsv, sig)
Base.empty!(rsv::FunctionResolver) = empty!(rsv.signatures)
function find_signatures_for(rsv::FunctionResolver, f::Union{Function, DataType})
return rsv.signatures[function_type_key(typeof(f))]
end