Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

findmin(A; dims=1) is much slower than manually looping over. #510

Closed
pratyai opened this issue Feb 13, 2024 · 1 comment · Fixed by #511
Closed

findmin(A; dims=1) is much slower than manually looping over. #510

pratyai opened this issue Feb 13, 2024 · 1 comment · Fixed by #511

Comments

@pratyai
Copy link
Contributor

pratyai commented Feb 13, 2024

Some demo:

using SparseArrays

# setups
n = 10000
halfhalf = a -> SparseMatrixCSC(a.m, a.n, a.colptr, a.rowval, a.nzval .- 0.5);
symmetrize = a -> (a + a')/2;
A = halfhalf(symmetrize(sprand(n, n, 0.1)));  # nnz(A) == 19004033
function manualmincol(a)
   local mincols = zeros(Int, a.m)
   local minvals = zeros(Float64, a.m)
   for c in 1:a.m
       local rb, re = a.colptr[c], a.colptr[c+1]
       if rb == re
           continue
       end
       local minval::Float64, row_::Int64 = Inf, 0
       for r in rb:(re-1)
           local val, row = a.nzval[r], a.rowval[r]
           if val < minval
               minval, row_ = val, row
           end
       end
       mincols[c], minvals[c] = row_, minval
   end
   return mincols, minvals
end

# test
@time bob = manualmincol(A)[1];
# output: 0.138503 seconds (5 allocations: 156.375 KiB)

@time bob2 = findmin(A, dims=1)[2];
# output: 2.414430 seconds (9 allocations: 391.094 KiB)

# check that they are the same
bob2 = Vector(first.(Tuple.(bob2))'[:,1]);
bob2 == bob
# output: true

I believe the reason is that with dims=... it goes through a non-specialised version of findmin():

julia> @which findmin(A, dims=1)
kwcall(::NamedTuple, ::typeof(findmin), A::AbstractArray)
     @ Base reducedim.jl:1130

Does it make sense to "intercept" the dims argument and provide a faster implementation? (I understand that the performance for row and column aggregations would be different, but currently both are much slower than spelling out the loops)

@dkarrasch
Copy link
Member

You may need to overload the three-arg function, _findmin(f, A, dims), but otherwise: yes, please, make a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants