Skip to content

Commit

Permalink
Merge pull request #32 from JuliaOpt/bl/empty_sum
Browse files Browse the repository at this point in the history
Fix @rewrite for empty sum
  • Loading branch information
blegat committed Jan 10, 2020
2 parents 3562fde + 3cc2504 commit 49581ba
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/Test/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ function iszero_test(x)
@test MA.isequal_canonical(x_copy, x)
end

function empty_sum_test(x)
@test MA.isequal_canonical(MA.@rewrite(x + sum(1 for i in 1:0) * sum(x for i in 1:0)), x)
@test MA.isequal_canonical(MA.@rewrite(x + sum(x for i in 1:0) * sum(1 for i in 1:0)), x)
@test MA.isequal_canonical(MA.@rewrite(x + sum(1 for i in 1:0) * 1^2), x)
@test MA.isequal_canonical(MA.@rewrite(x + 1^2 * sum(1 for i in 1:0)), x)
# `1^2` is considered as a complex expression by `@rewrite`.
@test MA.isequal_canonical(MA.@rewrite(x + 1^2 * sum(1 for i in 1:0) * sum(x for i in 1:0)), x)
@test MA.isequal_canonical(MA.@rewrite(x + sum(1 for i in 1:0) * 1^2 * sum(x for i in 1:0)), x)
@test MA.isequal_canonical(MA.@rewrite(x + 1^2 * sum(1 for i in 1:0) * sum(x for i in 1:0) * 1^2), x)
@test MA.isequal_canonical(MA.@rewrite(x .+ sum(1 for i in 1:0) * sum(x for i in 1:0)), x)
@test MA.isequal_canonical(MA.@rewrite(x .+ 1^2 * sum(1 for i in 1:0) * sum(x for i in 1:0) * 1^2), x)
end

function cube_test(x)
@test_rewrite x^3
@test_rewrite (x + 1)^3
Expand Down Expand Up @@ -71,6 +84,7 @@ end

const scalar_tests = Dict(
"mul_scalar_array" => mul_scalar_array_test,
"empty_sum" => empty_sum_test,
"cube" => cube_test,
"iszero" => iszero_test,
"scalar_in_any" => scalar_in_any_test,
Expand Down
30 changes: 30 additions & 0 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,36 @@ end
broadcast!(::Union{typeof(add_mul), typeof(+)}, ::Zero, x) = copy_if_mutable(x)
broadcast!(::typeof(add_mul), ::Zero, x, y) = x * y

# Needed in `@rewrite(1 .+ sum(1 for i in 1:0) * 1^2)`
Base.:*(z::Zero, ::Any) = z
Base.:*(::Any, z::Zero) = z
Base.:*(z::Zero, ::Zero) = z
Base.:+(::Zero, x::Any) = x
Base.:+(x::Any, ::Zero) = x
Base.:+(z::Zero, ::Zero) = z

# Needed by `@rewrite(BigInt(1) .+ sum(1 for i in 1:0) * 1^2)`
# since we don't require mutable type to support Zero in
# `mutable_operate!`.
_any_zero() = false
_any_zero(::Any, args::Vararg{Any, N}) where {N} = _any_zero(args...)
_any_zero(::Zero, ::Vararg{Any, N}) where {N} = true
function operate!(op::Union{typeof(add_mul), typeof(sub_mul)},
x, args::Vararg{Any, N}) where N
if _any_zero(args...)
return x
else
return operate_fallback!(mutability(x, op, x, args...), op, x, args...)
end
end

# Needed for `@rewrite(BigInt(1) .+ sum(1 for i in 1:0) * 1^2)`
Base.broadcastable(z::Zero) = Ref(z)
Base.ndims(::Type{Zero}) = 0
Base.length(::Zero) = 1
Base.iterate(z::Zero) = (z, nothing)
Base.iterate(::Zero, ::Nothing) = nothing

using Base.Meta

# See `JuMP._try_parse_idx_set`
Expand Down
11 changes: 11 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@ using LinearAlgebra, SparseArrays, Test
import MutableArithmetics
const MA = MutableArithmetics

@testset "Zero" begin
z = MA.Zero()
#@test zero(MA.Zero) isa MA.Zero
@test z + z isa MA.Zero
@test z + 1 == 1
@test 1 + z == 1
@test z * z isa MA.Zero
@test z * 1 isa MA.Zero
@test 1 * z isa MA.Zero
end

# Test that the macro call `m` throws an error exception during pre-compilation
macro test_macro_throws(error, m)
# See https://discourse.julialang.org/t/test-throws-with-macros-after-pr-23533/5878
Expand Down

0 comments on commit 49581ba

Please sign in to comment.