diff --git a/src/rewrite.jl b/src/rewrite.jl index d56f811..11439c6 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -78,24 +78,24 @@ end # See `JuMP._is_sum` _is_sum(s::Symbol) = (s == :sum) || (s == :∑) || (s == :Σ) -function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff=gensym()) +function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, new_var=gensym()) @assert isexpr(x,:call) @assert length(x.args) > 1 @assert isexpr(x.args[2],:generator) || isexpr(x.args[2],:flatten) header = x.args[1] if _is_sum(header) - _parse_generator_sum(x.args[2], aff, lcoeffs, rcoeffs, newaff) + _parse_generator_sum(x.args[2], aff, lcoeffs, rcoeffs, new_var) else error("Expected sum outside generator expression; got $header") end end -function _parse_generator_sum(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff) +function _parse_generator_sum(x::Expr, aff::Symbol, lcoeffs, rcoeffs, new_var) # We used to preallocate the expression at the lowest level of the loop. # When rewriting this some benchmarks revealed that it actually doesn't # seem to help anymore, so might as well keep the code simple. code = _parse_gen(x, t -> _rewrite(t, aff, lcoeffs, rcoeffs, aff)[2]) - return :($code; $newaff=$aff) + return :($code; $new_var=$aff) end _is_complex_expr(ex) = isa(ex, Expr) && !isexpr(ex, :ref) @@ -136,34 +136,41 @@ function _has_assignment_in_ref(ex::Expr) end _has_assignment_in_ref(other) = false -# output is assigned to newaff -function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symbol=gensym()) +function rewrite_sum(terms, current::Symbol, lcoeffs::Vector, rcoeffs::Vector, output::Symbol, block = Expr(:block)) + var = current + for term in terms[1:(end-1)] + var, code = _rewrite(term, var, lcoeffs, rcoeffs) + push!(block.args, code) + end + new_output, code = _rewrite(terms[end], var, lcoeffs, rcoeffs, output) + @assert new_output == output + push!(block.args, code) + return output, block +end + +""" + _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Symbol=gensym()) + +Return `new_var, code` such that `code` is equivalent to +```julia +new_var = aff + prod(lcoefs) * x * prod(rcoeffs) +``` +""" +function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Symbol=gensym()) if isexpr(x, :call) if x.args[1] == :+ - b = Expr(:block) - aff_ = aff - for arg in x.args[2:(end-1)] - aff_, code = _rewrite(arg, aff_, lcoeffs, rcoeffs) - push!(b.args, code) - end - newaff, code = _rewrite(x.args[end], aff_, lcoeffs, rcoeffs, newaff) - push!(b.args, code) - return newaff, b + return rewrite_sum(x.args[2:end], aff, lcoeffs, rcoeffs, new_var) elseif x.args[1] == :- - if length(x.args) == 2 # unary subtraction - return _rewrite(x.args[2], aff, vcat(-1.0, lcoeffs), rcoeffs, newaff) - else # a - b - c ... - b = Expr(:block) + block = Expr(:block) + if length(x.args) > 2 # not unary subtraction aff_, code = _rewrite(x.args[2], aff, lcoeffs, rcoeffs) - push!(b.args, code) - for arg in x.args[3:(end-1)] - aff_,code = _rewrite(arg, aff_, vcat(-1.0, lcoeffs), rcoeffs) - push!(b.args, code) - end - newaff,code = _rewrite(x.args[end], aff_, vcat(-1.0, lcoeffs), rcoeffs, newaff) - push!(b.args, code) - return newaff, b + push!(block.args, code) + start = 3 + else + aff_ = aff + start = 2 end + return rewrite_sum(x.args[start:end], aff_, vcat(-1.0, lcoeffs), rcoeffs, new_var, block) elseif x.args[1] == :* # we might need to recurse on multiple arguments, e.g., # (x+y)*(x+y) @@ -179,57 +186,57 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb x.args[which_idx], aff, vcat(lcoeffs, [esc(x.args[i]) for i in 2:(which_idx - 1)]), vcat(rcoeffs, [esc(x.args[i]) for i in (which_idx + 1):length(x.args)]), - newaff) + new_var) else blk = Expr(:block) for i in 2:length(x.args) if _is_complex_expr(x.args[i]) s = gensym() - newaff_, parsed = _rewrite_toplevel(x.args[i], s) + new_var_, parsed = _rewrite_toplevel(x.args[i], s) push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed)) - x.args[i] = newaff_ + x.args[i] = new_var_ else x.args[i] = esc(x.args[i]) end end callexpr = Expr(:call, :(MutableArithmetics.add_mul!), aff, lcoeffs..., x.args[2:end]..., rcoeffs...) - push!(blk.args, :($newaff = $callexpr)) - return newaff, blk + push!(blk.args, :($new_var = $callexpr)) + return new_var, blk end elseif x.args[1] == :^ && _is_complex_expr(x.args[2]) MulType = :(MA.promote_operation(*, typeof($(x.args[2])), typeof($(x.args[2])))) if x.args[3] == 2 blk = Expr(:block) s = gensym() - newaff_, parsed = _rewrite_toplevel(x.args[2], s) + new_var_, parsed = _rewrite_toplevel(x.args[2], s) push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed)) - push!(blk.args, :($newaff = MutableArithmetics.add_mul!( - $aff, $(Expr(:call, :*, lcoeffs..., newaff_, newaff_, + push!(blk.args, :($new_var = MutableArithmetics.add_mul!( + $aff, $(Expr(:call, :*, lcoeffs..., new_var_, new_var_, rcoeffs...))))) - return newaff, blk + return new_var, blk elseif x.args[3] == 1 - return _rewrite(:(convert($MulType, $(x.args[2]))), aff, lcoeffs, rcoeffs) + return _rewrite(:(convert($MulType, $(x.args[2]))), aff, lcoeffs, rcoeffs, new_var) elseif x.args[3] == 0 - return _rewrite(:(one($MulType)), aff, lcoeffs, rcoeffs) + return _rewrite(:(one($MulType)), aff, lcoeffs, rcoeffs, new_var) else blk = Expr(:block) s = gensym() - newaff_, parsed = _rewrite_toplevel(x.args[2], s) + new_var_, parsed = _rewrite_toplevel(x.args[2], s) push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed)) - push!(blk.args, :($newaff = MutableArithmetics.add_mul!( + push!(blk.args, :($new_var = MutableArithmetics.add_mul!( $aff, $(Expr(:call, :*, lcoeffs..., - Expr(:call, :^, newaff_, esc(x.args[3])), + Expr(:call, :^, new_var_, esc(x.args[3])), rcoeffs...))))) - return newaff, blk + return new_var, blk end elseif x.args[1] == :/ @assert length(x.args) == 3 numerator = x.args[2] denom = x.args[3] - return _rewrite(numerator, aff, lcoeffs, vcat(esc(:(1 / $denom)), rcoeffs), newaff) + return _rewrite(numerator, aff, lcoeffs, vcat(esc(:(1 / $denom)), rcoeffs), new_var) elseif length(x.args) >= 2 && (isexpr(x.args[2], :generator) || isexpr(x.args[2], :flatten)) - return newaff, _parse_generator(x,aff,lcoeffs,rcoeffs,newaff) + return new_var, _parse_generator(x, aff, lcoeffs, rcoeffs, new_var) end elseif isexpr(x, :curly) _error_curly(x) @@ -243,5 +250,5 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb end # at the lowest level callexpr = Expr(:call, :(MutableArithmetics.add_mul!), aff, lcoeffs..., esc(x), rcoeffs...) - return newaff, :($newaff = $callexpr) + return new_var, :($new_var = $callexpr) end