Skip to content

Commit

Permalink
Refactor with rewrite_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Nov 22, 2019
1 parent 50f8a79 commit fe9575b
Showing 1 changed file with 52 additions and 45 deletions.
97 changes: 52 additions & 45 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,24 @@ end
# See `JuMP._is_sum`
_is_sum(s::Symbol) = (s == :sum) || (s == :∑) || (s == )

function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff=gensym())
function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, new_var=gensym())
@assert isexpr(x,:call)
@assert length(x.args) > 1
@assert isexpr(x.args[2],:generator) || isexpr(x.args[2],:flatten)
header = x.args[1]
if _is_sum(header)
_parse_generator_sum(x.args[2], aff, lcoeffs, rcoeffs, newaff)
_parse_generator_sum(x.args[2], aff, lcoeffs, rcoeffs, new_var)
else
error("Expected sum outside generator expression; got $header")
end
end

function _parse_generator_sum(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff)
function _parse_generator_sum(x::Expr, aff::Symbol, lcoeffs, rcoeffs, new_var)
# We used to preallocate the expression at the lowest level of the loop.
# When rewriting this some benchmarks revealed that it actually doesn't
# seem to help anymore, so might as well keep the code simple.
code = _parse_gen(x, t -> _rewrite(t, aff, lcoeffs, rcoeffs, aff)[2])
return :($code; $newaff=$aff)
return :($code; $new_var=$aff)
end

_is_complex_expr(ex) = isa(ex, Expr) && !isexpr(ex, :ref)
Expand Down Expand Up @@ -136,34 +136,41 @@ function _has_assignment_in_ref(ex::Expr)
end
_has_assignment_in_ref(other) = false

# output is assigned to newaff
function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symbol=gensym())
function rewrite_sum(terms, current::Symbol, lcoeffs::Vector, rcoeffs::Vector, output::Symbol, block = Expr(:block))
var = current
for term in terms[1:(end-1)]
var, code = _rewrite(term, var, lcoeffs, rcoeffs)
push!(block.args, code)
end
new_output, code = _rewrite(terms[end], var, lcoeffs, rcoeffs, output)
@assert new_output == output
push!(block.args, code)
return output, block
end

"""
_rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Symbol=gensym())
Return `new_var, code` such that `code` is equivalent to
```julia
new_var = aff + prod(lcoefs) * x * prod(rcoeffs)
```
"""
function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Symbol=gensym())
if isexpr(x, :call)
if x.args[1] == :+
b = Expr(:block)
aff_ = aff
for arg in x.args[2:(end-1)]
aff_, code = _rewrite(arg, aff_, lcoeffs, rcoeffs)
push!(b.args, code)
end
newaff, code = _rewrite(x.args[end], aff_, lcoeffs, rcoeffs, newaff)
push!(b.args, code)
return newaff, b
return rewrite_sum(x.args[2:end], aff, lcoeffs, rcoeffs, new_var)
elseif x.args[1] == :-
if length(x.args) == 2 # unary subtraction
return _rewrite(x.args[2], aff, vcat(-1.0, lcoeffs), rcoeffs, newaff)
else # a - b - c ...
b = Expr(:block)
block = Expr(:block)
if length(x.args) > 2 # not unary subtraction
aff_, code = _rewrite(x.args[2], aff, lcoeffs, rcoeffs)
push!(b.args, code)
for arg in x.args[3:(end-1)]
aff_,code = _rewrite(arg, aff_, vcat(-1.0, lcoeffs), rcoeffs)
push!(b.args, code)
end
newaff,code = _rewrite(x.args[end], aff_, vcat(-1.0, lcoeffs), rcoeffs, newaff)
push!(b.args, code)
return newaff, b
push!(block.args, code)
start = 3
else
aff_ = aff
start = 2
end
return rewrite_sum(x.args[start:end], aff_, vcat(-1.0, lcoeffs), rcoeffs, new_var, block)
elseif x.args[1] == :*
# we might need to recurse on multiple arguments, e.g.,
# (x+y)*(x+y)
Expand All @@ -179,57 +186,57 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb
x.args[which_idx], aff,
vcat(lcoeffs, [esc(x.args[i]) for i in 2:(which_idx - 1)]),
vcat(rcoeffs, [esc(x.args[i]) for i in (which_idx + 1):length(x.args)]),
newaff)
new_var)
else
blk = Expr(:block)
for i in 2:length(x.args)
if _is_complex_expr(x.args[i])
s = gensym()
newaff_, parsed = _rewrite_toplevel(x.args[i], s)
new_var_, parsed = _rewrite_toplevel(x.args[i], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
x.args[i] = newaff_
x.args[i] = new_var_
else
x.args[i] = esc(x.args[i])
end
end
callexpr = Expr(:call, :(MutableArithmetics.add_mul!), aff,
lcoeffs..., x.args[2:end]..., rcoeffs...)
push!(blk.args, :($newaff = $callexpr))
return newaff, blk
push!(blk.args, :($new_var = $callexpr))
return new_var, blk
end
elseif x.args[1] == :^ && _is_complex_expr(x.args[2])
MulType = :(MA.promote_operation(*, typeof($(x.args[2])), typeof($(x.args[2]))))
if x.args[3] == 2
blk = Expr(:block)
s = gensym()
newaff_, parsed = _rewrite_toplevel(x.args[2], s)
new_var_, parsed = _rewrite_toplevel(x.args[2], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
push!(blk.args, :($newaff = MutableArithmetics.add_mul!(
$aff, $(Expr(:call, :*, lcoeffs..., newaff_, newaff_,
push!(blk.args, :($new_var = MutableArithmetics.add_mul!(
$aff, $(Expr(:call, :*, lcoeffs..., new_var_, new_var_,
rcoeffs...)))))
return newaff, blk
return new_var, blk
elseif x.args[3] == 1
return _rewrite(:(convert($MulType, $(x.args[2]))), aff, lcoeffs, rcoeffs)
return _rewrite(:(convert($MulType, $(x.args[2]))), aff, lcoeffs, rcoeffs, new_var)
elseif x.args[3] == 0
return _rewrite(:(one($MulType)), aff, lcoeffs, rcoeffs)
return _rewrite(:(one($MulType)), aff, lcoeffs, rcoeffs, new_var)
else
blk = Expr(:block)
s = gensym()
newaff_, parsed = _rewrite_toplevel(x.args[2], s)
new_var_, parsed = _rewrite_toplevel(x.args[2], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
push!(blk.args, :($newaff = MutableArithmetics.add_mul!(
push!(blk.args, :($new_var = MutableArithmetics.add_mul!(
$aff, $(Expr(:call, :*, lcoeffs...,
Expr(:call, :^, newaff_, esc(x.args[3])),
Expr(:call, :^, new_var_, esc(x.args[3])),
rcoeffs...)))))
return newaff, blk
return new_var, blk
end
elseif x.args[1] == :/
@assert length(x.args) == 3
numerator = x.args[2]
denom = x.args[3]
return _rewrite(numerator, aff, lcoeffs, vcat(esc(:(1 / $denom)), rcoeffs), newaff)
return _rewrite(numerator, aff, lcoeffs, vcat(esc(:(1 / $denom)), rcoeffs), new_var)
elseif length(x.args) >= 2 && (isexpr(x.args[2], :generator) || isexpr(x.args[2], :flatten))
return newaff, _parse_generator(x,aff,lcoeffs,rcoeffs,newaff)
return new_var, _parse_generator(x, aff, lcoeffs, rcoeffs, new_var)
end
elseif isexpr(x, :curly)
_error_curly(x)
Expand All @@ -243,5 +250,5 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb
end
# at the lowest level
callexpr = Expr(:call, :(MutableArithmetics.add_mul!), aff, lcoeffs..., esc(x), rcoeffs...)
return newaff, :($newaff = $callexpr)
return new_var, :($new_var = $callexpr)
end

0 comments on commit fe9575b

Please sign in to comment.