Skip to content

Commit

Permalink
Tweak order of operations to get nnz to infer as Int return type
Browse files Browse the repository at this point in the history
If the sparse array does not have a concrete index type, then union
splitting occurs over the possible `<:Integer` types permitted by
`SparseMatrixCSC`:

```julia
julia> code_warntype(nnz, (SparseMatrixCSC{Float64,<:Integer},), optimize=true, debuginfo=:none)
Variables
  #self#::Core.Const(SparseArrays.nnz)
  S::SparseMatrixCSC{Float64, var"#s96"} where var"#s96"<:Integer

Body::Any
1 ── %1  = SparseArrays.getfield(S, :colptr)::Vector{var"#s96"} where var"#s96"<:Integer
│    %2  = SparseArrays.getfield(S, :n)::Int64
│    %3  = Base.add_int(%2, 1)::Int64
│    %4  = Base.getindex(%1, %3)::Integer
│    %5  = (isa)(%4, Int64)::Bool
└───       goto JuliaLang#3 if not %5
2 ── %7  = π (%4, Int64)
│    %8  = Base.sub_int(%7, 1)::Int64
└───       goto JuliaLang#15
3 ── %10 = (isa)(%4, BigInt)::Bool
└───       goto JuliaLang#14 if not %10
4 ── %12 = π (%4, BigInt)
│    %13 = Base.slt_int(1, 0)::Bool
└───       goto JuliaLang#6 if not %13
5 ── %15 = Base.bitcast(UInt64, 1)::UInt64
│    %16 = Base.neg_int(%15)::UInt64
│    %17 = Base.GMP.MPZ.add_ui::typeof(Base.GMP.MPZ.add_ui)
│    %18 = invoke %17(%12::BigInt, %16::UInt64)::BigInt
└───       goto JuliaLang#13
6 ── %20 = Core.lshr_int(1, 63)::Int64
│    %21 = Core.trunc_int(Core.UInt8, %20)::UInt8
│    %22 = Core.eq_int(%21, 0x01)::Bool
└───       goto JuliaLang#8 if not %22
7 ──       invoke Core.throw_inexacterror(:check_top_bit::Symbol, UInt64::Type{UInt64}, 1::Int64)
└───       unreachable
8 ──       goto JuliaLang#9
9 ── %27 = Core.bitcast(Core.UInt64, 1)::UInt64
└───       goto JuliaLang#10
10 ─       goto JuliaLang#11
11 ─       goto JuliaLang#12
12 ─ %31 = Base.GMP.MPZ.sub_ui::typeof(Base.GMP.MPZ.sub_ui)
│    %32 = invoke %31(%12::BigInt, %27::UInt64)::BigInt
└───       goto JuliaLang#13
13 ┄ %34 = φ (JuliaLang#5 => %18, JuliaLang#12 => %32)::Any
└───       goto JuliaLang#15
14 ─ %36 = (%4 - 1)::Any
└───       goto JuliaLang#15
15 ┄ %38 = φ (JuliaLang#2 => %8, JuliaLang#13 => %34, JuliaLang#14 => %36)::Any
│    %39 = SparseArrays.Int(%38)::Any
└───       return %39
```

It appears that union splitting over the subtraction by one includes
an `Any` branch that widens the return type of `nnz`. By instead
converting the index type to `Int` before subtracting, type inference
is able to infer that all paths give an `Int` result:

```julia
julia> code_warntype(nnz, (SparseMatrixCSC{Float64,<:Integer},), optimize=true, debuginfo=:none)
Variables
  #self#::Core.Const(SparseArrays.nnz)
  S::SparseMatrixCSC{Float64, var"#s96"} where var"#s96"<:Integer

Body::Int64
1 ── %1  = SparseArrays.getfield(S, :colptr)::Vector{var"#s96"} where var"#s96"<:Integer
│    %2  = SparseArrays.getfield(S, :n)::Int64
│    %3  = Base.add_int(%2, 1)::Int64
│    %4  = Base.getindex(%1, %3)::Integer
│    %5  = (isa)(%4, BigInt)::Bool
└───       goto JuliaLang#14 if not %5
2 ── %7  = π (%4, BigInt)
│    %8  = Base.getfield(%7, :size)::Int32
│    %9  = Base.flipsign_int(%8, %8)::Int32
│    %10 = Core.sext_int(Core.Int64, %9)::Int64
│    %11 = Base.sle_int(0, %10)::Bool
└───       goto JuliaLang#4 if not %11
3 ── %13 = Core.sext_int(Core.Int64, %9)::Int64
│    %14 = Base.sle_int(%13, 1)::Bool
└───       goto JuliaLang#5
4 ──       nothing
5 ┄─ %17 = φ (JuliaLang#3 => %14, JuliaLang#4 => false)::Bool
└───       goto JuliaLang#12 if not %17
6 ── %19 = Base.getfield(%7, :size)::Int32
│    %20 = Core.sext_int(Core.Int64, %19)::Int64
│    %21 = (%20 === 0)::Bool
└───       goto JuliaLang#8 if not %21
7 ──       goto JuliaLang#9
8 ── %24 = Base.getfield(%7, :d)::Ptr{UInt64}
│    %25 = Base.pointerref(%24, 1, 1)::UInt64
│    %26 = Base.bitcast(Int64, %25)::Int64
│    %27 = Base.getfield(%7, :size)::Int32
│    %28 = Core.sext_int(Core.Int64, %27)::Int64
│    %29 = Base.flipsign_int(%26, %28)::Int64
└───       goto JuliaLang#9
9 ┄─ %31 = φ (JuliaLang#7 => 0, JuliaLang#8 => %29)::Int64
│    %32 = Base.getfield(%7, :size)::Int32
│    %33 = Core.sext_int(Core.Int64, %32)::Int64
│    %34 = Base.slt_int(0, %33)::Bool
│    %35 = Base.slt_int(0, %31)::Bool
│    %36 = (%34 === %35)::Bool
│    %37 = Base.not_int(%36)::Bool
└───       goto JuliaLang#11 if not %37
10 ─ %39 = Base.GMP.nameof(Int64)::Any
│    %40 = Base.GMP.InexactError(%39, Int64, %7)::Any
│          Base.GMP.throw(%40)
└───       unreachable
11 ─       goto JuliaLang#13
12 ─ %44 = Base.GMP.nameof(Int64)::Any
│    %45 = Base.GMP.InexactError(%44, Int64, %7)::Any
│          Base.GMP.throw(%45)
└───       unreachable
13 ─       goto JuliaLang#15
14 ─ %49 = SparseArrays.Int(%4)::Int64
└───       goto JuliaLang#15
15 ┄ %51 = φ (JuliaLang#13 => %31, JuliaLang#14 => %49)::Int64
│    %52 = Base.sub_int(%51, 1)::Int64
└───       return %52
```
  • Loading branch information
jmert committed Nov 4, 2020
1 parent aa8ca02 commit 3672a3a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion stdlib/SparseArrays/src/sparsematrix.jl
Expand Up @@ -108,7 +108,7 @@ julia> nnz(A)
3
```
"""
nnz(S::AbstractSparseMatrixCSC) = Int(getcolptr(S)[size(S, 2) + 1] - 1)
nnz(S::AbstractSparseMatrixCSC) = Int(getcolptr(S)[size(S, 2) + 1]) - 1
nnz(S::ReshapedArray{<:Any,1,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
nnz(S::UpperTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
nnz(S::LowerTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
Expand Down

0 comments on commit 3672a3a

Please sign in to comment.