From af65f13455c956093d21530c6846bb2bc18a90e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 19 Nov 2019 12:00:25 +0100 Subject: [PATCH] Implement mutability of Array --- src/dispatch.jl | 4 +-- src/interface.jl | 14 +++++------ src/linear_algebra.jl | 58 ++++++++++++++++++++++++++++++++++++++++++- src/rewrite.jl | 25 ++++++++++++++++--- src/shortcuts.jl | 2 ++ test/rewrite.jl | 4 +++ 6 files changed, 92 insertions(+), 15 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 91fc802..6f30ed5 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -3,7 +3,7 @@ abstract type AbstractMutable end # Special-case because the the base version wants to do fill!(::Array{AbstractVariableRef}, zero(GenericAffExpr{Float64,eltype(x)})) _one_indexed(A) = all(x -> isa(x, Base.OneTo), axes(A)) function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable}) - @assert _one_indexed(x) # Base.diagm doesn't work for non-one-indexed arrays in general. + @assert _one_indexed(x) # `LinearAlgebra.diagm` doesn't work for non-one-indexed arrays in general. ZeroType = promote_operation(zero, eltype(x)) - return diagm(0 => copyto!(similar(x, ZeroType), x)) + return LinearAlgebra.diagm(0 => copyto!(similar(x, ZeroType), x)) end diff --git a/src/interface.jl b/src/interface.jl index 4b22d8f..a931677 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -13,13 +13,10 @@ function promote_operation end function promote_operation(op::Function, args::Vararg{Type, N}) where N return typeof(op(zero.(args)...)) end -#promote_operation(::typeof(*), ::Type{T}) where {T} = T -#function promote_operation(op::typeof(*), ::Type{Array{T, N}}, ::Type{S}) where {S, T, N} -# return Array{promote_operation(op, T, S), N} -#end -#function promote_operation(op::typeof(*), ::Type{S}, ::Type{Array{T, N}}) where {S, T, N} -# return Array{promote_operation(op, S, T), N} -#end +promote_operation(::typeof(*), ::Type{T}) where {T} = T +function promote_operation(::typeof(*), ::Type{S}, ::Type{T}, ::Type{U}, args::Vararg{Type, N}) where {S, T, U, N} + return promote_operation(*, promote_operation(*, S, T), U, args...) +end # Define Traits abstract type MutableTrait end @@ -47,7 +44,8 @@ function mutable_operate_to_fallback(::NotMutable, output, op::Function, args... end function mutable_operate_to_fallback(::IsMutable, output, op::Function, args...) - error("`mutable_operate_to!($op, $(args...))` is not implemented yet.") + error("`mutable_operate_to!($(typeof(output)), $op, ", join(typeof.(args), ", "), + ")` is not implemented yet.") end """ diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 8a47170..96f19b4 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -1,6 +1,62 @@ import LinearAlgebra -mutability(::Type{<:Vector}) = IsMutable() +mutability(::Type{<:Array}) = IsMutable() + +# Sum + +function promote_operation(op::typeof(+), ::Type{Array{S, N}}, ::Type{Array{T, N}}) where {S, T, N} + return Array{promote_operation(op, S, T), N} +end +function mutable_operate!(::typeof(+), A::Array{S, N}, B::Array{T, N}) where{S, T, N} + for i in eachindex(A) + A[i] = operate!(+, A[i], B[i]) + end + return A +end + +# UniformScaling +const Scaling = Union{Number, LinearAlgebra.UniformScaling} +function promote_operation(op::typeof(+), ::Type{Array{T, 2}}, ::Type{LinearAlgebra.UniformScaling{S}}) where {S, T} + return Array{promote_operation(op, T, S), 2} +end +function promote_operation(op::typeof(+), ::Type{LinearAlgebra.UniformScaling{S}}, ::Type{Array{T, 2}}) where {S, T} + return Array{promote_operation(op, S, T), 2} +end +function mutable_operate!(::typeof(+), A::Matrix, B::LinearAlgebra.UniformScaling) + n = LinearAlgebra.checksquare(A) + for i in 1:n + A[i, i] = operate!(+, A[i, i], B) + end + return A +end +function mutable_operate!(::typeof(add_mul), A::Matrix, B::Scaling, C::Scaling, D::Vararg{Scaling, N}) where N + return mutable_operate!(+, A, *(B, C, D...)) +end +function mutable_operate!(::typeof(add_mul), A::Array{S, N}, B::Array{T, N}, α::Vararg{Scaling, M}) where {S, T, N, M} + for i in eachindex(A) + A[i] = operate!(add_mul, A[i], B[i], α...) + end + return A +end +function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α::Scaling, B::Array{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M} + for i in eachindex(A) + A[i] = operate!(add_mul, A[i], α, B[i], β...) + end + return A +end + +# Product + +function promote_operation(op::typeof(*), ::Type{Array{T, N}}, ::Type{S}) where {S, T, N} + return Array{promote_operation(op, T, S), N} +end +function promote_operation(op::typeof(*), ::Type{S}, ::Type{Array{T, N}}) where {S, T, N} + return Array{promote_operation(op, S, T), N} +end + +function promote_operation(::typeof(*), ::Type{Matrix{S}}, ::Type{Vector{T}}) where {S, T} + return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)} +end function promote_operation(::typeof(*), ::Type{<:AbstractMatrix{S}}, ::Type{<:AbstractVector{T}}) where {S, T} return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)} end diff --git a/src/rewrite.jl b/src/rewrite.jl index 573f92d..358fb0e 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -14,8 +14,22 @@ Base.:(+)(x, zero::Zero) = copy(x) using Base.Meta +# See `JuMP._try_parse_idx_set` +function _try_parse_idx_set(arg::Expr) + # [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and + # Expr(:ref, :x, Expr(:kw, :i, 1)) respectively. + if arg.head === :kw || arg.head === :(=) + @assert length(arg.args) == 2 + return true, arg.args[1], arg.args[2] + elseif isexpr(arg, :call) && arg.args[1] === :in + return true, arg.args[2], arg.args[3] + else + return false, nothing, nothing + end +end + function _parse_idx_set(arg::Expr) - parse_done, idxvar, idxset = Containers._try_parse_idx_set(arg) + parse_done, idxvar, idxset = _try_parse_idx_set(arg) if parse_done return idxvar, idxset end @@ -61,6 +75,9 @@ function _parse_gen(ex, atleaf) return loop end +# See `JuMP._is_sum` +_is_sum(s::Symbol) = (s == :sum) || (s == :∑) || (s == :Σ) + function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff=gensym()) @assert isexpr(x,:call) @assert length(x.args) > 1 @@ -175,7 +192,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb x.args[i] = esc(x.args[i]) end end - callexpr = Expr(:call, :operate!, add_mul, aff, + callexpr = Expr(:call, :(MutableArithmetics.operate!), add_mul, aff, lcoeffs..., x.args[2:end]..., rcoeffs...) push!(blk.args, :($newaff = $callexpr)) return newaff, blk @@ -187,7 +204,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb s = gensym() newaff_, parsed = _rewrite_toplevel(x.args[2], s) push!(blk.args, :($s = Zero(); $parsed)) - push!(blk.args, :($newaff = operate!(add_mul, + push!(blk.args, :($newaff = MutableArithmetics.operate!(add_mul, $aff, $(Expr(:call, :*, lcoeffs..., newaff_, newaff_, rcoeffs...))))) return newaff, blk @@ -225,6 +242,6 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb " become a syntax error in a future release." end # at the lowest level - callexpr = Expr(:call, :operate!, add_mul, aff, lcoeffs..., esc(x), rcoeffs...) + callexpr = Expr(:call, :(MutableArithmetics.operate!), add_mul, aff, lcoeffs..., esc(x), rcoeffs...) return newaff, :($newaff = $callexpr) end diff --git a/src/shortcuts.jl b/src/shortcuts.jl index c164afe..ca06a8a 100644 --- a/src/shortcuts.jl +++ b/src/shortcuts.jl @@ -40,6 +40,8 @@ function promote_operation(::typeof(add_mul), T::Type, args::Vararg{Type, N}) wh return promote_operation(+, T, promote_operation(*, args...)) end +mutable_operate!(::typeof(add_mul), x, y) = mutable_operate!(+, x, y) + """ add_mul_to!(output, args...) diff --git a/test/rewrite.jl b/test/rewrite.jl index 6469879..d65ce32 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -239,6 +239,10 @@ function vectorized_test(x, X11, X23, Xd) v = [4, 5, 6] @testset "Sum of matrices" begin + @test_rewrite(x + x) + @test_rewrite(x + 2x) + @test_rewrite(x + x * 2) + @test_rewrite(x + 2x * 2) @test_rewrite(Xd + Yd) @test_rewrite(Xd + 2Yd) @test_rewrite(Xd + Yd * 2)