Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] provide a public API for accessing/mapping over Union components #53193

Open
nsajko opened this issue Feb 5, 2024 · 8 comments
Open

[FR] provide a public API for accessing/mapping over Union components #53193

nsajko opened this issue Feb 5, 2024 · 8 comments
Labels
domain:types and dispatch Types, subtyping and method dispatch kind:feature Indicates new feature / enhancement requests

Comments

@nsajko
Copy link
Contributor

nsajko commented Feb 5, 2024

Useful functionality like Base.typesplit or Base.promote_union may be implemented by recursively mapping over the components of a Union object. I wonder if it would make sense to provide a public API so users could do that themselves. I guess either of these would be nice:

  1. document and support the a and b fields of a Union, or

  2. provide some higher level API similar to mapreduce, that would take a type instead of a collection, interpreting union components as collection elements

@nsajko nsajko added domain:types and dispatch Types, subtyping and method dispatch kind:feature Indicates new feature / enhancement requests labels Feb 5, 2024
@mikmoore
Copy link
Contributor

mikmoore commented Feb 5, 2024

I think something like the current Base.uniontypes that specializes and returns a Tuple (rather than current @nospecialize-annotated function returning a Vector{Any}, which is mostly used for codegen and tests) could fit this bill.

I would even consider commandeering this function name for this purpose and replacing the internal one with something else (where it matters -- I'm sure letting it specialize in some places wouldn't be catastrophic).

@vtjnash
Copy link
Sponsor Member

vtjnash commented Feb 5, 2024

Such as Base.uniontypes? Note that it is pretty tricky to work with the structure of a Union correctly, since its fields are not a independent collection in themselves, but rather only when taken as a whole (e.g. the problems with #48205)

@vtjnash
Copy link
Sponsor Member

vtjnash commented Feb 5, 2024

If Base.uniontypes returned a Tuple, it would not be inferable concretely, which seems like a massive footgun (#48205).

@mikmoore
Copy link
Contributor

mikmoore commented Feb 5, 2024

I don't think I properly understand the entire issue with concrete inferability. But I see the linked issue and how the difficulty of canonicalizing Unions can cause trouble, as well as the complication of shared typevars within a Union.

In any case, the following definition appears to only use external properties (plus the implementation detail that Unions of more than two types are made by nesting)

# ::Type overwrites ::Type{Union{...}}, so we need two different functions here
myuniontypes(x::Type) = (x,)
myuniontypes(x::Union) = _myuniontypes(x)
_myuniontypes(::Type{Union{A,B}}) where {A,B} = (myuniontypes(A)...,myuniontypes(B)...)

with the results

julia> myuniontypes(Base.IEEEFloat)
(Float16, Float32, Float64)

julia> @code_warntype myuniontypes(Base.IEEEFloat) # properly inferred as a constant
MethodInstance for myuniontypes(::Type{Union{Float16, Float32, Float64}})
  from myuniontypes(x::Union) @ Main REPL[2]:1
Arguments
  #self#::Core.Const(myuniontypes)
  x::Type{Union{Float16, Float32, Float64}}
Body::Tuple{DataType, DataType, DataType}
1 ─ %1 = Main._myuniontypes(x)::Core.Const((Float16, Float32, Float64))
└──      return %1

julia> myuniontypes(AbstractVecOrMat{Float64})
(AbstractVector{Float64}, AbstractMatrix{Float64})

julia> myuniontypes(AbstractVecOrMat{<:Real}) # can't reduce due to shared typevar
(AbstractVecOrMat{<:Real},)

julia> myuniontypes(Union{AbstractVector{<:Real},AbstractMatrix{<:Real}}) # distinct typevars
(AbstractVector{<:Real}, AbstractMatrix{<:Real})

julia> myuniontypes(StridedMatrix{Float64}) # actually comes out better than I would have expected
( # added line breaks for "readability"
	DenseMatrix{Float64},
	Base.ReinterpretArray{Float64, 2, S, A, IsReshaped} where {A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S},
	Base.ReshapedArray{Float64, 2, A} where A<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {T, N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S},	SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray},
	SubArray{Float64, 2, A, I} where {A<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {T, N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, Base.ReshapedArray{T, N, A} where {T, N, A<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {T, N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}, IsReshaped, S}, SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real}}, Tuple{AbstractUnitRange, Vararg{Any}}}}, DenseArray}}, DenseArray}, I<:Tuple{Vararg{Union{Base.AbstractCartesianIndex, AbstractRange{<:Union{Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8}}, Base.ReshapedArray{T, N, A, Tuple{}} where {T, N, A<:AbstractUnitRange}, Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8}}}}
)

There certainly are situations where a more specific answer doesn't exist (and maybe this function should be made to throw errors for those). But for "simple" results this appears to work. Is this the sort of implementation we should be directing people to for questions like this one? One could say this is still no better than internals like Base.uniontypes because it depends on the "nested binary unions" implementation detail. I can't say I love it for code outside of Base, but we need some kind of answer for these questions.

@vtjnash
Copy link
Sponsor Member

vtjnash commented Feb 5, 2024

_myuniontypes relies upon an implementation bug in subtyping and should be avoided in real code

@Tokazama
Copy link
Contributor

Tokazama commented Feb 6, 2024

Such as Base.uniontypes? Note that it is pretty tricky to work with the structure of a Union correctly, since its fields are not a independent collection in themselves, but rather only when taken as a whole (e.g. the problems with #48205)

This is why we should have something like Base.uniontypes, to discourage people from mucking around with internals that may change in the future or are tricky to get right. Right now people just directly access the the fields of Union and iterate over it like a list.

If Base.uniontypes returned a Tuple, it would not be inferable concretely, which seems like a massive footgun (#48205).

Couldn't we just return a tuple of types instead of a Tuple type?

@vtjnash
Copy link
Sponsor Member

vtjnash commented Feb 7, 2024

neither big or little t Tuple is inferable because of how the type system works (#48205)

@Tokazama
Copy link
Contributor

The union_to_tuple function below manually unrolls into a tuple up to length of 32. If the union is longer than that it pushes to a locally created vector then collects into a tuple.

In the most naive case where we only know we're working with Union this infers Tuple. If the union type is being lowered from some other type parameter, then we can completely infer everything (e.g., eltype_variants(::Type{<:Array{T}}) where {T} = isa(T, Union) ? union_to_tuple(T) : T will know the exact return type). I think this gets us to the same level of inference as fieldtypes.

implementation of `union_to_tuple`
function union_to_tuple(b0::Union)
    (a1 = getfield(b0, 1); b1 = getfield(b0, 2); isa(b1, Union)) || return (a1, b1)
    (a2 = getfield(b1, 1); b2 = getfield(b1, 2); isa(b2, Union)) || return (a1, a2, b2)
    (a3 = getfield(b2, 1); b3 = getfield(b2, 2); isa(b3, Union)) || return (a1, a2, a3, b3)
    (a4 = getfield(b3, 1); b4 = getfield(b3, 2); isa(b4, Union)) || return (a1, a2, a3, a4, b4)
    (a5 = getfield(b4, 1); b5 = getfield(b4, 2); isa(b5, Union)) || return (a1, a2, a3, a4, a5, b5)
    (a6 = getfield(b5, 1); b6 = getfield(b5, 2); isa(b6, Union)) || return (
        a1, a2, a3, a4, a5, a6, b6
    )
    (a7 = getfield(b6, 1); b7 = getfield(b6, 2); isa(b7, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, b7
    )
    (a8 = getfield(b7, 1); b8 = getfield(b7, 2); isa(b8, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, b8
    )
    (a9 = getfield(b8, 1); b9 = getfield(b8, 2); isa(b9, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, b9
    )
    (a10 = getfield(b9, 1); b10 = getfield(b9, 2); isa(b10, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, b10
    )
    (a11 = getfield(b10, 1); b11 = getfield(b10, 2); isa(b11, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, b11
    )
    (a12 = getfield(b11, 1); b12 = getfield(b11, 2); isa(b12, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, b12
    )
    (a13 = getfield(b12, 1); b13 = getfield(b12, 2); isa(b13, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, b13
    )
    (a14 = getfield(b13, 1); b14 = getfield(b13, 2); isa(b14, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, b14
    )
    (a15 = getfield(b14, 1); b15 = getfield(b14, 2); isa(b15, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, b15
    )
    (a16 = getfield(b15, 1); b16 = getfield(b15, 2); isa(b16, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, b16
    )
    (a17 = getfield(b16, 1); b17 = getfield(b16, 2); isa(b17, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, b17
    )
    (a18 = getfield(b17, 1); b18 = getfield(b17, 2); isa(b18, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        b18
    )
    (a19 = getfield(b18, 1); b19 = getfield(b18, 2); isa(b19, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, b19
    )
    (a20 = getfield(b19, 1); b20 = getfield(b19, 2); isa(b20, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, b20
    )
    (a21 = getfield(b20, 1); b21 = getfield(b20, 2); isa(b21, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, b21
    )
    (a22 = getfield(b21, 1); b22 = getfield(b21, 2); isa(b22, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, b22
    )
    (a23 = getfield(b22, 1); b23 = getfield(b22, 2); isa(b23, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, b23
    )
    (a24 = getfield(b23, 1); b24 = getfield(b23, 2); isa(b24, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, a24, b24
    )
    (a25 = getfield(b24, 1); b25 = getfield(b24, 2); isa(b25, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, a24, a25, b25
    )
    (a26 = getfield(b25, 1); b26 = getfield(b25, 2); isa(b26, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, a24, a25, a26, b26
    )
    (a27 = getfield(b26, 1); b27 = getfield(b26, 2); isa(b27, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, a24, a25, a26, a27, b27
    )
    (a28 = getfield(b27, 1); b28 = getfield(b27, 2); isa(b28, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, b28
    )
    (a29 = getfield(b28, 1); b29 = getfield(b28, 2); isa(b29, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, b29
    )
    (a30 = getfield(b29, 1); b30 = getfield(b29, 2); isa(b30, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, b30
    )
    (a31 = getfield(b30, 1); b31 = getfield(b30, 2); isa(b31, Union)) || return (
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31, b31
    )
    uvec = [
        a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18,
        a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31
    ]
    utail = b31
    while true
        push!(uvec, getfield(utail, 1))
        tmp = getfield(utail, 2)
        if isa(tmp, Union)
            utail = tmp
        else
            push!(uvec, tmp)
            break
        end
    end
    return Tuple(uvec)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
domain:types and dispatch Types, subtyping and method dispatch kind:feature Indicates new feature / enhancement requests
Projects
None yet
Development

No branches or pull requests

4 participants