diff --git a/src/MutableArithmetics.jl b/src/MutableArithmetics.jl index 0b3c78b..6c3b493 100644 --- a/src/MutableArithmetics.jl +++ b/src/MutableArithmetics.jl @@ -20,7 +20,20 @@ Return `a + *(args...)`. Note that `add_mul(a, b, c) = muladd(b, c, a)`. function add_mul end add_mul(a, b) = a + b add_mul(a, b, c) = muladd(b, c, a) -add_mul(a, b, c::Vararg{Any, N}) where {N} = add_mul(a, b *(c...)) +add_mul(a, b, c, d, args::Vararg{Any, N}) where {N} = add_mul(a, b, *(c, d, args...)) + +""" + sub_mul(a, args...) + +Return `a + *(args...)`. Note that `sub_mul(a, b, c) = muladd(b, c, a)`. +""" +function sub_mul end +sub_mul(a, b) = a - b +sub_mul(a, b, c, args::Vararg{Any, N}) where {N} = a - *(b, c, args...) + +const AddSubMul = Union{typeof(add_mul), typeof(sub_mul)} +add_sub_op(::typeof(add_mul)) = + +add_sub_op(::typeof(sub_mul)) = - """ iszero!(x) diff --git a/src/Test/int.jl b/src/Test/int.jl index 60e4d0c..4f115a7 100644 --- a/src/Test/int.jl +++ b/src/Test/int.jl @@ -14,8 +14,30 @@ function int_add_test(::Type{T}) where T a = t(165) b = t(255) - @test MA.isequal_canonical(MA.add!(a, b), t(420)) - @test MA.isequal_canonical(a, t(420)) + expected = t(420) + @test MA.isequal_canonical(MA.add!(a, b), expected) + @test MA.isequal_canonical(a, expected) + end +end +function int_sub_test(::Type{T}) where T + @testset "sub_to! / sub!" begin + @test MA.mutability(T, -, T, T) isa MA.IsMutable + + t(n) = convert(T, n) + a = t(5) + b = t(28) + c = t(41) + expected = t(-13) + @test MA.isequal_canonical(MA.sub_to!(a, b, c), expected) + @test MA.isequal_canonical(a, expected) + @test MA.isequal_canonical(MA.sub!(b, c), expected) + @test MA.isequal_canonical(b, expected) + + a = t(165) + b = t(255) + expected = t(-90) + @test MA.isequal_canonical(MA.sub!(a, b), expected) + @test MA.isequal_canonical(a, expected) end end function int_mul_test(::Type{T}) where T @@ -76,6 +98,47 @@ function int_add_mul_test(::Type{T}) where T @test MA.isequal_canonical(a, t(420)) end end +function int_sub_mul_test(::Type{T}) where T + @testset "sub_mul_to! / sub_mul! / sub_mul_buf_to! / sub_mul_buf!" begin + @test MA.mutability(T, MA.sub_mul, T, T) isa MA.IsMutable + @test MA.mutability(T, MA.sub_mul, T, T, T) isa MA.IsMutable + @test MA.mutability(T, MA.sub_mul, T, T, T, T) isa MA.IsMutable + + t(n) = convert(T, n) + a = t(5) + b = t(9) + c = t(3) + d = t(20) + buf = t(24) + + expected = t(-51) + @test MA.isequal_canonical(MA.sub_mul_to!(a, b, c, d), expected) + @test MA.isequal_canonical(a, expected) + a = t(5) + @test MA.isequal_canonical(MA.sub_mul!(b, c, d), expected) + @test MA.isequal_canonical(b, expected) + b = t(9) + + @test MA.isequal_canonical(MA.sub_mul_buf_to!(buf, a, b, c, d), expected) + @test MA.isequal_canonical(a, expected) + @test MA.isequal_canonical(MA.sub_mul_buf!(buf, b, c, d), expected) + @test MA.isequal_canonical(b, expected) + + a = t(148) + b = t(16) + c = t(17) + d = t(42) + buf = t(56) + expected = t(-124) + @test MA.isequal_canonical(MA.sub_mul!(a, b, c), expected) + @test MA.isequal_canonical(a, expected) + a = t(148) + @test MA.isequal_canonical(MA.sub_mul_buf_to!(buf, d, a, b, c), expected) + @test MA.isequal_canonical(d, expected) + @test MA.isequal_canonical(MA.sub_mul_buf!(buf, a, b, c), expected) + @test MA.isequal_canonical(a, expected) + end +end function int_zero_test(::Type{T}) where T @testset "zero!" begin diff --git a/src/bigfloat.jl b/src/bigfloat.jl index 44ba20a..e891757 100644 --- a/src/bigfloat.jl +++ b/src/bigfloat.jl @@ -29,6 +29,13 @@ end # return mutable_operate_to!(output, op, a, b.λ) #end +# - +promote_operation(::typeof(-), ::Vararg{Type{BigFloat}, N}) where {N} = BigFloat +function mutable_operate_to!(output::BigFloat, ::typeof(-), a::BigFloat, b::BigFloat) + ccall((:mpfr_sub, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Ref{BigFloat}, MPFRRoundingMode), output, a, b, Base.MPFR.ROUNDING_MODE[]) + return output +end + # * promote_operation(::typeof(*), ::Vararg{Type{BigFloat}, N}) where {N} = BigFloat function mutable_operate_to!(output::BigFloat, ::typeof(*), a::BigFloat, b::BigFloat) @@ -36,7 +43,7 @@ function mutable_operate_to!(output::BigFloat, ::typeof(*), a::BigFloat, b::BigF return output end -function mutable_operate_to!(output::BigFloat, op::Union{typeof(*), typeof(+)}, +function mutable_operate_to!(output::BigFloat, op::Union{typeof(+), typeof(-), typeof(*)}, a::BigFloat, b::BigFloat, c::Vararg{BigFloat, N}) where N mutable_operate_to!(output, op, a, b) return mutable_operate!(op, output, c...) @@ -45,34 +52,34 @@ function mutable_operate!(op::Function, x::BigFloat, args::Vararg{Any, N}) where mutable_operate_to!(x, op, x, args...) end -# add_mul +# add_mul and sub_mul # Buffer to hold the product -buffer_for(::typeof(add_mul), args::Vararg{Type{BigFloat}, N}) where {N} = BigFloat() -function mutable_operate_to!(output::BigFloat, ::typeof(add_mul), x::BigFloat, y::BigFloat, z::BigFloat, args::Vararg{BigFloat, N}) where N - return mutable_buffered_operate_to!(BigFloat(), output, add_mul, x, y, z, args...) +buffer_for(::AddSubMul, args::Vararg{Type{BigFloat}, N}) where {N} = BigFloat() +function mutable_operate_to!(output::BigFloat, op::AddSubMul, x::BigFloat, y::BigFloat, z::BigFloat, args::Vararg{BigFloat, N}) where N + return mutable_buffered_operate_to!(BigFloat(), output, op, x, y, z, args...) end -function mutable_buffered_operate_to!(buffer::BigFloat, output::BigFloat, ::typeof(add_mul), +function mutable_buffered_operate_to!(buffer::BigFloat, output::BigFloat, op::AddSubMul, a::BigFloat, x::BigFloat, y::BigFloat, args::Vararg{BigFloat, N}) where N mutable_operate_to!(buffer, *, x, y, args...) - return mutable_operate_to!(output, +, a, buffer) + return mutable_operate_to!(output, add_sub_op(op), a, buffer) end -function mutable_buffered_operate!(buffer::BigFloat, op::typeof(add_mul), x::BigFloat, args::Vararg{Any, N}) where N +function mutable_buffered_operate!(buffer::BigFloat, op::AddSubMul, x::BigFloat, args::Vararg{Any, N}) where N return mutable_buffered_operate_to!(buffer, x, op, x, args...) end scaling_to_bigfloat(x::BigFloat) = x scaling_to_bigfloat(x::Number) = convert(BigFloat, x) scaling_to_bigfloat(J::LinearAlgebra.UniformScaling) = scaling_to_bigfloat(J.λ) -function mutable_operate_to!(output::BigFloat, op::Union{typeof(+), typeof(*)}, args::Vararg{Scaling, N}) where N +function mutable_operate_to!(output::BigFloat, op::Union{typeof(+), typeof(-), typeof(*)}, args::Vararg{Scaling, N}) where N return mutable_operate_to!(output, op, scaling_to_bigfloat.(args)...) end -function mutable_operate_to!(output::BigFloat, op::typeof(add_mul), x::Scaling, y::Scaling, z::Scaling, args::Vararg{Scaling, N}) where N +function mutable_operate_to!(output::BigFloat, op::AddSubMul, x::Scaling, y::Scaling, z::Scaling, args::Vararg{Scaling, N}) where N return mutable_operate_to!( output, op, scaling_to_bigfloat(x), scaling_to_bigfloat(y), scaling_to_bigfloat(z), scaling_to_bigfloat.(args)...) end # Called for instance if `args` is `(v', v)` for a vector `v`. -function mutable_operate_to!(output::BigFloat, op::typeof(add_mul), x, y, z, args::Vararg{Any, N}) where N - return mutable_operate_to!(output, +, x, *(y, z, args...)) +function mutable_operate_to!(output::BigFloat, op::AddSubMul, x, y, z, args::Vararg{Any, N}) where N + return mutable_operate_to!(output, add_sub_op(op), x, *(y, z, args...)) end diff --git a/src/bigint.jl b/src/bigint.jl index 86ff4cc..3a66b32 100644 --- a/src/bigint.jl +++ b/src/bigint.jl @@ -18,13 +18,19 @@ end # return mutable_operate_to!(output, op, a, b.λ) #end +# - +promote_operation(::typeof(-), ::Vararg{Type{BigInt}, N}) where {N} = BigInt +function mutable_operate_to!(output::BigInt, ::typeof(-), a::BigInt, b::BigInt) + return Base.GMP.MPZ.sub!(output, a, b) +end + # * promote_operation(::typeof(*), ::Vararg{Type{BigInt}, N}) where {N} = BigInt function mutable_operate_to!(output::BigInt, ::typeof(*), a::BigInt, b::BigInt) return Base.GMP.MPZ.mul!(output, a, b) end -function mutable_operate_to!(output::BigInt, op::Union{typeof(*), typeof(+)}, +function mutable_operate_to!(output::BigInt, op::Union{typeof(+), typeof(-), typeof(*)}, a::BigInt, b::BigInt, c::Vararg{BigInt, N}) where N mutable_operate_to!(output, op, a, b) return mutable_operate!(op, output, c...) @@ -33,34 +39,34 @@ function mutable_operate!(op::Function, x::BigInt, args::Vararg{Any, N}) where N mutable_operate_to!(x, op, x, args...) end -# add_mul +# add_mul and sub_mul # Buffer to hold the product -buffer_for(::typeof(add_mul), args::Vararg{Type{BigInt}, N}) where {N} = BigInt() -function mutable_operate_to!(output::BigInt, ::typeof(add_mul), x::BigInt, y::BigInt, z::BigInt, args::Vararg{BigInt, N}) where N - return mutable_buffered_operate_to!(BigInt(), output, add_mul, x, y, z, args...) +buffer_for(::AddSubMul, args::Vararg{Type{BigInt}, N}) where {N} = BigInt() +function mutable_operate_to!(output::BigInt, op::AddSubMul, x::BigInt, y::BigInt, z::BigInt, args::Vararg{BigInt, N}) where N + return mutable_buffered_operate_to!(BigInt(), output, op, x, y, z, args...) end -function mutable_buffered_operate_to!(buffer::BigInt, output::BigInt, ::typeof(add_mul), +function mutable_buffered_operate_to!(buffer::BigInt, output::BigInt, op::AddSubMul, a::BigInt, x::BigInt, y::BigInt, args::Vararg{BigInt, N}) where N mutable_operate_to!(buffer, *, x, y, args...) - return mutable_operate_to!(output, +, a, buffer) + return mutable_operate_to!(output, add_sub_op(op), a, buffer) end -function mutable_buffered_operate!(buffer::BigInt, op::typeof(add_mul), x::BigInt, args::Vararg{Any, N}) where N +function mutable_buffered_operate!(buffer::BigInt, op::AddSubMul, x::BigInt, args::Vararg{Any, N}) where N return mutable_buffered_operate_to!(buffer, x, op, x, args...) end scaling_to_bigint(x::BigInt) = x scaling_to_bigint(x::Number) = convert(BigInt, x) scaling_to_bigint(J::LinearAlgebra.UniformScaling) = scaling_to_bigint(J.λ) -function mutable_operate_to!(output::BigInt, op::Union{typeof(+), typeof(*)}, args::Vararg{Scaling, N}) where N +function mutable_operate_to!(output::BigInt, op::Union{typeof(+), typeof(-), typeof(*)}, args::Vararg{Scaling, N}) where N return mutable_operate_to!(output, op, scaling_to_bigint.(args)...) end -function mutable_operate_to!(output::BigInt, op::typeof(add_mul), x::Scaling, y::Scaling, z::Scaling, args::Vararg{Scaling, N}) where N +function mutable_operate_to!(output::BigInt, op::AddSubMul, x::Scaling, y::Scaling, z::Scaling, args::Vararg{Scaling, N}) where N return mutable_operate_to!( output, op, scaling_to_bigint(x), scaling_to_bigint(y), scaling_to_bigint(z), scaling_to_bigint.(args)...) end # Called for instance if `args` is `(v', v)` for a vector `v`. -function mutable_operate_to!(output::BigInt, op::typeof(add_mul), x, y, z, args::Vararg{Any, N}) where N - return mutable_operate_to!(output, +, x, *(y, z, args...)) +function mutable_operate_to!(output::BigInt, op::AddSubMul, x, y, z, args::Vararg{Any, N}) where N + return mutable_operate_to!(output, add_sub_op(op), x, *(y, z, args...)) end diff --git a/src/interface.jl b/src/interface.jl index f30a3d2..1a766e7 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -29,10 +29,10 @@ function promote_operation(::typeof(*), ::Type{S}, ::Type{T}, ::Type{U}, args::V end # Helpful error for common mistake -function promote_operation(op::Union{typeof(+), typeof(-), typeof(add_mul)}, A::Type{<:Array}, α::Type{<:Number}) +function promote_operation(op::Union{typeof(+), typeof(-), AddSubMul}, A::Type{<:Array}, α::Type{<:Number}) error("Operation `$op` between `$A` and `$α` is not allowed. You should use broadcast.") end -function promote_operation(op::Union{typeof(+), typeof(-), typeof(add_mul)}, α::Type{<:Number}, A::Type{<:Array}) +function promote_operation(op::Union{typeof(+), typeof(-), AddSubMul}, α::Type{<:Number}, A::Type{<:Array}) error("Operation `$op` between `$α` and `$A` is not allowed. You should use broadcast.") end @@ -85,7 +85,7 @@ function operate end # API without altering `x` and `y`. If it is not the case, implement a # custom `operate` method. operate(::typeof(-), x) = -x -operate(op::Union{typeof(+), typeof(-), typeof(*), typeof(add_mul)}, x, y, args::Vararg{Any, N}) where {N} = op(x, y, args...) +operate(op::Union{typeof(+), typeof(-), typeof(*), AddSubMul}, x, y, args::Vararg{Any, N}) where {N} = op(x, y, args...) operate(::Union{typeof(+), typeof(*)}, x) = copy_if_mutable(x) @@ -169,8 +169,8 @@ function mutable_operate_to_fallback(::NotMutable, output, op::Function, args... throw(ArgumentError("Cannot call `mutable_operate_to!(::$(typeof(output)), $op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(output))` cannot be modifed to equal the result of the operation. Use `operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.")) end -function mutable_operate_to_fallback(::IsMutable, output, op::typeof(add_mul), x, y) - return mutable_operate_to!(output, +, x, y) +function mutable_operate_to_fallback(::IsMutable, output, op::AddSubMul, x, y) + return mutable_operate_to!(output, add_sub_op(op), x, y) end function mutable_operate_to_fallback(::IsMutable, output, op::Function, args...) error("`mutable_operate_to!(::$(typeof(output)), $op, ::", join(typeof.(args), ", ::"), @@ -201,8 +201,8 @@ function mutable_operate_fallback(::NotMutable, op::Function, args...) throw(ArgumentError("Cannot call `mutable_operate!($op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(args[1]))` cannot be modifed to equal the result of the operation. Use `operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.")) end -function mutable_operate_fallback(::IsMutable, op::typeof(add_mul), x, y) - return mutable_operate!(+, x, y) +function mutable_operate_fallback(::IsMutable, op::AddSubMul, x, y) + return mutable_operate!(add_sub_op(op), x, y) end function mutable_operate_fallback(::IsMutable, op::Function, args...) error("`mutable_operate!($op, ::", join(typeof.(args), ", ::"), diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index dc1a82d..43cc3ec 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -24,13 +24,10 @@ function mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Matrix, B::LinearA end return A end -function mutable_operate!(::typeof(add_mul), A::Matrix, B::Scaling, C::Scaling, D::Vararg{Scaling, N}) where N - return mutable_operate!(+, A, *(B, C, D...)) +function mutable_operate!(op::AddSubMul, A::Matrix, B::Scaling, C::Scaling, D::Vararg{Scaling, N}) where N + return mutable_operate!(add_sub_op(op), A, *(B, C, D...)) end -function sub_mul end -operate!(::typeof(sub_mul), x, args::Vararg{Any, N}) where {N} = operate!(add_mul, x, -1, args...) - mul_rhs(::typeof(+)) = add_mul mul_rhs(::typeof(-)) = sub_mul @@ -54,22 +51,22 @@ function mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Array{S, N}, B::Ab _check_dims(A, B) return _mutable_operate!(op, A, B, tuple(), tuple()) end -function mutable_operate!(::typeof(add_mul), A::Array{S, N}, B::AbstractArray{T, N}, α::Vararg{Scaling, M}) where {S, T, N, M} +function mutable_operate!(op::AddSubMul, A::Array{S, N}, B::AbstractArray{T, N}, α::Vararg{Scaling, M}) where {S, T, N, M} _check_dims(A, B) - return _mutable_operate!(+, A, B, tuple(), α) + return _mutable_operate!(add_sub_op(op), A, B, tuple(), α) end -function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α::Scaling, B::AbstractArray{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M} +function mutable_operate!(op::AddSubMul, A::Array{S, N}, α::Scaling, B::AbstractArray{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M} _check_dims(A, B) - return _mutable_operate!(+, A, B, (α,), β) + return _mutable_operate!(add_sub_op(op), A, B, (α,), β) end -function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α1::Scaling, α2::Scaling, B::AbstractArray{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M} +function mutable_operate!(op::AddSubMul, A::Array{S, N}, α1::Scaling, α2::Scaling, B::AbstractArray{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M} _check_dims(A, B) - return _mutable_operate!(+, A, B, (α1, α2), β) + return _mutable_operate!(add_sub_op(op), A, B, (α1, α2), β) end # Fallback, we may be able to be more efficient in more cases by adding more specialized methods -function mutable_operate!(::typeof(add_mul), A::Array, x, y, args::Vararg{Any, N}) where N - return mutable_operate!(add_mul, A, x, *(y, args...)) +function mutable_operate!(op::AddSubMul, A::Array, x, y, args::Vararg{Any, N}) where N + return mutable_operate!(op, A, x, *(y, args...)) end # Product @@ -87,14 +84,19 @@ function promote_operation(op::typeof(*), A::Type{<:AbstractArray{S}}, B::Type{< return promote_array_mul(A, B) end +function promote_sum_mul(T::Type, S::Type) + U = promote_operation(*, T, S) + return promote_operation(+, U, U) +end + function promote_array_mul(::Type{Matrix{S}}, ::Type{Vector{T}}) where {S, T} - return Vector{promote_operation(add_mul, S, S, T)} + return Vector{promote_sum_mul(S, T)} end function promote_array_mul(::Type{<:AbstractMatrix{S}}, ::Type{<:AbstractMatrix{T}}) where {S, T} - return Matrix{promote_operation(add_mul, S, S, T)} + return Matrix{promote_sum_mul(S, T)} end function promote_array_mul(::Type{<:AbstractMatrix{S}}, ::Type{<:AbstractVector{T}}) where {S, T} - return Vector{promote_operation(add_mul, S, S, T)} + return Vector{promote_sum_mul(S, T)} end ################################################################################ @@ -211,14 +213,14 @@ end # allocate the resulting array but it redirects to `mul_to!` instead of # `LinearAlgebra.mul!`. function operate(::typeof(*), A::AbstractMatrix{S}, B::AbstractVector{T}) where {T, S} - U = promote_operation(add_mul, S, S, T) + U = promote_sum_mul(S, T) # `similar` gives SparseMatrixCSC if `B` is SparseMatrixCSC #C = similar(B, U, axes(A, 1)) C = Vector{U}(undef, size(A, 1)) return mutable_operate_to!(C, *, A, B) end function operate(::typeof(*), A::AbstractMatrix{S}, B::AbstractMatrix{T}) where {T, S} - U = promote_operation(add_mul, S, S, T) + U = promote_sum_mul(S, T) # `similar` gives SparseMatrixCSC if `B` is SparseMatrixCSC #C = similar(B, U, axes(A, 1), axes(B, 2)) C = Matrix{U}(undef, size(A, 1), size(B, 2)) @@ -235,7 +237,7 @@ _mirror_transpose_or_adjoint(x, ::LinearAlgebra.Transpose) = LinearAlgebra.trans _mirror_transpose_or_adjoint(x, ::LinearAlgebra.Adjoint) = LinearAlgebra.adjoint(x) # dot product function promote_array_mul(::Type{<:TransposeOrAdjoint{S, <:AbstractVector}}, ::Type{<:AbstractVector{T}}) where {S, T} - return promote_operation(add_mul, S, S, T) + return promote_sum_mul(S, T) end function operate(::typeof(*), x::LinearAlgebra.Adjoint{<:Any, <:AbstractVector}, y::AbstractVector) return operate(LinearAlgebra.dot, parent(x), y) @@ -250,13 +252,15 @@ function operate(::typeof(*), x::LinearAlgebra.Transpose{<:Any, <:AbstractVector if lx != length(y) throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y)).")) end + + SumType = promote_sum_mul(eltype(x), eltype(y)) + if iszero(lx) - return zero(promote_operation(add_mul, eltype(x), eltype(y))) + return zero(SumType) end # We need a buffer to hold the intermediate multiplication. - SumType = promote_operation(add_mul, eltype(x), eltype(x), eltype(y)) mul_buffer = buffer_for(add_mul, SumType, eltype(x), eltype(y)) s = zero(SumType) @@ -272,13 +276,14 @@ function operate(::typeof(LinearAlgebra.dot), x::AbstractArray, y::AbstractArray if lx != length(y) throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y)).")) end + if iszero(lx) return LinearAlgebra.dot(zero(eltype(x)), zero(eltype(y))) end # We need a buffer to hold the intermediate multiplication. - SumType = promote_operation(add_mul, eltype(x), eltype(x), eltype(y)) + SumType = promote_sum_mul(eltype(x), eltype(y)) mul_buffer = buffer_for(add_mul, SumType, eltype(x), eltype(y)) s = zero(SumType) diff --git a/src/rewrite.jl b/src/rewrite.jl index c648b86..939c95a 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -34,6 +34,14 @@ struct Zero end function operate(::typeof(add_mul), ::Zero, args::Vararg{Any, N}) where {N} return operate(*, args...) end +function operate(::typeof(sub_mul), ::Zero, x) + # `operate(*, x)` would redirect to `copy_if_mutable(x)` which would be a + # useless copy. + return operate(-, x) +end +function operate(::typeof(sub_mul), ::Zero, x, y, args::Vararg{Any, N}) where {N} + return operate(-, operate(*, x, y, args...)) +end broadcast!(::Union{typeof(add_mul), typeof(+)}, ::Zero, x) = copy_if_mutable(x) broadcast!(::typeof(add_mul), ::Zero, x, y) = x * y @@ -125,24 +133,24 @@ end # See `JuMP._is_sum` _is_sum(s::Symbol) = (s == :sum) || (s == :∑) || (s == :Σ) -function _parse_generator(vectorized::Bool, inner_factor::Expr, current_sum::Union{Nothing, Symbol}, left_factors, right_factors, new_var=gensym()) +function _parse_generator(vectorized::Bool, minus::Bool, inner_factor::Expr, current_sum::Union{Nothing, Symbol}, left_factors, right_factors, new_var=gensym()) @assert isexpr(inner_factor, :call) @assert length(inner_factor.args) > 1 @assert isexpr(inner_factor.args[2], :generator) || isexpr(inner_factor.args[2], :flatten) header = inner_factor.args[1] if _is_sum(header) - _parse_generator_sum(vectorized, inner_factor.args[2], current_sum, left_factors, right_factors, new_var) + _parse_generator_sum(vectorized, minus, inner_factor.args[2], current_sum, left_factors, right_factors, new_var) else error("Expected `sum` outside generator expression; got `$header`.") end end -function _parse_generator_sum(vectorized::Bool, inner_factor::Expr, current_sum::Union{Nothing, Symbol}, left_factors, right_factors, new_var) +function _parse_generator_sum(vectorized::Bool, minus::Bool, inner_factor::Expr, current_sum::Union{Nothing, Symbol}, left_factors, right_factors, new_var) # We used to preallocate the expression at the lowest level of the loop. # When rewriting this some benchmarks revealed that it actually doesn't # seem to help anymore, so might as well keep the code simple. return _start_summing(current_sum, current_sum -> begin - code = rewrite_generator(inner_factor, t -> _rewrite(vectorized, t, current_sum, left_factors, right_factors, current_sum)[2]) + code = rewrite_generator(inner_factor, t -> _rewrite(vectorized, minus, t, current_sum, left_factors, right_factors, current_sum)[2]) return Expr(:block, code, :($new_var = $current_sum)) end) end @@ -179,7 +187,7 @@ Rewrite the expression `x` as specified in [`@rewrite`](@ref). Return the rewritten expression returning the result. """ function rewrite_and_return(x) - output_variable, code = _rewrite(false, x, nothing, [], []) + output_variable, code = _rewrite(false, false, x, nothing, [], []) # We need to use `let` because `rewrite(:(sum(i for i in 1:2))` return quote let @@ -215,13 +223,13 @@ function _has_assignment_in_ref(ex::Expr) end _has_assignment_in_ref(other) = false -function rewrite_sum(vectorized::Bool, terms, current_sum::Union{Nothing, Symbol}, left_factors::Vector, right_factors::Vector, output::Symbol, block = Expr(:block)) +function rewrite_sum(vectorized::Bool, minus::Bool, terms, current_sum::Union{Nothing, Symbol}, left_factors::Vector, right_factors::Vector, output::Symbol, block = Expr(:block)) var = current_sum for term in terms[1:(end-1)] - var, code = _rewrite(vectorized, term, var, left_factors, right_factors) + var, code = _rewrite(vectorized, minus, term, var, left_factors, right_factors) push!(block.args, code) end - new_output, code = _rewrite(vectorized, terms[end], var, left_factors, right_factors, output) + new_output, code = _rewrite(vectorized, minus, terms[end], var, left_factors, right_factors, output) @assert new_output == output push!(block.args, code) return output, block @@ -236,36 +244,34 @@ function _start_summing(current_sum::Symbol, first_term::Function) return first_term(current_sum) end -function _write_add_mul(vectorized, current_sum, left_factors, inner_factors, right_factors, new_var::Symbol) +function _write_add_mul(vectorized, minus, current_sum, left_factors, inner_factors, right_factors, new_var::Symbol) if vectorized f = :(MutableArithmetics.broadcast!) else f = :(MutableArithmetics.operate!) end + op = minus ? :(MutableArithmetics.sub_mul) : :(MutableArithmetics.add_mul) return _start_summing(current_sum, current_sum -> begin - call_expr = Expr(:call, f, :(MutableArithmetics.add_mul), current_sum, left_factors..., inner_factors..., reverse(right_factors)...) + call_expr = Expr(:call, f, op, current_sum, left_factors..., inner_factors..., reverse(right_factors)...) return :($new_var = $call_expr) end) end """ - _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Nothing, Symbol}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym()) + _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Union{Nothing, Symbol}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym()) Return `new_var, code` such that `code` is equivalent to ```julia new_var = prod(left_factors) * inner_factor * prod(reverse(right_factors)) ``` -if `current_sum` is `nothing`, -```julia -new_var = current_sum + prod(left_factors) * inner_factor * prod(reverse(right_factors)) -``` -if `current_sum` is a `Symbol` and `vectorized` is `false` and +if `current_sum` is `nothing`, and is ```julia -new_var = current_sum .+ prod(left_factors) * inner_factor * prod(reverse(right_factors)) +new_var = current_sum op prod(left_factors) * inner_factor * prod(reverse(right_factors)) ``` -otherwise. +otherwise where `op` is `+` if `!vectorized` and `!minus`, `.+` if `vectorized` and `!minus`, +`-` if `!vectorized` and `minus` and `.-` if `vectorized` and `minus`. """ -function _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Symbol, Nothing}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym()) +function _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Union{Symbol, Nothing}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym()) if isexpr(inner_factor, :call) # We need to verfify that `left_factors` and `right_factors` are empty for broadcast, see `_is_decomposable_with_factors`. # We also need to verify that `current_sum` is `nothing` otherwise we are unsure that the elements in the containers have been copied, e.g., in @@ -274,7 +280,7 @@ function _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Symbol, Not (current_sum === nothing && isempty(left_factors) && isempty(right_factors) && (inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-))) block = Expr(:block) if length(inner_factor.args) > 2 # not unary addition or subtraction - next_sum, code = _rewrite(vectorized, inner_factor.args[2], current_sum, left_factors, right_factors) + next_sum, code = _rewrite(vectorized, minus, inner_factor.args[2], current_sum, left_factors, right_factors) push!(block.args, code) start = 3 else @@ -283,9 +289,9 @@ function _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Symbol, Not end vectorized = vectorized || inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-) if inner_factor.args[1] == :- || inner_factor.args[1] == :(.-) - left_factors = vcat(-1, left_factors) + minus = !minus end - return rewrite_sum(vectorized, inner_factor.args[start:end], next_sum, left_factors, right_factors, new_var, block) + return rewrite_sum(vectorized, minus, inner_factor.args[start:end], next_sum, left_factors, right_factors, new_var, block) elseif inner_factor.args[1] == :* # FIXME && !vectorized ? # we might need to recurse on multiple arguments, e.g., # (x+y)*(x+y) @@ -297,8 +303,7 @@ function _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Symbol, Not _is_decomposable_with_factors(inner_factor.args[i]) end return _rewrite( - vectorized, - inner_factor.args[which_idx], current_sum, + vectorized, minus, inner_factor.args[which_idx], current_sum, vcat(left_factors, [esc(inner_factor.args[i]) for i in 2:(which_idx - 1)]), vcat(right_factors, [esc(inner_factor.args[i]) for i in length(inner_factor.args):-1:(which_idx + 1)]), new_var) @@ -314,7 +319,7 @@ function _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Symbol, Not end end push!(blk.args, _write_add_mul( - vectorized, current_sum, left_factors, + vectorized, minus, current_sum, left_factors, inner_factor.args[2:end], right_factors, new_var )) return new_var, blk @@ -324,18 +329,18 @@ function _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Symbol, Not if inner_factor.args[3] == 2 new_var_, parsed = rewrite(inner_factor.args[2]) square_expr = _write_add_mul( - vectorized, current_sum, left_factors, + vectorized, minus, current_sum, left_factors, (new_var_, new_var_), right_factors, new_var ) return new_var, Expr(:block, parsed, square_expr) elseif inner_factor.args[3] == 1 - return _rewrite(vectorized, :(convert($MulType, $(inner_factor.args[2]))), current_sum, left_factors, right_factors, new_var) + return _rewrite(vectorized, minus, :(convert($MulType, $(inner_factor.args[2]))), current_sum, left_factors, right_factors, new_var) elseif inner_factor.args[3] == 0 - return _rewrite(vectorized, :(one($MulType)), current_sum, left_factors, right_factors, new_var) + return _rewrite(vectorized, minus, :(one($MulType)), current_sum, left_factors, right_factors, new_var) else new_var_, parsed = rewrite(inner_factor.args[2]) power_expr = _write_add_mul( - vectorized, current_sum, left_factors, + vectorized, minus, current_sum, left_factors, (Expr(:call, :^, new_var_, esc(inner_factor.args[3])),), right_factors, new_var ) @@ -345,9 +350,9 @@ function _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Symbol, Not @assert length(inner_factor.args) == 3 numerator = inner_factor.args[2] denom = inner_factor.args[3] - return _rewrite(vectorized, numerator, current_sum, left_factors, vcat(esc(:(1 / $denom)), right_factors), new_var) + return _rewrite(vectorized, minus, numerator, current_sum, left_factors, vcat(esc(:(1 / $denom)), right_factors), new_var) elseif length(inner_factor.args) >= 2 && (isexpr(inner_factor.args[2], :generator) || isexpr(inner_factor.args[2], :flatten)) - return new_var, _parse_generator(vectorized, inner_factor, current_sum, left_factors, right_factors, new_var) + return new_var, _parse_generator(vectorized, minus, inner_factor, current_sum, left_factors, right_factors, new_var) end elseif isexpr(inner_factor, :curly) Base.error("The curly syntax (sum{},prod{},norm2{}) is no longer supported. Expression: `$inner_factor`.") @@ -359,5 +364,5 @@ function _rewrite(vectorized::Bool, inner_factor, current_sum::Union{Symbol, Not error("Unexpected assignment in expression `$inner_factor`.") end # at the lowest level - return new_var, _write_add_mul(vectorized, current_sum, left_factors, (esc(inner_factor),), right_factors, new_var) + return new_var, _write_add_mul(vectorized, minus, current_sum, left_factors, (esc(inner_factor),), right_factors, new_var) end diff --git a/src/shortcuts.jl b/src/shortcuts.jl index 5e5c908..ab3a05a 100644 --- a/src/shortcuts.jl +++ b/src/shortcuts.jl @@ -12,6 +12,20 @@ Return the sum of `a`, `b`, ..., possibly modifying `a`. """ add!(args::Vararg{Any, N}) where {N} = operate!(+, args...) +""" + sub_to!(output, a, b) + +Return the `a - b`, possibly modifying `output`. +""" +sub_to!(output, a, b) = operate_to!(output, -, a, b) + +""" + sub!(a, b) + +Return `a - b`, possibly modifying `a`. +""" +sub!(a, b) = operate!(-, a, b) + """ mul_to!(a, b, c) @@ -35,14 +49,14 @@ mul(args::Vararg{Any, N}) where {N} = operate(*, args...) # `Vararg` gives extra allocations on Julia v1.3, see https://travis-ci.com/JuliaOpt/MutableArithmetics.jl/jobs/260666164#L215-L238 -function promote_operation(::typeof(add_mul), T::Type, x::Type, y::Type) - return promote_operation(+, T, promote_operation(*, x, y)) +function promote_operation(op::AddSubMul, T::Type, x::Type, y::Type) + return promote_operation(add_sub_op(op), T, promote_operation(*, x, y)) end -function promote_operation(::typeof(add_mul), x::Type{<:AbstractArray}, y::Type{<:AbstractArray}) - return promote_operation(+, x, y) +function promote_operation(op::AddSubMul, x::Type{<:AbstractArray}, y::Type{<:AbstractArray}) + return promote_operation(add_sub_op(op), x, y) end -function promote_operation(::typeof(add_mul), T::Type, args::Vararg{Type, N}) where N - return promote_operation(+, T, promote_operation(*, args...)) +function promote_operation(op::AddSubMul, T::Type, args::Vararg{Type, N}) where N + return promote_operation(add_sub_op(op), T, promote_operation(*, args...)) end """ @@ -77,6 +91,38 @@ function add_mul_buf!(buffer, args::Vararg{Any, N}) where {N} buffered_operate!(buffer, add_mul, args...) end +""" + sub_mul_to!(output, args...) + +Return `sub_mul(args...)`, possibly modifying `output`. +""" +sub_mul_to!(output, args::Vararg{Any, N}) where {N} = operate_to!(output, sub_mul, args...) + +""" + sub_mul!(args...) + +Return `sub_mul(args...)`, possibly modifying `args[1]`. +""" +sub_mul!(args::Vararg{Any, N}) where {N} = operate!(sub_mul, args...) + +""" + sub_mul_buf_to!(buffer, output, args...) + +Return `sub_mul(args...)`, possibly modifying `output` and `buffer`. +""" +function sub_mul_buf_to!(buffer, output, args::Vararg{Any, N}) where {N} + buffered_operate_to!(buffer, output, sub_mul, args...) +end + +""" + sub_mul_buf!(buffer, args...) + +Return `sub_mul(args...)`, possibly modifying `args[1]` and `buffer`. +""" +function sub_mul_buf!(buffer, args::Vararg{Any, N}) where {N} + buffered_operate!(buffer, sub_mul, args...) +end + """ zero!(a)