From 7ee6d2892fbd6187d55d0717a2da9fe4784065ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 18 Nov 2019 15:56:42 +0100 Subject: [PATCH] Add @rewrite macro --- Project.toml | 4 +- src/MutableArithmetics.jl | 6 + src/dispatch.jl | 9 + src/interface.jl | 7 + src/rewrite.jl | 230 +++++++++++++++++ src/shortcuts.jl | 2 + test/rewrite.jl | 501 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 8 files changed, 759 insertions(+), 1 deletion(-) create mode 100644 src/dispatch.jl create mode 100644 src/rewrite.jl create mode 100644 test/rewrite.jl diff --git a/Project.toml b/Project.toml index a676ef04..26ee5c54 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" julia = "1.1" [extras] +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "OffsetArrays", "SparseArrays"] diff --git a/src/MutableArithmetics.jl b/src/MutableArithmetics.jl index 6d6c197a..a818c536 100644 --- a/src/MutableArithmetics.jl +++ b/src/MutableArithmetics.jl @@ -22,4 +22,10 @@ include("Test/Test.jl") include("bigint.jl") include("linear_algebra.jl") +isequal_canonical(a, b) = a == b + +include("rewrite.jl") + +include("dispatch.jl") + end # module diff --git a/src/dispatch.jl b/src/dispatch.jl new file mode 100644 index 00000000..91fc802d --- /dev/null +++ b/src/dispatch.jl @@ -0,0 +1,9 @@ +abstract type AbstractMutable end + +# Special-case because the the base version wants to do fill!(::Array{AbstractVariableRef}, zero(GenericAffExpr{Float64,eltype(x)})) +_one_indexed(A) = all(x -> isa(x, Base.OneTo), axes(A)) +function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable}) + @assert _one_indexed(x) # Base.diagm doesn't work for non-one-indexed arrays in general. + ZeroType = promote_operation(zero, eltype(x)) + return diagm(0 => copyto!(similar(x, ZeroType), x)) +end diff --git a/src/interface.jl b/src/interface.jl index 0d4961b4..4b22d8f1 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -13,6 +13,13 @@ function promote_operation end function promote_operation(op::Function, args::Vararg{Type, N}) where N return typeof(op(zero.(args)...)) end +#promote_operation(::typeof(*), ::Type{T}) where {T} = T +#function promote_operation(op::typeof(*), ::Type{Array{T, N}}, ::Type{S}) where {S, T, N} +# return Array{promote_operation(op, T, S), N} +#end +#function promote_operation(op::typeof(*), ::Type{S}, ::Type{Array{T, N}}) where {S, T, N} +# return Array{promote_operation(op, S, T), N} +#end # Define Traits abstract type MutableTrait end diff --git a/src/rewrite.jl b/src/rewrite.jl new file mode 100644 index 00000000..573f92d4 --- /dev/null +++ b/src/rewrite.jl @@ -0,0 +1,230 @@ +# Heavily inspired from `JuMP/src/parse_expr.jl` code. + +export @rewrite +macro rewrite(expr) + return rewrite(expr) +end + +struct Zero end +# We need to copy `x` as it will be used as might be given by the user and be +# given as first argument of `operate!`. +Base.:(+)(zero::Zero, x) = copy(x) +# `add_mul(zero, ...)` redirects to `muladd(..., zero)` which calls `... + zero`. +Base.:(+)(x, zero::Zero) = copy(x) + +using Base.Meta + +function _parse_idx_set(arg::Expr) + parse_done, idxvar, idxset = Containers._try_parse_idx_set(arg) + if parse_done + return idxvar, idxset + end + error("Invalid syntax: $arg") +end + +# takes a generator statement and returns a properly nested for loop +# with nested filters as specified +function _parse_gen(ex, atleaf) + if isexpr(ex, :flatten) + return _parse_gen(ex.args[1], atleaf) + end + if !isexpr(ex, :generator) + return atleaf(ex) + end + function itrsets(sets) + if isa(sets, Expr) + return sets + elseif length(sets) == 1 + return sets[1] + else + return Expr(:block, sets...) + end + end + + idxvars = [] + if isexpr(ex.args[2], :filter) # if condition + loop = Expr(:for, esc(itrsets(ex.args[2].args[2:end])), + Expr(:if, esc(ex.args[2].args[1]), + _parse_gen(ex.args[1], atleaf))) + for idxset in ex.args[2].args[2:end] + idxvar, s = _parse_idx_set(idxset) + push!(idxvars, idxvar) + end + else + loop = Expr(:for, esc(itrsets(ex.args[2:end])), + _parse_gen(ex.args[1], atleaf)) + for idxset in ex.args[2:end] + idxvar, s = _parse_idx_set(idxset) + push!(idxvars, idxvar) + end + end + return loop +end + +function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff=gensym()) + @assert isexpr(x,:call) + @assert length(x.args) > 1 + @assert isexpr(x.args[2],:generator) || isexpr(x.args[2],:flatten) + header = x.args[1] + if _is_sum(header) + _parse_generator_sum(x.args[2], aff, lcoeffs, rcoeffs, newaff) + else + error("Expected sum outside generator expression; got $header") + end +end + +function _parse_generator_sum(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff) + # We used to preallocate the expression at the lowest level of the loop. + # When rewriting this some benchmarks revealed that it actually doesn't + # seem to help anymore, so might as well keep the code simple. + code = _parse_gen(x, t -> _rewrite(t, aff, lcoeffs, rcoeffs, aff)[2]) + return :($code; $newaff=$aff) +end + +_is_complex_expr(ex) = isa(ex, Expr) && !isexpr(ex, :ref) + +function rewrite(x) + variable = gensym() + new_variable, code = _rewrite_toplevel(x, variable) + return quote + $variable = Zero() + $code + $new_variable + end +end + +_rewrite_toplevel(x, variable::Symbol) = _rewrite(x, variable, [], []) + +function _is_comparison(ex::Expr) + if isexpr(ex, :comparison) + return true + elseif isexpr(ex, :call) + if ex.args[1] in (:<=, :≤, :>=, :≥, :(==)) + return true + else + return false + end + else + return false + end +end + +# x[i=1] <= 2 is a somewhat common user error. Catch it here. +function _has_assignment_in_ref(ex::Expr) + if isexpr(ex, :ref) + return any(x -> isexpr(x, :(=)), ex.args) + else + return any(_has_assignment_in_ref, ex.args) + end +end +_has_assignment_in_ref(other) = false + +# output is assigned to newaff +function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symbol=gensym()) + if isexpr(x, :call) + if x.args[1] == :+ + b = Expr(:block) + aff_ = aff + for arg in x.args[2:(end-1)] + aff_, code = _rewrite(arg, aff_, lcoeffs, rcoeffs) + push!(b.args, code) + end + newaff, code = _rewrite(x.args[end], aff_, lcoeffs, rcoeffs, newaff) + push!(b.args, code) + return newaff, b + elseif x.args[1] == :- + if length(x.args) == 2 # unary subtraction + return _rewrite(x.args[2], aff, vcat(-1.0, lcoeffs), rcoeffs, newaff) + else # a - b - c ... + b = Expr(:block) + aff_, code = _rewrite(x.args[2], aff, lcoeffs, rcoeffs) + push!(b.args, code) + for arg in x.args[3:(end-1)] + aff_,code = _rewrite(arg, aff_, vcat(-1.0, lcoeffs), rcoeffs) + push!(b.args, code) + end + newaff,code = _rewrite(x.args[end], aff_, vcat(-1.0, lcoeffs), rcoeffs, newaff) + push!(b.args, code) + return newaff, b + end + elseif x.args[1] == :* + # we might need to recurse on multiple arguments, e.g., + # (x+y)*(x+y) + n_expr = mapreduce(_is_complex_expr, +, x.args) + if n_expr == 1 # special case, only recurse on one argument and don't create temporary objects + which_idx = 0 + for i in 2:length(x.args) + if _is_complex_expr(x.args[i]) + which_idx = i + end + end + return _rewrite( + x.args[which_idx], aff, + vcat(lcoeffs, [esc(x.args[i]) for i in 2:(which_idx - 1)]), + vcat(rcoeffs, [esc(x.args[i]) for i in (which_idx + 1):length(x.args)]), + newaff) + else + blk = Expr(:block) + for i in 2:length(x.args) + if _is_complex_expr(x.args[i]) + s = gensym() + newaff_, parsed = _rewrite_toplevel(x.args[i], s) + push!(blk.args, :($s = 0.0; $parsed)) + x.args[i] = newaff_ + else + x.args[i] = esc(x.args[i]) + end + end + callexpr = Expr(:call, :operate!, add_mul, aff, + lcoeffs..., x.args[2:end]..., rcoeffs...) + push!(blk.args, :($newaff = $callexpr)) + return newaff, blk + end + elseif x.args[1] == :^ && _is_complex_expr(x.args[2]) + MulType = :(MA.promote_operation(*, typeof($(x.args[2])), typeof($(x.args[2])))) + if x.args[3] == 2 + blk = Expr(:block) + s = gensym() + newaff_, parsed = _rewrite_toplevel(x.args[2], s) + push!(blk.args, :($s = Zero(); $parsed)) + push!(blk.args, :($newaff = operate!(add_mul, + $aff, $(Expr(:call, :*, lcoeffs..., newaff_, newaff_, + rcoeffs...))))) + return newaff, blk + elseif x.args[3] == 1 + return _rewrite(:(convert($MulType, $(x.args[2]))), aff, lcoeffs, rcoeffs) + elseif x.args[3] == 0 + return _rewrite(:(one($MulType)), aff, lcoeffs, rcoeffs) + else + blk = Expr(:block) + s = gensym() + newaff_, parsed = _rewrite_toplevel(x.args[2], s) + push!(blk.args, :($s = Zero(); $parsed)) + push!(blk.args, :($newaff = _destructive_add_with_reorder!( + $aff, $(Expr(:call, :*, lcoeffs..., + Expr(:call, :^, newaff_, esc(x.args[3])), + rcoeffs...))))) + return newaff, blk + end + elseif x.args[1] == :/ + @assert length(x.args) == 3 + numerator = x.args[2] + denom = x.args[3] + return _rewrite(numerator, aff, lcoeffs, vcat(esc(:(1 / $denom)), rcoeffs), newaff) + elseif length(x.args) >= 2 && (isexpr(x.args[2], :generator) || isexpr(x.args[2], :flatten)) + return newaff, _parse_generator(x,aff,lcoeffs,rcoeffs,newaff) + end + elseif isexpr(x, :curly) + _error_curly(x) + end + if isa(x, Expr) && _is_comparison(x) + error("Unexpected comparison in expression $x.") + end + if isa(x, Expr) && _has_assignment_in_ref(x) + @warn "Unexpected assignment in expression $x. This will" * + " become a syntax error in a future release." + end + # at the lowest level + callexpr = Expr(:call, :operate!, add_mul, aff, lcoeffs..., esc(x), rcoeffs...) + return newaff, :($newaff = $callexpr) +end diff --git a/src/shortcuts.jl b/src/shortcuts.jl index 1724060f..c164afe7 100644 --- a/src/shortcuts.jl +++ b/src/shortcuts.jl @@ -32,7 +32,9 @@ mul!(args::Vararg{Any, N}) where {N} = operate!(*, args...) Return `a + *(args...)`. Note that `add_mul(a, b, c) = muladd(b, c, a)`. """ function add_mul end +add_mul(a, b) = a + b add_mul(a, b, c) = muladd(b, c, a) +add_mul(a, b, c::Vararg{Any, N}) where {N} = add_mul(a, b *(c...)) function promote_operation(::typeof(add_mul), T::Type, args::Vararg{Type, N}) where N return promote_operation(+, T, promote_operation(*, args...)) diff --git a/test/rewrite.jl b/test/rewrite.jl new file mode 100644 index 00000000..6469879f --- /dev/null +++ b/test/rewrite.jl @@ -0,0 +1,501 @@ +using SparseArrays, Test +import MutableArithmetics +const MA = MutableArithmetics + +macro test_rewrite(expr) + esc(quote + @test MA.isequal_canonical(MA.@rewrite($expr), $expr) + end) +end + +function basic_operators_test(w, x, y, z) + aff = @inferred 7.1 * x + 2.5 + @test_rewrite 7.1 * x + 2.5 + aff2 = @inferred 1.2 * y + 1.2 + @test_rewrite 1.2 * y + 1.2 + q = @inferred 2.5 * y * z + aff + @test_rewrite 2.5 * y * z + aff + q2 = @inferred 8 * x * z + aff2 + @test_rewrite 8 * x * z + aff2 + @test_rewrite 2 * x * x + 1 * y * y + z + 3 + + @testset "Comparison" begin + @testset "iszero" begin + @test !iszero(x) + @test !iszero(aff) + @test iszero(zero(aff)) + @test !iszero(q) + @test iszero(zero(q)) + end + + @testset "isequal_canonical" begin + @test MA.isequal_canonical((@inferred 3w + 2y), @inferred 2y + 3w) + @test !MA.isequal_canonical((@inferred 3w + 2y + 1), @inferred 3w + 2y) + @test !MA.isequal_canonical((@inferred 3w + 2y), @inferred 3y + 2w) + @test !MA.isequal_canonical((@inferred 3w + 2y), @inferred 3w + y) + + @test !MA.isequal_canonical(aff, aff2) + @test !MA.isequal_canonical(aff2, aff) + + @test MA.isequal_canonical(q, @inferred 2.5z*y + aff) + @test !MA.isequal_canonical(q, @inferred 2.5y*z + aff2) + @test !MA.isequal_canonical(q, @inferred 2.5x*z + aff) + @test !MA.isequal_canonical(q, @inferred 2.5y*x + aff) + @test !MA.isequal_canonical(q, @inferred 1.5y*z + aff) + @test MA.isequal_canonical(q2, @inferred 8z*x + aff2) + @test !MA.isequal_canonical(q2, @inferred 8x*z + aff) + @test !MA.isequal_canonical(q2, @inferred 7x*z + aff2) + @test !MA.isequal_canonical(q2, @inferred 8x*y + aff2) + @test !MA.isequal_canonical(q2, @inferred 8y*z + aff2) + end + end + + # Different objects that must all interact: + # 1. Number + # 2. Variable + # 3. AffExpr + # 4. QuadExpr + + # 1. Number tests + @testset "Number--???" begin + # 1-1 Number--Number - nope! + # 1-2 Number--Variable + @test_rewrite 4.13 + w + @test_rewrite 3.16 - w + @test_rewrite 5.23 * w + # 1-3 Number--AffExpr + @test_rewrite 1.5 + aff + @test_rewrite 1.5 - aff + @test_rewrite 2 * aff + # 1-4 Number--QuadExpr + @test_rewrite 1.5 + q + @test_rewrite 1.5 - q + @test_rewrite 2 * q + end + + # 2. Variable tests + @testset "Variable--???" begin + # 2-0 Variable unary + @test (+x) === x + @test_rewrite -x + # 2-1 Variable--Number + @test_rewrite w + 4.13 + @test_rewrite w - 4.13 + @test_rewrite w * 4.13 + @test_rewrite w / 2.00 + @test w == w + @test_rewrite x*y - 1 + @test_rewrite x^2 + @test_rewrite x^1 + @test_rewrite x^0 + # 2-2 Variable--Variable + @test_rewrite w + x + @test_rewrite w - x + @test_rewrite w * x + @test_rewrite x - x + @test_rewrite y*z - x + # 2-3 Variable--AffExpr + @test_rewrite z + aff + @test_rewrite z - aff + @test_rewrite z * aff + @test_rewrite 7.1 * x - aff + # 2-4 Variable--QuadExpr + @test_rewrite w + q + @test_rewrite w - q + end + + # 3. AffExpr tests + @testset "AffExpr--???" begin + # 3-0 AffExpr unary + @test_rewrite +aff + @test_rewrite -aff + # 3-1 AffExpr--Number + @test_rewrite aff + 1.5 + @test_rewrite aff - 1.5 + @test_rewrite aff * 2 + @test_rewrite aff / 2 + @test aff == aff + @test_rewrite aff - 1 + @test_rewrite aff^2 + @test_rewrite (7.1*x + 2.5)^2 + @test_rewrite aff^1 + @test_rewrite (7.1*x + 2.5)^1 + @test_rewrite aff^0 + @test_rewrite (7.1*x + 2.5)^0 + # 3-2 AffExpr--Variable + @test_rewrite aff + z + @test_rewrite aff - z + @test_rewrite aff * z + @test_rewrite aff - 7.1 * x + # 3-3 AffExpr--AffExpr + @test_rewrite aff + aff2 + @test_rewrite aff - aff2 + @test_rewrite aff * aff2 + @test string((x+x)*(x+3)) == string((x+3)*(x+x)) # Issue #288 + @test_rewrite aff-aff + # 4-4 AffExpr--QuadExpr + @test_rewrite aff2 + q + @test_rewrite aff2 - q + end + + # 4. QuadExpr + # TODO: This test block and others above should be rewritten to be + # self-contained. The definitions of q, w, and aff2 are too far to + # easily check correctness of the tests. + @testset "QuadExpr--???" begin + # 4-0 QuadExpr unary + @test_rewrite +q + @test_rewrite -q + # 4-1 QuadExpr--Number + @test_rewrite q + 1.5 + @test_rewrite q - 1.5 + @test_rewrite q * 2 + @test_rewrite q / 2 + @test q == q + @test_rewrite aff2 - q + # 4-2 QuadExpr--Variable + @test_rewrite q + w + @test_rewrite q - w + # 4-3 QuadExpr--AffExpr + @test_rewrite q + aff2 + @test_rewrite q - aff2 + # 4-4 QuadExpr--QuadExpr + @test_rewrite q + q2 + @test_rewrite q - q2 + end +end + +function sum_test(matrix) + @testset "sum(j::DenseAxisArray{Variable})" begin + @test_rewrite sum(matrix) + end + @testset "sum(affs::Array{AffExpr})" begin + @test_rewrite sum([2matrix[i, j] for i in 1:size(matrix, 1), j in 1:size(matrix, 2)]) + end + @testset "sum(quads::Array{QuadExpr})" begin + @test_rewrite sum([2matrix[i, j]^2 for i in 1:size(matrix, 1), j in 1:size(matrix, 2)]) + end +end + +function dot_test(x, y, z) + @test_rewrite dot(x[1], x[1]) + @test_rewrite dot(2, x[1]) + @test_rewrite dot(x[1], 2) + + c = vcat(1:3) + @test_rewrite dot(c, x) + @test_rewrite dot(x, c) + + A = [1 3 ; 2 4] + @test_rewrite dot(A, y) + @test_rewrite dot(y, A) + + B = ones(2, 2, 2) + @test_rewrite dot(B, z) + @test_rewrite dot(z, B) + + @test_rewrite dot(x, ones(3)) - dot(y, ones(2,2)) +end + +# JuMP issue #656 +function issue_656(x) + floats = Float64[i for i in 1:2] + anys = Array{Any}(undef, 2) + anys[1] = 10 + anys[2] = 20 + x + @test dot(floats, anys) == 10 + 40 + 2x +end + +function transpose_test(x, y, z) + @test MA.isequal_canonical(x', [x[1] x[2] x[3]]) + @test MA.isequal_canonical(copy(transpose(x)), [x[1] x[2] x[3]]) + @test MA.isequal_canonical(y', [y[1,1] y[2,1] + y[1,2] y[2,2] + y[1,3] y[2,3]]) + @test MA.isequal_canonical(copy(transpose(y)), + [y[1,1] y[2,1] + y[1,2] y[2,2] + y[1,3] y[2,3]]) + @test (z')' == z + @test transpose(transpose(z)) == z +end + +function vectorized_test(x, X11, X23, Xd) + A = [2 1 0 + 1 2 1 + 0 1 2] + B = sparse(A) + X = sparse([1, 2], [1, 3], [X11, X23], 3, 3) # for testing Variable + # FIXME + #@test MA.isequal_canonical([X11 0. 0.; 0. 0. X23; 0. 0. 0.], @inferred MA._densify_with_jump_eltype(X)) + Y = sparse([1, 2], [1, 3], [2X11, 4X23], 3, 3) # for testing GenericAffExpr + Yd = [2X11 0 0 + 0 0 4X23 + 0 0 0] + Z = sparse([1, 2], [1, 3], [X11^2, 2X23^2], 3, 3) # for testing GenericQuadExpr + Zd = [X11^2 0 0 + 0 0 2X23^2 + 0 0 0] + v = [4, 5, 6] + + @testset "Sum of matrices" begin + @test_rewrite(Xd + Yd) + @test_rewrite(Xd + 2Yd) + @test_rewrite(Xd + Yd * 2) + @test_rewrite(Yd + Xd) + @test_rewrite(Yd + 2Xd) + @test_rewrite(Yd + Xd * 2) + @test_rewrite(Yd + Zd) + @test_rewrite(Yd + 2Zd) + @test_rewrite(Yd + Zd * 2) + @test_rewrite(Zd + Yd) + @test_rewrite(Zd + 2Yd) + @test_rewrite(Zd + Yd * 2) + @test_rewrite(Zd + Xd) + @test_rewrite(Zd + 2Xd) + @test_rewrite(Zd + Xd * 2) + @test_rewrite(Xd + Zd) + @test_rewrite(Xd + 2Zd) + @test_rewrite(Xd + Zd * 2) + end + + @test MA.isequal_canonical(A*x, [2x[1] + x[2] + 2x[2] + x[1] + x[3] + x[2] + 2x[3]]) + @test MA.isequal_canonical(A*x, B*x) + @test MA.isequal_canonical(A*x, MA.@rewrite(B*x)) + @test MA.isequal_canonical(MA.@rewrite(A*x), MA.@rewrite(B*x)) + @test MA.isequal_canonical(x'*A, [2x[1]+x[2]; 2x[2]+x[1]+x[3]; x[2]+2x[3]]') + @test MA.isequal_canonical(x'*A, x'*B) + @test MA.isequal_canonical(x'*A, MA.@rewrite(x'*B)) + @test MA.isequal_canonical(MA.@rewrite(x'*A), MA.@rewrite(x'*B)) + @test MA.isequal_canonical(x'*A*x, 2x[1]*x[1] + 2x[1]*x[2] + 2x[2]*x[2] + 2x[2]*x[3] + 2x[3]*x[3]) + @test MA.isequal_canonical(x'A*x, x'*B*x) + @test MA.isequal_canonical(x'*A*x, MA.@rewrite(x'*B*x)) + @test MA.isequal_canonical(MA.@rewrite(x'*A*x), MA.@rewrite(x'*B*x)) + + y = A*x + @test MA.isequal_canonical(-x, [-x[1], -x[2], -x[3]]) + @test MA.isequal_canonical(-y, [-2x[1] - x[2] + -x[1] - 2x[2] - x[3] + -x[2] - 2x[3]]) + @test MA.isequal_canonical(y .+ 1, [2x[1] + x[2] + 1 + x[1] + 2x[2] + x[3] + 1 + x[2] + 2x[3] + 1]) + @test MA.isequal_canonical(y .- 1, [ + 2x[1] + x[2] - 1 + x[1] + 2x[2] + x[3] - 1 + x[2] + 2x[3] - 1]) + @test MA.isequal_canonical(y .+ 2ones(3), [2x[1] + x[2] + 2 + x[1] + 2x[2] + x[3] + 2 + x[2] + 2x[3] + 2]) + @test MA.isequal_canonical(y .- 2ones(3), [2x[1] + x[2] - 2 + x[1] + 2x[2] + x[3] - 2 + x[2] + 2x[3] - 2]) + @test MA.isequal_canonical(2ones(3) .+ y, [2x[1] + x[2] + 2 + x[1] + 2x[2] + x[3] + 2 + x[2] + 2x[3] + 2]) + @test MA.isequal_canonical(2ones(3) .- y, [-2x[1] - x[2] + 2 + -x[1] - 2x[2] - x[3] + 2 + -x[2] - 2x[3] + 2]) + @test MA.isequal_canonical(y .+ x, [3x[1] + x[2] + x[1] + 3x[2] + x[3] + x[2] + 3x[3]]) + @test MA.isequal_canonical(x .+ y, [3x[1] + x[2] + x[1] + 3x[2] + x[3] + x[2] + 3x[3]]) + @test MA.isequal_canonical(2y .+ 2x, [6x[1] + 2x[2] + 2x[1] + 6x[2] + 2x[3] + 2x[2] + 6x[3]]) + @test MA.isequal_canonical(y .- x, [ x[1] + x[2] + x[1] + x[2] + x[3] + x[2] + x[3]]) + @test MA.isequal_canonical(x .- y, [-x[1] - x[2] + -x[1] - x[2] - x[3] + -x[2] - x[3]]) + @test MA.isequal_canonical(y .+ x[:], [3x[1] + x[2] + x[1] + 3x[2] + x[3] + x[2] + 3x[3]]) + @test MA.isequal_canonical(x[:] .+ y, [3x[1] + x[2] + x[1] + 3x[2] + x[3] + x[2] + 3x[3]]) + + @test MA.isequal_canonical(MA.@rewrite(A*x/2), A*x/2) + @test MA.isequal_canonical(X*v, [4X11; 6X23; 0]) + @test MA.isequal_canonical(v'*X, [4X11 0 5X23]) + @test MA.isequal_canonical(copy(transpose(v))*X, [4X11 0 5X23]) + @test MA.isequal_canonical(X'*v, [4X11; 0; 5X23]) + @test MA.isequal_canonical(copy(transpose(X))*v, [4X11; 0; 5X23]) + @test MA.isequal_canonical(X*A, [2X11 X11 0 + 0 X23 2X23 + 0 0 0 ]) + @test MA.isequal_canonical(A*X, [2X11 0 X23 + X11 0 2X23 + 0 0 X23]) + @test MA.isequal_canonical(A*X', [2X11 0 0 + X11 X23 0 + 0 2X23 0]) + @test MA.isequal_canonical(X'*A, [2X11 X11 0 + 0 0 0 + X23 2X23 X23]) + @test MA.isequal_canonical(copy(transpose(X))*A, [2X11 X11 0 + 0 0 0 + X23 2X23 X23]) + @test MA.isequal_canonical(A'*X, [2X11 0 X23 + X11 0 2X23 + 0 0 X23]) + @test MA.isequal_canonical(copy(transpose(X))*A, X'*A) + @test MA.isequal_canonical(copy(transpose(A))*X, A'*X) + @test MA.isequal_canonical(X*A, X*B) + @test MA.isequal_canonical(Y'*A, copy(transpose(Y))*A) + @test MA.isequal_canonical(A*Y', A*copy(transpose(Y))) + @test MA.isequal_canonical(Z'*A, copy(transpose(Z))*A) + @test MA.isequal_canonical(Xd'*Y, copy(transpose(Xd))*Y) + @test MA.isequal_canonical(Y'*Xd, copy(transpose(Y))*Xd) + @test MA.isequal_canonical(Xd'*Xd, copy(transpose(Xd))*Xd) + @test MA.isequal_canonical(A*X, B*X) + @test MA.isequal_canonical(A*X', B*X') + @test MA.isequal_canonical(A'*X, B'*X) +end + +function broadcast_test(x) + A = [1 2; + 3 4] + B = sparse(A) + y = SparseMatrixCSC(2, 2, copy(B.colptr), copy(B.rowval), vec(x)) + @test MA.isequal_canonical(A.+x, [1+x[1,1] 2+x[1,2]; + 3+x[2,1] 4+x[2,2]]) + @test MA.isequal_canonical(A.+x, B.+x) + @test MA.isequal_canonical(A.+x, A.+y) + @test MA.isequal_canonical(A.+y, B.+y) + @test MA.isequal_canonical(x.+A, [1+x[1,1] 2+x[1,2]; + 3+x[2,1] 4+x[2,2]]) + @test MA.isequal_canonical(x.+A, x.+B) + @test MA.isequal_canonical(x.+A, y.+A) + @test MA.isequal_canonical(x .+ x, [2x[1,1] 2x[1,2]; 2x[2,1] 2x[2,2]]) + @test MA.isequal_canonical(y.+A, y.+B) + @test MA.isequal_canonical(A.-x, [1-x[1,1] 2-x[1,2]; + 3-x[2,1] 4-x[2,2]]) + @test MA.isequal_canonical(A.-x, B.-x) + @test MA.isequal_canonical(A.-x, A.-y) + @test MA.isequal_canonical(x .- x, [zero(x[1] - x[1]) for _1 in 1:2, _2 in 1:2]) + @test MA.isequal_canonical(A.-y, B.-y) + @test MA.isequal_canonical(x.-A, [-1+x[1,1] -2+x[1,2]; + -3+x[2,1] -4+x[2,2]]) + @test MA.isequal_canonical(x.-A, x.-B) + @test MA.isequal_canonical(x.-A, y.-A) + @test MA.isequal_canonical(y.-A, y.-B) + @test MA.isequal_canonical(A.*x, [1*x[1,1] 2*x[1,2]; + 3*x[2,1] 4*x[2,2]]) + @test MA.isequal_canonical(A.*x, B.*x) + @test MA.isequal_canonical(A.*x, A.*y) + @test MA.isequal_canonical(A.*y, B.*y) + + @test MA.isequal_canonical(x.*A, [1*x[1,1] 2*x[1,2]; + 3*x[2,1] 4*x[2,2]]) + @test MA.isequal_canonical(x.*A, x.*B) + @test MA.isequal_canonical(x.*A, y.*A) + @test MA.isequal_canonical(y.*A, y.*B) + + @test MA.isequal_canonical(x .* x, [x[1,1]^2 x[1,2]^2; x[2,1]^2 x[2,2]^2]) + @test MA.isequal_canonical(x./A, [1/1*x[1,1] 1/2*x[1,2]; + 1/3*x[2,1] 1/4*x[2,2]]) + @test MA.isequal_canonical(x./A, x./B) + @test MA.isequal_canonical(x./A, y./A) + + # TODO: Refactor to avoid calling the internal JuMP function + # `_densify_with_jump_eltype`. + #z = JuMP._densify_with_jump_eltype((2 .* y) ./ 3) + #@test MA.isequal_canonical((2 .* x) ./ 3, z) + #z = JuMP._densify_with_jump_eltype(2 * (y ./ 3)) + #@test MA.isequal_canonical(2 .* (x ./ 3), z) + #z = JuMP._densify_with_jump_eltype((x[1,1],) .* B) + #@test MA.isequal_canonical((x[1,1],) .* A, z) +end + +function non_array_test(x, x2) + # This is needed to compare arrays that have nonstandard indexing + elements_equal(A::AbstractArray{T, N}, B::AbstractArray{T, N}) where {T, N} = all(a == b for (a, b) in zip(A, B)) + + for x2 in (OffsetArray(x, -length(x)), view(x, :), sparse(x)) + @test elements_equal(+x, +x2) + @test elements_equal(-x, -x2) + @test elements_equal(x .+ first(x), x2 .+ first(x2)) + @test elements_equal(x .- first(x), x2 .- first(x2)) + @test elements_equal(first(x) .- x, first(x2) .- x2) + @test elements_equal(first(x) .+ x, first(x2) .+ x2) + @test elements_equal(2 .* x, 2 .* x2) + @test elements_equal(first(x) .+ x2, first(x2) .+ x) + @test sum(x) == sum(x2) + if !MA._one_indexed(x2) + @test_throws DimensionMismatch x + x2 + end + @testset "diagm" begin + if !MA._one_indexed(x2) && eltype(x2) isa MA.AbstractMutable + @test_throws AssertionError diagm(x2) + else + @test diagm(x) == diagm(x2) + end + end + end +end + +function unary_matrix(Q) + @test_rewrite 2Q + # See https://github.com/JuliaLang/julia/issues/32374 for `Symmetric` + @test_rewrite -Q +end + +function scalar_uniform_scaling(x) + @test_rewrite x + 2I + @test_rewrite (x + 1) + I + @test_rewrite x - 2I + @test_rewrite (x - 1) - I + @test_rewrite 2I + x + @test_rewrite I + (x + 1) + @test_rewrite 2I - x + @test_rewrite I - (x - 1) + @test_rewrite I * x + @test_rewrite I * (x + 1) + @test_rewrite (x + 1) * I +end + +function matrix_uniform_scaling(x) + @test_rewrite x + 2I + @test_rewrite (x .+ 1) + I + @test_rewrite x - 2I + @test_rewrite (x .- 1) - I + @test_rewrite 2I + x + @test_rewrite I + (x .+ 1) + @test_rewrite 2I - x + @test_rewrite I - (x .- 1) + @test_rewrite I * x + @test_rewrite I * (x .+ 1) + @test_rewrite (x .+ 1) * I + @test_rewrite (x .+ 1) + I * I + @test_rewrite (x .+ 1) + 2 * I + @test_rewrite (x .+ 1) + I * 2 +end + +using LinearAlgebra +using OffsetArrays + +@testset "@rewrite with Int" begin + basic_operators_test(1, 2, 3, 4) + sum_test(rand(Int, 3, 3)) + dot_test(rand(Int, 3), rand(Int, 2, 2), rand(Int, 2, 2, 2)) + issue_656(3) + transpose_test(rand(Int, 3), rand(Int, 2, 3), rand(Int, 5)) + vectorized_test([3, 2, 6], 4, 5, [8 1 9; 4 3 1; 2 0 8]) + broadcast_test([2 4; 1 3]) + x = [2, 4, 3] + non_array_test(x, x) + non_array_test(x, OffsetArray(x, -length(x))) + non_array_test(x, view(x, :)) + non_array_test(x, sparse(x)) + unary_matrix([1 2; 3 4]) + unary_matrix(Symmetric([1 2; 2 4])) + scalar_uniform_scaling(3) + matrix_uniform_scaling([1 2; 3 4]) + matrix_uniform_scaling(Symmetric([1 2; 2 4])) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5326c520..788ba72b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,3 +14,4 @@ end include("bigint.jl") end include("matmul.jl") +include("rewrite.jl")