diff --git a/.travis.yml b/.travis.yml index e0fa795..148ae10 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,10 @@ os: - linux - osx julia: + - 1.0 - 1.1 + - 1.2 + - 1.3 - nightly matrix: allow_failures: diff --git a/Project.toml b/Project.toml index a676ef0..26ee5c5 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" julia = "1.1" [extras] +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "OffsetArrays", "SparseArrays"] diff --git a/src/MutableArithmetics.jl b/src/MutableArithmetics.jl index 6d6c197..6b6f3c9 100644 --- a/src/MutableArithmetics.jl +++ b/src/MutableArithmetics.jl @@ -12,14 +12,41 @@ module MutableArithmetics # slowdown because it compiles something that works for any `N`. See # https://github.com/JuliaLang/julia/issues/32761 for details. +# `copy(::BigInt)` and `copy(::Array)` does not copy its elements so we need `deepcopy`. +mutable_copy(x) = deepcopy(x) +mutable_copy(A::AbstractArray) = mutable_copy.(A) + +""" + add_mul(a, args...) + +Return `a + *(args...)`. Note that `add_mul(a, b, c) = muladd(b, c, a)`. +""" +function add_mul end +add_mul(a, b) = a + b +add_mul(a, b, c) = muladd(b, c, a) +add_mul(a, b, c::Vararg{Any, N}) where {N} = add_mul(a, b *(c...)) + include("interface.jl") include("shortcuts.jl") +include("broadcast.jl") # Test that can be used to test an implementation of the interface include("Test/Test.jl") # Implementation of the interface for Base types +import LinearAlgebra +const Scaling = Union{Number, LinearAlgebra.UniformScaling} +mutable_copy(A::LinearAlgebra.Symmetric) = LinearAlgebra.Symmetric(mutable_copy(parent(A)), ifelse(A.uplo == 'U', :U, :L)) +# Broadcast applies the transpose +mutable_copy(A::LinearAlgebra.Transpose) = LinearAlgebra.Transpose(mutable_copy(parent(A))) +mutable_copy(A::LinearAlgebra.Adjoint) = LinearAlgebra.Adjoint(mutable_copy(parent(A))) include("bigint.jl") include("linear_algebra.jl") +isequal_canonical(a, b) = a == b + +include("rewrite.jl") + +include("dispatch.jl") + end # module diff --git a/src/bigint.jl b/src/bigint.jl index 64b43be..aa34b1c 100644 --- a/src/bigint.jl +++ b/src/bigint.jl @@ -13,6 +13,9 @@ promote_operation(::typeof(+), ::Vararg{Type{BigInt}, N}) where {N} = BigInt function mutable_operate_to!(output::BigInt, ::typeof(+), a::BigInt, b::BigInt) return Base.GMP.MPZ.add!(output, a, b) end +#function mutable_operate_to!(output::BigInt, op::typeof(+), a::BigInt, b::LinearAlgebra.UniformScaling) +# return mutable_operate_to!(output, op, a, b.λ) +#end # * promote_operation(::typeof(*), ::Vararg{Type{BigInt}, N}) where {N} = BigInt @@ -20,6 +23,12 @@ function mutable_operate_to!(output::BigInt, ::typeof(*), a::BigInt, b::BigInt) return Base.GMP.MPZ.mul!(output, a, b) end +function mutable_operate_to!(output::BigInt, op::Union{typeof(*), typeof(+)}, + a::BigInt, b::BigInt, c::Vararg{BigInt, N}) where N + mutable_operate_to!(output, op, a, b) + return mutable_operate!(op, output, c...) +end + # add_mul function mutable_operate_to!(output::BigInt, ::typeof(add_mul), args::Vararg{BigInt, N}) where N return mutable_buffered_operate_to!(BigInt(), output, add_mul, args...) @@ -30,6 +39,9 @@ function mutable_buffered_operate_to!(buffer::BigInt, output::BigInt, ::typeof(a return mutable_operate_to!(output, +, a, buffer) end -function mutable_operate_to!(output::BigInt, op::Function, a::Integer, b::Integer) - return mutable_operate_to!(output, op, convert(BigInt, a), convert(BigInt, b)) +scaling_to_bigint(x::BigInt) = x +scaling_to_bigint(x::Number) = convert(BigInt, x) +scaling_to_bigint(J::LinearAlgebra.UniformScaling) = scaling_to_bigint(J.λ) +function mutable_operate_to!(output::BigInt, op::Function, args::Vararg{Scaling, N}) where N + return mutable_operate_to!(output, op, scaling_to_bigint.(args)...) end diff --git a/src/broadcast.jl b/src/broadcast.jl new file mode 100644 index 0000000..382c525 --- /dev/null +++ b/src/broadcast.jl @@ -0,0 +1,76 @@ +function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Type{Eltype}) where {N, Eltype} + return Array{Eltype, N} +end +function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Type{Bool}) where N + return BitArray{N} +end + +# Same as `Base.Broadcast.combine_styles` but with types as argument. +combine_styles() = Broadcast.DefaultArrayStyle{0}() +combine_styles(c::Type) = Broadcast.result_style(Broadcast.BroadcastStyle(c)) +combine_styles(c1::Type, c2::Type) = Broadcast.result_style(combine_styles(c1), combine_styles(c2)) +@inline combine_styles(c1::Type, c2::Type, cs::Vararg{Type, N}) where N = Broadcast.result_style(combine_styles(c1), combine_styles(c2, cs...)) + +function promote_broadcast(op::Function, args::Vararg{Any, N}) where N + # FIXME we could use `promote_operation` instead as + # `combine_eltypes` uses `return_type` hence it may return a non-concrete type + # and we do not handle that case. + T = Base.Broadcast.combine_eltypes(op, args) + return broadcasted_type(combine_styles(args...), T) +end + +""" + broadcast_mutability(T::Type, ::typeof(op), args::Type...)::MutableTrait + +Return `IsMutable` to indicate an object of type `T` can be modified to be +equal to `broadcast(op, args...)`. +""" +function broadcast_mutability(T::Type, op, args::Vararg{Type, N}) where N + if mutability(T) isa IsMutable && promote_broadcast(op, args...) == T + return IsMutable() + else + return NotMutable() + end +end +broadcast_mutability(x, op, args::Vararg{Any, N}) where {N} = broadcast_mutability(typeof(x), op, typeof.(args)...) +broadcast_mutability(::Type) = NotMutable() + +""" + mutable_broadcast!(op::Function, args...) + +Modify the value of `args[1]` to be equal to the value of `broadcast(op, args...)`. Can +only be called if `mutability(args[1], op, args...)` returns `true`. +""" +function mutable_broadcast! end + +function mutable_broadcasted(broadcasted::Broadcast.Broadcasted{S}) where S + function f(args::Vararg{Any, N}) where N + return operate!(broadcasted.f, args...) + end + return Broadcast.Broadcasted{S}(f, broadcasted.args, broadcasted.axes) +end + +# If A is `Symmetric`, we cannot do that as we might modify the same entry twice. +# See https://github.com/JuliaOpt/JuMP.jl/issues/2102 +function mutable_broadcast!(op::Function, A::Array, args::Vararg{Any, N}) where N + bc = Broadcast.broadcasted(op, A, args...) + instantiated = Broadcast.instantiate(bc) + return copyto!(A, mutable_broadcasted(instantiated)) +end + + +""" + broadcast!(op::Function, args...) + +Returns the value of `broadcast(op, args...)`, possibly modifying `args[1]`. +""" +function broadcast!(op::Function, args::Vararg{Any, N}) where N + return broadcast_fallback!(broadcast_mutability(args[1], op, args...), op, args...) +end + +function broadcast_fallback!(::NotMutable, op::Function, args::Vararg{Any, N}) where N + return broadcast(op, args...) +end +function broadcast_fallback!(::IsMutable, op::Function, args::Vararg{Any, N}) where N + return mutable_broadcast!(op, args...) +end diff --git a/src/dispatch.jl b/src/dispatch.jl new file mode 100644 index 0000000..6f30ed5 --- /dev/null +++ b/src/dispatch.jl @@ -0,0 +1,9 @@ +abstract type AbstractMutable end + +# Special-case because the the base version wants to do fill!(::Array{AbstractVariableRef}, zero(GenericAffExpr{Float64,eltype(x)})) +_one_indexed(A) = all(x -> isa(x, Base.OneTo), axes(A)) +function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable}) + @assert _one_indexed(x) # `LinearAlgebra.diagm` doesn't work for non-one-indexed arrays in general. + ZeroType = promote_operation(zero, eltype(x)) + return LinearAlgebra.diagm(0 => copyto!(similar(x, ZeroType), x)) +end diff --git a/src/interface.jl b/src/interface.jl index 0d4961b..eb955ad 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -13,6 +13,18 @@ function promote_operation end function promote_operation(op::Function, args::Vararg{Type, N}) where N return typeof(op(zero.(args)...)) end +promote_operation(::typeof(*), ::Type{T}) where {T} = T +function promote_operation(::typeof(*), ::Type{S}, ::Type{T}, ::Type{U}, args::Vararg{Type, N}) where {S, T, U, N} + return promote_operation(*, promote_operation(*, S, T), U, args...) +end + +# Helpful error for common mistake +function promote_operation(op::Union{typeof(+), typeof(-), typeof(add_mul)}, A::Type{<:Array}, α::Type{<:Number}) + error("Operation `$op` between `$A` and `$α` is not allowed. You should use broadcast.") +end +function promote_operation(op::Union{typeof(+), typeof(-), typeof(add_mul)}, α::Type{<:Number}, A::Type{<:Array}) + error("Operation `$op` between `$α` and `$A` is not allowed. You should use broadcast.") +end # Define Traits abstract type MutableTrait end @@ -40,7 +52,8 @@ function mutable_operate_to_fallback(::NotMutable, output, op::Function, args... end function mutable_operate_to_fallback(::IsMutable, output, op::Function, args...) - error("`mutable_operate_to!($op, $(args...))` is not implemented yet.") + error("`mutable_operate_to!($(typeof(output)), $op, ", join(typeof.(args), ", "), + ")` is not implemented yet.") end """ diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 8a47170..1c61ab2 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -1,6 +1,62 @@ -import LinearAlgebra +mutability(::Type{<:Array}) = IsMutable() -mutability(::Type{<:Vector}) = IsMutable() +# Sum + +function promote_operation(op::Union{typeof(+), typeof(-)}, ::Type{Array{S, N}}, ::Type{Array{T, N}}) where {S, T, N} + return Array{promote_operation(op, S, T), N} +end +function mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Array{S, N}, B::Array{T, N}) where {S, T, N} + for i in eachindex(A) + A[i] = operate!(op, A[i], B[i]) + end + return A +end + +# UniformScaling +function promote_operation(op::typeof(+), ::Type{Array{T, 2}}, ::Type{LinearAlgebra.UniformScaling{S}}) where {S, T} + return Array{promote_operation(op, T, S), 2} +end +function promote_operation(op::typeof(+), ::Type{LinearAlgebra.UniformScaling{S}}, ::Type{Array{T, 2}}) where {S, T} + return Array{promote_operation(op, S, T), 2} +end +function mutable_operate!(::typeof(+), A::Matrix, B::LinearAlgebra.UniformScaling) + n = LinearAlgebra.checksquare(A) + for i in 1:n + A[i, i] = operate!(+, A[i, i], B) + end + return A +end +function mutable_operate!(::typeof(add_mul), A::Matrix, B::Scaling, C::Scaling, D::Vararg{Scaling, N}) where N + return mutable_operate!(+, A, *(B, C, D...)) +end +function mutable_operate!(::typeof(add_mul), A::Array{S, N}, B::Array{T, N}, α::Vararg{Scaling, M}) where {S, T, N, M} + for i in eachindex(A) + A[i] = operate!(add_mul, A[i], B[i], α...) + end + return A +end +function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α::Scaling, B::Array{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M} + for i in eachindex(A) + A[i] = operate!(add_mul, A[i], α, B[i], β...) + end + return A +end +function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α1::Scaling, α2::Scaling, B::Array{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M} + return mutable_operate!(add_mul, A, α1 * α2, B, β...) +end + +# Product + +function promote_operation(op::typeof(*), ::Type{Array{T, N}}, ::Type{S}) where {S, T, N} + return Array{promote_operation(op, T, S), N} +end +function promote_operation(op::typeof(*), ::Type{S}, ::Type{Array{T, N}}) where {S, T, N} + return Array{promote_operation(op, S, T), N} +end + +function promote_operation(::typeof(*), ::Type{Matrix{S}}, ::Type{Vector{T}}) where {S, T} + return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)} +end function promote_operation(::typeof(*), ::Type{<:AbstractMatrix{S}}, ::Type{<:AbstractVector{T}}) where {S, T} return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)} end diff --git a/src/rewrite.jl b/src/rewrite.jl new file mode 100644 index 0000000..11439c6 --- /dev/null +++ b/src/rewrite.jl @@ -0,0 +1,254 @@ +# Heavily inspired from `JuMP/src/parse_expr.jl` code. + +export @rewrite +macro rewrite(expr) + return rewrite(expr) +end + +struct Zero end +# We need to copy `x` as it will be used as might be given by the user and be +# given as first argument of `operate!`. +Base.:(+)(zero::Zero, x) = mutable_copy(x) +# `add_mul(zero, ...)` redirects to `muladd(..., zero)` which calls `... + zero`. +Base.:(+)(x, zero::Zero) = mutable_copy(x) + +using Base.Meta + +# See `JuMP._try_parse_idx_set` +function _try_parse_idx_set(arg::Expr) + # [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and + # Expr(:ref, :x, Expr(:kw, :i, 1)) respectively. + if arg.head === :kw || arg.head === :(=) + @assert length(arg.args) == 2 + return true, arg.args[1], arg.args[2] + elseif isexpr(arg, :call) && arg.args[1] === :in + return true, arg.args[2], arg.args[3] + else + return false, nothing, nothing + end +end + +function _parse_idx_set(arg::Expr) + parse_done, idxvar, idxset = _try_parse_idx_set(arg) + if parse_done + return idxvar, idxset + end + error("Invalid syntax: $arg") +end + +# takes a generator statement and returns a properly nested for loop +# with nested filters as specified +function _parse_gen(ex, atleaf) + if isexpr(ex, :flatten) + return _parse_gen(ex.args[1], atleaf) + end + if !isexpr(ex, :generator) + return atleaf(ex) + end + function itrsets(sets) + if isa(sets, Expr) + return sets + elseif length(sets) == 1 + return sets[1] + else + return Expr(:block, sets...) + end + end + + idxvars = [] + if isexpr(ex.args[2], :filter) # if condition + loop = Expr(:for, esc(itrsets(ex.args[2].args[2:end])), + Expr(:if, esc(ex.args[2].args[1]), + _parse_gen(ex.args[1], atleaf))) + for idxset in ex.args[2].args[2:end] + idxvar, s = _parse_idx_set(idxset) + push!(idxvars, idxvar) + end + else + loop = Expr(:for, esc(itrsets(ex.args[2:end])), + _parse_gen(ex.args[1], atleaf)) + for idxset in ex.args[2:end] + idxvar, s = _parse_idx_set(idxset) + push!(idxvars, idxvar) + end + end + return loop +end + +# See `JuMP._is_sum` +_is_sum(s::Symbol) = (s == :sum) || (s == :∑) || (s == :Σ) + +function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, new_var=gensym()) + @assert isexpr(x,:call) + @assert length(x.args) > 1 + @assert isexpr(x.args[2],:generator) || isexpr(x.args[2],:flatten) + header = x.args[1] + if _is_sum(header) + _parse_generator_sum(x.args[2], aff, lcoeffs, rcoeffs, new_var) + else + error("Expected sum outside generator expression; got $header") + end +end + +function _parse_generator_sum(x::Expr, aff::Symbol, lcoeffs, rcoeffs, new_var) + # We used to preallocate the expression at the lowest level of the loop. + # When rewriting this some benchmarks revealed that it actually doesn't + # seem to help anymore, so might as well keep the code simple. + code = _parse_gen(x, t -> _rewrite(t, aff, lcoeffs, rcoeffs, aff)[2]) + return :($code; $new_var=$aff) +end + +_is_complex_expr(ex) = isa(ex, Expr) && !isexpr(ex, :ref) + +function rewrite(x) + variable = gensym() + new_variable, code = _rewrite_toplevel(x, variable) + return quote + $variable = MutableArithmetics.Zero() + $code + $new_variable + end +end + +_rewrite_toplevel(x, variable::Symbol) = _rewrite(x, variable, [], []) + +function _is_comparison(ex::Expr) + if isexpr(ex, :comparison) + return true + elseif isexpr(ex, :call) + if ex.args[1] in (:<=, :≤, :>=, :≥, :(==)) + return true + else + return false + end + else + return false + end +end + +# x[i=1] <= 2 is a somewhat common user error. Catch it here. +function _has_assignment_in_ref(ex::Expr) + if isexpr(ex, :ref) + return any(x -> isexpr(x, :(=)), ex.args) + else + return any(_has_assignment_in_ref, ex.args) + end +end +_has_assignment_in_ref(other) = false + +function rewrite_sum(terms, current::Symbol, lcoeffs::Vector, rcoeffs::Vector, output::Symbol, block = Expr(:block)) + var = current + for term in terms[1:(end-1)] + var, code = _rewrite(term, var, lcoeffs, rcoeffs) + push!(block.args, code) + end + new_output, code = _rewrite(terms[end], var, lcoeffs, rcoeffs, output) + @assert new_output == output + push!(block.args, code) + return output, block +end + +""" + _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Symbol=gensym()) + +Return `new_var, code` such that `code` is equivalent to +```julia +new_var = aff + prod(lcoefs) * x * prod(rcoeffs) +``` +""" +function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Symbol=gensym()) + if isexpr(x, :call) + if x.args[1] == :+ + return rewrite_sum(x.args[2:end], aff, lcoeffs, rcoeffs, new_var) + elseif x.args[1] == :- + block = Expr(:block) + if length(x.args) > 2 # not unary subtraction + aff_, code = _rewrite(x.args[2], aff, lcoeffs, rcoeffs) + push!(block.args, code) + start = 3 + else + aff_ = aff + start = 2 + end + return rewrite_sum(x.args[start:end], aff_, vcat(-1.0, lcoeffs), rcoeffs, new_var, block) + elseif x.args[1] == :* + # we might need to recurse on multiple arguments, e.g., + # (x+y)*(x+y) + n_expr = mapreduce(_is_complex_expr, +, x.args) + if n_expr == 1 # special case, only recurse on one argument and don't create temporary objects + which_idx = 0 + for i in 2:length(x.args) + if _is_complex_expr(x.args[i]) + which_idx = i + end + end + return _rewrite( + x.args[which_idx], aff, + vcat(lcoeffs, [esc(x.args[i]) for i in 2:(which_idx - 1)]), + vcat(rcoeffs, [esc(x.args[i]) for i in (which_idx + 1):length(x.args)]), + new_var) + else + blk = Expr(:block) + for i in 2:length(x.args) + if _is_complex_expr(x.args[i]) + s = gensym() + new_var_, parsed = _rewrite_toplevel(x.args[i], s) + push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed)) + x.args[i] = new_var_ + else + x.args[i] = esc(x.args[i]) + end + end + callexpr = Expr(:call, :(MutableArithmetics.add_mul!), aff, + lcoeffs..., x.args[2:end]..., rcoeffs...) + push!(blk.args, :($new_var = $callexpr)) + return new_var, blk + end + elseif x.args[1] == :^ && _is_complex_expr(x.args[2]) + MulType = :(MA.promote_operation(*, typeof($(x.args[2])), typeof($(x.args[2])))) + if x.args[3] == 2 + blk = Expr(:block) + s = gensym() + new_var_, parsed = _rewrite_toplevel(x.args[2], s) + push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed)) + push!(blk.args, :($new_var = MutableArithmetics.add_mul!( + $aff, $(Expr(:call, :*, lcoeffs..., new_var_, new_var_, + rcoeffs...))))) + return new_var, blk + elseif x.args[3] == 1 + return _rewrite(:(convert($MulType, $(x.args[2]))), aff, lcoeffs, rcoeffs, new_var) + elseif x.args[3] == 0 + return _rewrite(:(one($MulType)), aff, lcoeffs, rcoeffs, new_var) + else + blk = Expr(:block) + s = gensym() + new_var_, parsed = _rewrite_toplevel(x.args[2], s) + push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed)) + push!(blk.args, :($new_var = MutableArithmetics.add_mul!( + $aff, $(Expr(:call, :*, lcoeffs..., + Expr(:call, :^, new_var_, esc(x.args[3])), + rcoeffs...))))) + return new_var, blk + end + elseif x.args[1] == :/ + @assert length(x.args) == 3 + numerator = x.args[2] + denom = x.args[3] + return _rewrite(numerator, aff, lcoeffs, vcat(esc(:(1 / $denom)), rcoeffs), new_var) + elseif length(x.args) >= 2 && (isexpr(x.args[2], :generator) || isexpr(x.args[2], :flatten)) + return new_var, _parse_generator(x, aff, lcoeffs, rcoeffs, new_var) + end + elseif isexpr(x, :curly) + _error_curly(x) + end + if isa(x, Expr) && _is_comparison(x) + error("Unexpected comparison in expression $x.") + end + if isa(x, Expr) && _has_assignment_in_ref(x) + @warn "Unexpected assignment in expression $x. This will" * + " become a syntax error in a future release." + end + # at the lowest level + callexpr = Expr(:call, :(MutableArithmetics.add_mul!), aff, lcoeffs..., esc(x), rcoeffs...) + return new_var, :($new_var = $callexpr) +end diff --git a/src/shortcuts.jl b/src/shortcuts.jl index 1724060..3bfd92f 100644 --- a/src/shortcuts.jl +++ b/src/shortcuts.jl @@ -26,18 +26,12 @@ Return the product of `a`, `b`, ..., possibly modifying `a`. """ mul!(args::Vararg{Any, N}) where {N} = operate!(*, args...) -""" - add_mul(a, args...) - -Return `a + *(args...)`. Note that `add_mul(a, b, c) = muladd(b, c, a)`. -""" -function add_mul end -add_mul(a, b, c) = muladd(b, c, a) - function promote_operation(::typeof(add_mul), T::Type, args::Vararg{Type, N}) where N return promote_operation(+, T, promote_operation(*, args...)) end +mutable_operate!(::typeof(add_mul), x, y) = mutable_operate!(+, x, y) + """ add_mul_to!(output, args...) diff --git a/test/broadcast.jl b/test/broadcast.jl new file mode 100644 index 0000000..16b110d --- /dev/null +++ b/test/broadcast.jl @@ -0,0 +1,20 @@ +using Test +import MutableArithmetics +const MA = MutableArithmetics + +@testset "Int" begin + a = [1, 2] + b = 3 + @test MA.broadcast!(+, a, b) == [4, 5] + @test a == [4, 5] +end +@testset "BigInt" begin + x = BigInt(1) + y = BigInt(2) + a = [x, y] + b = 3 + @test MA.broadcast!(+, a, b) == [4, 5] + @test a == [4, 5] + @test x == 4 + @test y == 5 +end diff --git a/test/int.jl b/test/int.jl index 953bbd7..47586fb 100644 --- a/test/int.jl +++ b/test/int.jl @@ -1,3 +1,7 @@ +using Test +import MutableArithmetics +const MA = MutableArithmetics + @testset "promote_operation" begin @test MA.promote_operation(MA.zero, Int) == Int @test MA.promote_operation(MA.one, Int) == Int @@ -5,6 +9,10 @@ @test MA.promote_operation(-, Int, Int) == Int @test MA.promote_operation(*, Int, Int) == Int @test MA.promote_operation(MA.add_mul, Int, Int, Int) == Int + err = ErrorException("Operation `+` between `Array{$Int,1}` and `$Int` is not allowed. You should use broadcast.") + @test_throws err MA.promote_operation(+, Vector{Int}, Int) + err = ErrorException("Operation `+` between `$Int` and `Array{$Int,1}` is not allowed. You should use broadcast.") + @test_throws err MA.promote_operation(+, Int, Vector{Int}) end @testset "add_to! / add!" begin @test MA.mutability(Int, MA.add_to!, Int, Int) isa MA.NotMutable diff --git a/test/rewrite.jl b/test/rewrite.jl new file mode 100644 index 0000000..9f34557 --- /dev/null +++ b/test/rewrite.jl @@ -0,0 +1,530 @@ +using SparseArrays, Test +import MutableArithmetics +const MA = MutableArithmetics + +macro test_rewrite(expr) + esc(quote + @test MA.isequal_canonical(MA.@rewrite($expr), $expr) + end) +end + +function basic_operators_test(w, x, y, z) + aff = @inferred 7.1 * x + 2.5 + @test_rewrite 7.1 * x + 2.5 + aff2 = @inferred 1.2 * y + 1.2 + @test_rewrite 1.2 * y + 1.2 + q = @inferred 2.5 * y * z + aff + @test_rewrite 2.5 * y * z + aff + q2 = @inferred 8 * x * z + aff2 + @test_rewrite 8 * x * z + aff2 + @test_rewrite 2 * x * x + 1 * y * y + z + 3 + + @testset "Comparison" begin + @testset "iszero" begin + @test !iszero(x) + @test !iszero(aff) + @test iszero(zero(aff)) + @test !iszero(q) + @test iszero(zero(q)) + end + + @testset "isequal_canonical" begin + @test MA.isequal_canonical((@inferred 3w + 2y), @inferred 2y + 3w) + @test !MA.isequal_canonical((@inferred 3w + 2y + 1), @inferred 3w + 2y) + @test !MA.isequal_canonical((@inferred 3w + 2y), @inferred 3y + 2w) + @test !MA.isequal_canonical((@inferred 3w + 2y), @inferred 3w + y) + + @test !MA.isequal_canonical(aff, aff2) + @test !MA.isequal_canonical(aff2, aff) + + @test MA.isequal_canonical(q, @inferred 2.5z*y + aff) + @test !MA.isequal_canonical(q, @inferred 2.5y*z + aff2) + @test !MA.isequal_canonical(q, @inferred 2.5x*z + aff) + @test !MA.isequal_canonical(q, @inferred 2.5y*x + aff) + @test !MA.isequal_canonical(q, @inferred 1.5y*z + aff) + @test MA.isequal_canonical(q2, @inferred 8z*x + aff2) + @test !MA.isequal_canonical(q2, @inferred 8x*z + aff) + @test !MA.isequal_canonical(q2, @inferred 7x*z + aff2) + @test !MA.isequal_canonical(q2, @inferred 8x*y + aff2) + @test !MA.isequal_canonical(q2, @inferred 8y*z + aff2) + end + end + + # Different objects that must all interact: + # 1. Number + # 2. Variable + # 3. AffExpr + # 4. QuadExpr + + # 1. Number tests + @testset "Number--???" begin + # 1-1 Number--Number - nope! + # 1-2 Number--Variable + @test_rewrite 4.13 + w + @test_rewrite 3.16 - w + @test_rewrite 5.23 * w + # 1-3 Number--AffExpr + @test_rewrite 1.5 + aff + @test_rewrite 1.5 - aff + @test_rewrite 2 * aff + # 1-4 Number--QuadExpr + @test_rewrite 1.5 + q + @test_rewrite 1.5 - q + @test_rewrite 2 * q + end + + # 2. Variable tests + @testset "Variable--???" begin + # 2-0 Variable unary + @test (+x) === x + @test_rewrite -x + # 2-1 Variable--Number + @test_rewrite w + 4.13 + @test_rewrite w - 4.13 + @test_rewrite w * 4.13 + @test_rewrite w / 2.00 + @test w == w + @test_rewrite x*y - 1 + @test_rewrite x^2 + @test_rewrite x^1 + @test_rewrite x^0 + # 2-2 Variable--Variable + @test_rewrite w + x + @test_rewrite w - x + @test_rewrite w * x + @test_rewrite x - x + @test_rewrite y*z - x + # 2-3 Variable--AffExpr + @test_rewrite z + aff + @test_rewrite z - aff + @test_rewrite z * aff + @test_rewrite 7.1 * x - aff + # 2-4 Variable--QuadExpr + @test_rewrite w + q + @test_rewrite w - q + end + + # 3. AffExpr tests + @testset "AffExpr--???" begin + # 3-0 AffExpr unary + @test_rewrite +aff + @test_rewrite -aff + # 3-1 AffExpr--Number + @test_rewrite aff + 1.5 + @test_rewrite aff - 1.5 + @test_rewrite aff * 2 + @test_rewrite aff / 2 + @test aff == aff + @test_rewrite aff - 1 + @test_rewrite aff^2 + @test_rewrite (7.1*x + 2.5)^2 + @test_rewrite aff^1 + @test_rewrite (7.1*x + 2.5)^1 + @test_rewrite aff^0 + @test_rewrite (7.1*x + 2.5)^0 + # 3-2 AffExpr--Variable + @test_rewrite aff + z + @test_rewrite aff - z + @test_rewrite aff * z + @test_rewrite aff - 7.1 * x + # 3-3 AffExpr--AffExpr + @test_rewrite aff + aff2 + @test_rewrite aff - aff2 + @test_rewrite aff * aff2 + @test string((x+x)*(x+3)) == string((x+3)*(x+x)) # Issue #288 + @test_rewrite aff-aff + # 4-4 AffExpr--QuadExpr + @test_rewrite aff2 + q + @test_rewrite aff2 - q + end + + # 4. QuadExpr + # TODO: This test block and others above should be rewritten to be + # self-contained. The definitions of q, w, and aff2 are too far to + # easily check correctness of the tests. + @testset "QuadExpr--???" begin + # 4-0 QuadExpr unary + @test_rewrite +q + @test_rewrite -q + # 4-1 QuadExpr--Number + @test_rewrite q + 1.5 + @test_rewrite q - 1.5 + @test_rewrite q * 2 + @test_rewrite q / 2 + @test q == q + @test_rewrite aff2 - q + # 4-2 QuadExpr--Variable + @test_rewrite q + w + @test_rewrite q - w + # 4-3 QuadExpr--AffExpr + @test_rewrite q + aff2 + @test_rewrite q - aff2 + # 4-4 QuadExpr--QuadExpr + @test_rewrite q + q2 + @test_rewrite q - q2 + end +end + +function sum_test(matrix) + @testset "sum(j::DenseAxisArray{Variable})" begin + @test_rewrite sum(matrix) + end + @testset "sum(affs::Array{AffExpr})" begin + @test_rewrite sum([2matrix[i, j] for i in 1:size(matrix, 1), j in 1:size(matrix, 2)]) + end + @testset "sum(quads::Array{QuadExpr})" begin + @test_rewrite sum([2matrix[i, j]^2 for i in 1:size(matrix, 1), j in 1:size(matrix, 2)]) + end +end + +function dot_test(x, y, z) + @test_rewrite dot(x[1], x[1]) + @test_rewrite dot(2, x[1]) + @test_rewrite dot(x[1], 2) + + c = vcat(1:3) + @test_rewrite dot(c, x) + @test_rewrite dot(x, c) + + A = [1 3 ; 2 4] + @test_rewrite dot(A, y) + @test_rewrite dot(y, A) + + B = ones(2, 2, 2) + @test_rewrite dot(B, z) + @test_rewrite dot(z, B) + + @test_rewrite dot(x, ones(3)) - dot(y, ones(2,2)) +end + +# JuMP issue #656 +function issue_656(x) + floats = Float64[i for i in 1:2] + anys = Array{Any}(undef, 2) + anys[1] = 10 + anys[2] = 20 + x + @test dot(floats, anys) == 10 + 40 + 2x +end + +function transpose_test(x, y, z) + @test MA.isequal_canonical(x', [x[1] x[2] x[3]]) + @test MA.isequal_canonical(copy(transpose(x)), [x[1] x[2] x[3]]) + @test MA.isequal_canonical(y', [y[1,1] y[2,1] + y[1,2] y[2,2] + y[1,3] y[2,3]]) + @test MA.isequal_canonical(copy(transpose(y)), + [y[1,1] y[2,1] + y[1,2] y[2,2] + y[1,3] y[2,3]]) + @test (z')' == z + @test transpose(transpose(z)) == z +end + +function vectorized_test(x, X11, X23, Xd) + A = [2 1 0 + 1 2 1 + 0 1 2] + B = sparse(A) + X = sparse([1, 2], [1, 3], [X11, X23], 3, 3) # for testing Variable + # FIXME + #@test MA.isequal_canonical([X11 0. 0.; 0. 0. X23; 0. 0. 0.], @inferred MA._densify_with_jump_eltype(X)) + Y = sparse([1, 2], [1, 3], [2X11, 4X23], 3, 3) # for testing GenericAffExpr + Yd = [2X11 0 0 + 0 0 4X23 + 0 0 0] + Z = sparse([1, 2], [1, 3], [X11^2, 2X23^2], 3, 3) # for testing GenericQuadExpr + Zd = [X11^2 0 0 + 0 0 2X23^2 + 0 0 0] + v = [4, 5, 6] + + @testset "Sum of matrices" begin + @test_rewrite(x - x) + @test_rewrite(x + x) + @test_rewrite(x + 2x) + @test_rewrite(x - 2x) + @test_rewrite(x + x * 2) + @test_rewrite(x - x * 2) + @test_rewrite(x + 2x * 2) + @test_rewrite(x - 2x * 2) + @test_rewrite(Xd + Yd) + @test_rewrite(Xd - Yd) + @test_rewrite(Xd + 2Yd) + @test_rewrite(Xd - 2Yd) + @test_rewrite(Xd + Yd * 2) + @test_rewrite(Xd - Yd * 2) + @test_rewrite(Yd + Xd) + @test_rewrite(Yd + 2Xd) + @test_rewrite(Yd + Xd * 2) + @test_rewrite(Yd + Zd) + @test_rewrite(Yd + 2Zd) + @test_rewrite(Yd + Zd * 2) + @test_rewrite(Zd + Yd) + @test_rewrite(Zd + 2Yd) + @test_rewrite(Zd + Yd * 2) + @test_rewrite(Zd + Xd) + @test_rewrite(Zd + 2Xd) + @test_rewrite(Zd + Xd * 2) + @test_rewrite(Xd + Zd) + @test_rewrite(Xd + 2Zd) + @test_rewrite(Xd + Zd * 2) + end + + @test_rewrite(x') + @test_rewrite(x' * A) + # Complex expression + @test_rewrite(x' * ones(3, 3)) + @test_rewrite(x' * A * x) + # Complex expression + @test_rewrite(x' * ones(3, 3) * x) + + @test MA.isequal_canonical(A*x, [2x[1] + x[2] + 2x[2] + x[1] + x[3] + x[2] + 2x[3]]) + @test MA.isequal_canonical(A*x, B*x) + @test MA.isequal_canonical(A*x, MA.@rewrite(B*x)) + @test MA.isequal_canonical(MA.@rewrite(A*x), MA.@rewrite(B*x)) + @test MA.isequal_canonical(x'*A, [2x[1]+x[2]; 2x[2]+x[1]+x[3]; x[2]+2x[3]]') + @test MA.isequal_canonical(x'*A, x'*B) + @test MA.isequal_canonical(x'*A, MA.@rewrite(x'*B)) + @test MA.isequal_canonical(MA.@rewrite(x'*A), MA.@rewrite(x'*B)) + @test MA.isequal_canonical(x'*A*x, 2x[1]*x[1] + 2x[1]*x[2] + 2x[2]*x[2] + 2x[2]*x[3] + 2x[3]*x[3]) + @test MA.isequal_canonical(x'A*x, x'*B*x) + @test MA.isequal_canonical(x'*A*x, MA.@rewrite(x'*B*x)) + @test MA.isequal_canonical(MA.@rewrite(x'*A*x), MA.@rewrite(x'*B*x)) + + y = A*x + @test MA.isequal_canonical(-x, [-x[1], -x[2], -x[3]]) + @test MA.isequal_canonical(-y, [-2x[1] - x[2] + -x[1] - 2x[2] - x[3] + -x[2] - 2x[3]]) + @test MA.isequal_canonical(y .+ 1, [2x[1] + x[2] + 1 + x[1] + 2x[2] + x[3] + 1 + x[2] + 2x[3] + 1]) + @test MA.isequal_canonical(y .- 1, [ + 2x[1] + x[2] - 1 + x[1] + 2x[2] + x[3] - 1 + x[2] + 2x[3] - 1]) + @test MA.isequal_canonical(y .+ 2ones(3), [2x[1] + x[2] + 2 + x[1] + 2x[2] + x[3] + 2 + x[2] + 2x[3] + 2]) + @test MA.isequal_canonical(y .- 2ones(3), [2x[1] + x[2] - 2 + x[1] + 2x[2] + x[3] - 2 + x[2] + 2x[3] - 2]) + @test MA.isequal_canonical(2ones(3) .+ y, [2x[1] + x[2] + 2 + x[1] + 2x[2] + x[3] + 2 + x[2] + 2x[3] + 2]) + @test MA.isequal_canonical(2ones(3) .- y, [-2x[1] - x[2] + 2 + -x[1] - 2x[2] - x[3] + 2 + -x[2] - 2x[3] + 2]) + @test MA.isequal_canonical(y .+ x, [3x[1] + x[2] + x[1] + 3x[2] + x[3] + x[2] + 3x[3]]) + @test MA.isequal_canonical(x .+ y, [3x[1] + x[2] + x[1] + 3x[2] + x[3] + x[2] + 3x[3]]) + @test MA.isequal_canonical(2y .+ 2x, [6x[1] + 2x[2] + 2x[1] + 6x[2] + 2x[3] + 2x[2] + 6x[3]]) + @test MA.isequal_canonical(y .- x, [ x[1] + x[2] + x[1] + x[2] + x[3] + x[2] + x[3]]) + @test MA.isequal_canonical(x .- y, [-x[1] - x[2] + -x[1] - x[2] - x[3] + -x[2] - x[3]]) + @test MA.isequal_canonical(y .+ x[:], [3x[1] + x[2] + x[1] + 3x[2] + x[3] + x[2] + 3x[3]]) + @test MA.isequal_canonical(x[:] .+ y, [3x[1] + x[2] + x[1] + 3x[2] + x[3] + x[2] + 3x[3]]) + + @test MA.isequal_canonical(MA.@rewrite(A*x/2), A*x/2) + @test MA.isequal_canonical(X*v, [4X11; 6X23; 0]) + @test MA.isequal_canonical(v'*X, [4X11 0 5X23]) + @test MA.isequal_canonical(copy(transpose(v))*X, [4X11 0 5X23]) + @test MA.isequal_canonical(X'*v, [4X11; 0; 5X23]) + @test MA.isequal_canonical(copy(transpose(X))*v, [4X11; 0; 5X23]) + @test MA.isequal_canonical(X*A, [2X11 X11 0 + 0 X23 2X23 + 0 0 0 ]) + @test MA.isequal_canonical(A*X, [2X11 0 X23 + X11 0 2X23 + 0 0 X23]) + @test MA.isequal_canonical(A*X', [2X11 0 0 + X11 X23 0 + 0 2X23 0]) + @test MA.isequal_canonical(X'*A, [2X11 X11 0 + 0 0 0 + X23 2X23 X23]) + @test MA.isequal_canonical(copy(transpose(X))*A, [2X11 X11 0 + 0 0 0 + X23 2X23 X23]) + @test MA.isequal_canonical(A'*X, [2X11 0 X23 + X11 0 2X23 + 0 0 X23]) + @test MA.isequal_canonical(copy(transpose(X))*A, X'*A) + @test MA.isequal_canonical(copy(transpose(A))*X, A'*X) + @test MA.isequal_canonical(X*A, X*B) + @test MA.isequal_canonical(Y'*A, copy(transpose(Y))*A) + @test MA.isequal_canonical(A*Y', A*copy(transpose(Y))) + @test MA.isequal_canonical(Z'*A, copy(transpose(Z))*A) + @test MA.isequal_canonical(Xd'*Y, copy(transpose(Xd))*Y) + @test MA.isequal_canonical(Y'*Xd, copy(transpose(Y))*Xd) + @test MA.isequal_canonical(Xd'*Xd, copy(transpose(Xd))*Xd) + @test MA.isequal_canonical(A*X, B*X) + @test MA.isequal_canonical(A*X', B*X') + @test MA.isequal_canonical(A'*X, B'*X) +end + +function broadcast_test(x) + A = [1 2; + 3 4] + B = sparse(A) + y = SparseMatrixCSC(2, 2, copy(B.colptr), copy(B.rowval), vec(x)) + @test MA.isequal_canonical(A.+x, [1+x[1,1] 2+x[1,2]; + 3+x[2,1] 4+x[2,2]]) + @test MA.isequal_canonical(A.+x, B.+x) + @test MA.isequal_canonical(A.+x, A.+y) + @test MA.isequal_canonical(A.+y, B.+y) + @test MA.isequal_canonical(x.+A, [1+x[1,1] 2+x[1,2]; + 3+x[2,1] 4+x[2,2]]) + @test MA.isequal_canonical(x.+A, x.+B) + @test MA.isequal_canonical(x.+A, y.+A) + @test MA.isequal_canonical(x .+ x, [2x[1,1] 2x[1,2]; 2x[2,1] 2x[2,2]]) + @test MA.isequal_canonical(y.+A, y.+B) + @test MA.isequal_canonical(A.-x, [1-x[1,1] 2-x[1,2]; + 3-x[2,1] 4-x[2,2]]) + @test MA.isequal_canonical(A.-x, B.-x) + @test MA.isequal_canonical(A.-x, A.-y) + @test MA.isequal_canonical(x .- x, [zero(x[1] - x[1]) for _1 in 1:2, _2 in 1:2]) + @test MA.isequal_canonical(A.-y, B.-y) + @test MA.isequal_canonical(x.-A, [-1+x[1,1] -2+x[1,2]; + -3+x[2,1] -4+x[2,2]]) + @test MA.isequal_canonical(x.-A, x.-B) + @test MA.isequal_canonical(x.-A, y.-A) + @test MA.isequal_canonical(y.-A, y.-B) + @test MA.isequal_canonical(A.*x, [1*x[1,1] 2*x[1,2]; + 3*x[2,1] 4*x[2,2]]) + @test MA.isequal_canonical(A.*x, B.*x) + @test MA.isequal_canonical(A.*x, A.*y) + @test MA.isequal_canonical(A.*y, B.*y) + + @test MA.isequal_canonical(x.*A, [1*x[1,1] 2*x[1,2]; + 3*x[2,1] 4*x[2,2]]) + @test MA.isequal_canonical(x.*A, x.*B) + @test MA.isequal_canonical(x.*A, y.*A) + @test MA.isequal_canonical(y.*A, y.*B) + + @test MA.isequal_canonical(x .* x, [x[1,1]^2 x[1,2]^2; x[2,1]^2 x[2,2]^2]) + @test MA.isequal_canonical(x ./ A, [ + x[1,1] / 1 x[1,2] / 2; + x[2,1] / 3 x[2,2] / 4]) + @test MA.isequal_canonical(x ./ A, x ./ B) + @test MA.isequal_canonical(x ./ A, y ./ A) + + # TODO: Refactor to avoid calling the internal JuMP function + # `_densify_with_jump_eltype`. + #z = JuMP._densify_with_jump_eltype((2 .* y) ./ 3) + #@test MA.isequal_canonical((2 .* x) ./ 3, z) + #z = JuMP._densify_with_jump_eltype(2 * (y ./ 3)) + #@test MA.isequal_canonical(2 .* (x ./ 3), z) + #z = JuMP._densify_with_jump_eltype((x[1,1],) .* B) + #@test MA.isequal_canonical((x[1,1],) .* A, z) +end + +function non_array_test(x, x2) + # This is needed to compare arrays that have nonstandard indexing + elements_equal(A::AbstractArray{T, N}, B::AbstractArray{T, N}) where {T, N} = all(a == b for (a, b) in zip(A, B)) + + @test elements_equal(+x, +x2) + @test elements_equal(-x, -x2) + @test elements_equal(x .+ first(x), x2 .+ first(x2)) + @test elements_equal(x .- first(x), x2 .- first(x2)) + @test elements_equal(first(x) .- x, first(x2) .- x2) + @test elements_equal(first(x) .+ x, first(x2) .+ x2) + @test elements_equal(2 .* x, 2 .* x2) + @test elements_equal(first(x) .+ x2, first(x2) .+ x) + @test sum(x) == sum(x2) + if !MA._one_indexed(x2) + @test_throws DimensionMismatch x + x2 + end + # `diagm` only define with `Pair` in Julia v1.0 and v1.1 + @testset "diagm" begin + if !MA._one_indexed(x2) && eltype(x2) isa MA.AbstractMutable + @test_throws AssertionError diagm(x2) + else + @test diagm(0 => x) == diagm(0 => x2) + if VERSION >= v"1.2" + @test diagm(x) == diagm(x2) + end + end + end +end + +function unary_matrix(Q) + @test_rewrite 2Q + # See https://github.com/JuliaLang/julia/issues/32374 for `Symmetric` + @test_rewrite -Q +end + +function scalar_uniform_scaling(x) + @test_rewrite x + 2I + @test_rewrite (x + 1) + I + @test_rewrite x - 2I + @test_rewrite (x - 1) - I + @test_rewrite 2I + x + @test_rewrite I + (x + 1) + @test_rewrite 2I - x + @test_rewrite I - (x - 1) + @test_rewrite I * x + @test_rewrite I * (x + 1) + @test_rewrite (x + 1) * I +end + +function matrix_uniform_scaling(x) + @test_rewrite x + 2I + @test_rewrite (x .+ 1) + I + @test_rewrite x - 2I + @test_rewrite (x .- 1) - I + @test_rewrite 2I + x + @test_rewrite I + (x .+ 1) + @test_rewrite 2I - x + @test_rewrite I - (x .- 1) + @test_rewrite I * x + @test_rewrite I * (x .+ 1) + @test_rewrite (x .+ 1) * I + @test_rewrite (x .+ 1) + I * I + @test_rewrite (x .+ 1) + 2 * I + @test_rewrite (x .+ 1) + I * 2 +end + +using LinearAlgebra +using OffsetArrays + +@testset "@rewrite with $T" for T in [ + Int, + Float64, + BigInt + ] + basic_operators_test(T(1), T(2), T(3), T(4)) + sum_test(T[5 1 9; -7 2 4; -2 -7 5]) + S = zeros(T, 2, 2, 2) + S[1, :, :] = T[5 -8; 3 -7] + S[2, :, :] = T[-2 8; 8 -1] + dot_test(T[-7, 1, 4], T[0 -4; 6 -5], S) + issue_656(T(3)) + transpose_test(T[9, -3, 8], T[-4 4 1; 4 -8 -6], T[6, 9, 2, 4, -3]) + vectorized_test(T[3, 2, 6], T(4), T(5), T[8 1 9; 4 3 1; 2 0 8]) + broadcast_test(T[2 4; -1 3]) + x = T[2, 4, 3] + non_array_test(x, x) + non_array_test(x, OffsetArray(x, -length(x))) + non_array_test(x, view(x, :)) + non_array_test(x, sparse(x)) + unary_matrix(T[1 2; 3 4]) + unary_matrix(Symmetric(T[1 2; 2 4])) + scalar_uniform_scaling(T(3)) + matrix_uniform_scaling(T[1 2; 3 4]) + matrix_uniform_scaling(Symmetric(T[1 2; 2 4])) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5326c52..108fbaf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,4 +13,8 @@ end @testset "BigInt" begin include("bigint.jl") end +@testset "Broadcast" begin + include("broadcast.jl") +end include("matmul.jl") +include("rewrite.jl")