diff --git a/src/MutableArithmetics.jl b/src/MutableArithmetics.jl index 3538466..8cc8b0b 100644 --- a/src/MutableArithmetics.jl +++ b/src/MutableArithmetics.jl @@ -35,14 +35,11 @@ include("linear_algebra.jl") include("sparse_arrays.jl") isequal_canonical(a, b) = a == b -function isequal_canonical(a::Array{T, N}, b::Array{T, N}) where {T, N} +function isequal_canonical(a::AT, b::AT) where AT <: Union{Array, LinearAlgebra.Symmetric} return all(zip(a, b)) do elements return isequal_canonical(elements...) end end -function isequal_canonical(a::LinearAlgebra.Symmetric, b::LinearAlgebra.Symmetric) - return isequal_canonical(parent(a), parent(b)) -end include("rewrite.jl") include("dispatch.jl") diff --git a/src/Test/array.jl b/src/Test/array.jl index 7784d37..9bb9f39 100644 --- a/src/Test/array.jl +++ b/src/Test/array.jl @@ -313,6 +313,14 @@ function symmetric_unary_test(x) end end +function symmetric_add_test(x) + if x isa AbstractMatrix && size(x, 1) == size(x, 2) + y = LinearAlgebra.Symmetric(x) + add_test(y, y) + add_test(x, y) + end +end + function matrix_uniform_scaling_test(x) if !(x isa AbstractMatrix && size(x, 1) == size(x, 2)) return @@ -347,6 +355,7 @@ const array_tests = Dict( "broadcast_division" => broadcast_division_test, "unary" => unary_test, "symmetric_unary" => symmetric_unary_test, + "symmetric_add" => symmetric_add_test, "matrix_uniform_scaling" => matrix_uniform_scaling_test, "symmetric_matrix_uniform_scaling" => symmetric_matrix_uniform_scaling_test ) diff --git a/src/interface.jl b/src/interface.jl index 65fb131..bed8378 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -10,6 +10,11 @@ Returns the type returned to the call `operate(op, args...)` where the types of the arguments `args` are `ArgsTypes`. """ function promote_operation end +function promote_operation(op::Function, x::Type{<:AbstractArray}, y::Type{<:AbstractArray}) + # `zero` is not defined for `AbstractArray` so the fallback would fail with a cryptic MethodError. + # We replace it by a more helpful error here. + error("`promote_operation($op, $x, $y)` not implemented yet, please report this.") +end # Julia v1.0.x has trouble with inference with the `Vararg` method, see # https://travis-ci.org/JuliaOpt/JuMP.jl/jobs/617606373 function promote_operation(op::Function, x::Type, y::Type) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 7159e8e..cc8223a 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -12,6 +12,9 @@ end function promote_operation(op::Union{typeof(+), typeof(-)}, ::Type{Matrix{T}}, ::Type{LinearAlgebra.UniformScaling{S}}) where {S, T} return Matrix{promote_operation(op, S, T)} end +function promote_operation(op::Union{typeof(+), typeof(-)}, ::Type{Matrix{T}}, ::Type{<:LinearAlgebra.Symmetric{S}}) where {S, T} + return Matrix{promote_operation(op, S, T)} +end # Only `Scaling` function mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Matrix, B::LinearAlgebra.UniformScaling) @@ -32,23 +35,35 @@ mul_rhs(::typeof(+)) = add_mul mul_rhs(::typeof(-)) = sub_mul # `Scaling` and `Array` -function _mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Array{S, N}, B::Array{T, N}, left_factors::Tuple, right_factors::Tuple) where {S, T, N} +function _mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Array{S, N}, + B::Union{Array{T, N}, LinearAlgebra.Symmetric{T}}, + left_factors::Tuple, right_factors::Tuple) where {S, T, N} for i in eachindex(A) A[i] = operate!(mul_rhs(op), A[i], left_factors..., B[i], right_factors...) end return A end +function _check_dims(A, B) + if size(A) != size(B) + throw(DimensionMismatch("Cannot sum matrices of size `$(size(A))` and size `$(size(B))`, the size of the two matrices must be equal.")) + end +end + function mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Array{S, N}, B::AbstractArray{T, N}) where {S, T, N} + _check_dims(A, B) return _mutable_operate!(op, A, B, tuple(), tuple()) end function mutable_operate!(::typeof(add_mul), A::Array{S, N}, B::AbstractArray{T, N}, α::Vararg{Scaling, M}) where {S, T, N, M} + _check_dims(A, B) return _mutable_operate!(+, A, B, tuple(), α) end function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α::Scaling, B::AbstractArray{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M} + _check_dims(A, B) return _mutable_operate!(+, A, B, (α,), β) end function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α1::Scaling, α2::Scaling, B::AbstractArray{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M} + _check_dims(A, B) return _mutable_operate!(+, A, B, (α1, α2), β) end @@ -59,6 +74,7 @@ end # Product +similar_array_type(::Type{LinearAlgebra.Symmetric{T, MT}}, ::Type{S}) where {S, T, MT} = LinearAlgebra.Symmetric{S, similar_array_type(MT, S)} similar_array_type(::Type{Array{T, N}}, ::Type{S}) where {S, T, N} = Array{S, N} function promote_operation(op::typeof(*), A::Type{<:AbstractArray{T}}, ::Type{S}) where {S, T} return similar_array_type(A, promote_operation(op, T, S)) diff --git a/src/shortcuts.jl b/src/shortcuts.jl index eac3604..ef9237a 100644 --- a/src/shortcuts.jl +++ b/src/shortcuts.jl @@ -30,6 +30,9 @@ mul!(args::Vararg{Any, N}) where {N} = operate!(*, args...) function promote_operation(::typeof(add_mul), T::Type, x::Type, y::Type) return promote_operation(+, T, promote_operation(*, x, y)) end +function promote_operation(::typeof(add_mul), x::Type{<:AbstractArray}, y::Type{<:AbstractArray}) + return promote_operation(+, x, y) +end function promote_operation(::typeof(add_mul), T::Type, args::Vararg{Type, N}) where N return promote_operation(+, T, promote_operation(*, args...)) end diff --git a/test/matmul.jl b/test/matmul.jl index febdba0..3467ab2 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -4,6 +4,23 @@ const MA = MutableArithmetics include("utilities.jl") +struct CustomArray{T, N} <: AbstractArray{T, N} end + +@testset "Errors" begin + @testset "`promote_op` error" begin + AT = CustomArray{Int, 3} + err = ErrorException("`promote_operation(+, CustomArray{Int64,3}, CustomArray{Int64,3})` not implemented yet, please report this.") + @test_throws err MA.promote_operation(+, AT, AT) + end + + @testset "Dimension mismatch" begin + A = zeros(1, 1) + B = zeros(2, 2) + err = DimensionMismatch("Cannot sum matrices of size `(1, 1)` and size `(2, 2)`, the size of the two matrices must be equal.") + @test_throws err MA.@rewrite A + B + end +end + @testset "Matrix multiplication" begin @testset "matrix-vector product" begin A = [1 1 1; 1 1 1; 1 1 1]