Skip to content

Commit

Permalink
Fix broadcast with sparse arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Feb 21, 2020
1 parent ac1e1b9 commit 31253fe
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ end
function isequal_canonical(x::LinearAlgebra.Tridiagonal, y::LinearAlgebra.Tridiagonal)
return isequal_canonical(x.dl, y.dl) && isequal_canonical(x.d, y.d) && isequal_canonical(x.du, y.du)
end
function isequal_canonical(x::SparseMat, y::SparseMat)
return x.m == y.m && x.n == y.n && isequal_canonical(x.colptr, y.colptr) && isequal_canonical(x.rowval, y.rowval) && isequal_canonical(x.nzval, y.nzval)
end

include("rewrite.jl")
include("dispatch.jl")
Expand Down
15 changes: 15 additions & 0 deletions src/Test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ function sparse_linear_test(X11, X23, Xd)
0 0 4X23
0 0 0]

function _test_broadcast(A, B)
@test_rewrite(A .* B)
@test_rewrite(B .* A)
@test_rewrite(A .+ B)
@test_rewrite(B .+ A)
@test_rewrite(A .- B)
@test_rewrite(B .- A)
end

_test_broadcast(Xd, X)
_test_broadcast(Xd, Y)
_test_broadcast(Yd, X)
_test_broadcast(Yd, Y)
_test_broadcast(X, Y)

add_test(Xd, Yd)
add_test(Xd, Y)
add_test(Xd, Xd)
Expand Down
12 changes: 9 additions & 3 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Type{Eltype}) where {N, Eltype}
function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Base.HasShape{N}, ::Type{Eltype}) where {N, Eltype}
return Array{Eltype, N}
end
function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Type{Bool}) where N
function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Base.HasShape{N}, ::Type{Bool}) where N
return BitArray{N}
end

Expand All @@ -11,12 +11,18 @@ combine_styles(c::Type) = Broadcast.result_style(Broadcast.BroadcastStyle(c))
combine_styles(c1::Type, c2::Type) = Broadcast.result_style(combine_styles(c1), combine_styles(c2))
@inline combine_styles(c1::Type, c2::Type, cs::Vararg{Type, N}) where N = Broadcast.result_style(combine_styles(c1), combine_styles(c2, cs...))

combine_shapes(s) = s
combine_2_shapes(s1::Base.HasShape{N}, s2::Base.HasShape{M}) where {N, M} = Base.HasShape{max(N, M)}()
combine_shapes(s1, s2, args::Vararg{Any, N}) where {N} = combine_shapes(combine_2_shapes(s1, s2), args...)
_shape(T) = Base.HasShape{ndims(T)}()
combine_sizes(args::Vararg{Any, N}) where {N} = combine_shapes(_shape.(args)...)

function promote_broadcast(op::Function, args::Vararg{Any, N}) where N
# FIXME we could use `promote_operation` instead as
# `combine_eltypes` uses `return_type` hence it may return a non-concrete type
# and we do not handle that case.
T = Base.Broadcast.combine_eltypes(op, args)
return broadcasted_type(combine_styles(args...), T)
return broadcasted_type(combine_styles(args...), combine_sizes(args...), T)
end

"""
Expand Down
9 changes: 9 additions & 0 deletions src/sparse_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,12 @@ function mutable_operate!(::typeof(add_mul), ret::SparseMat{T},
α::Vararg{Union{T, Scaling}, N}) where {T, N}
mutable_operate!(add_mul, ret, copy(A), B, α...)
end

# This `BroadcastStyle` is used when there is a mix of sparse arrays and dense arrays.
# The result is a sparse array.
function broadcasted_type(::SparseArrays.HigherOrderFns.PromoteToSparse, ::Base.HasShape{1}, ::Type{Eltype}) where Eltype
return SparseArrays.SparseVector{Eltype, Int}
end
function broadcasted_type(::SparseArrays.HigherOrderFns.PromoteToSparse, ::Base.HasShape{2}, ::Type{Eltype}) where Eltype
return SparseMat{Eltype, Int}
end

0 comments on commit 31253fe

Please sign in to comment.