Skip to content

Commit

Permalink
Merge pull request #17 from JuliaOpt/bl/hygiene
Browse files Browse the repository at this point in the history
Add hygiene tests
  • Loading branch information
blegat committed Nov 29, 2019
2 parents 9dedc01 + 6b49148 commit d74f9d2
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ julia> MutableArithmetics.rewrite_generator(:(i for i in 1:2 if isodd(i)), i ->
```
"""
function rewrite_generator(ex, inner)
# `i + j for i in 1:2 for j in 1:2` is a `flatten` expression
if isexpr(ex, :flatten)
return rewrite_generator(ex.args[1], inner)
end
if !isexpr(ex, :generator)
return inner(ex)
end
# `i + j for i in 1:2, j in 1:2` is a `generator` expression
function itrsets(sets)
if isa(sets, Expr)
return sets
Expand Down Expand Up @@ -98,7 +100,7 @@ _is_sum(s::Symbol) = (s == :sum) || (s == :∑) || (s == :Σ)
function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, new_var=gensym())
@assert isexpr(x,:call)
@assert length(x.args) > 1
@assert isexpr(x.args[2],:generator) || isexpr(x.args[2],:flatten)
@assert isexpr(x.args[2], :generator) || isexpr(x.args[2], :flatten)
header = x.args[1]
if _is_sum(header)
_parse_generator_sum(x.args[2], aff, lcoeffs, rcoeffs, new_var)
Expand All @@ -117,17 +119,25 @@ end

_is_complex_expr(ex) = isa(ex, Expr) && !isexpr(ex, :ref)

function rewrite_and_return(x)
variable, code = rewrite(x)
return :($code; $variable)
end
function rewrite(x)
variable = gensym()
new_variable, code = rewrite_to(x, variable)
return new_variable, :($variable = MutableArithmetics.Zero(); $code)
code = rewrite_and_return(x)
return variable, :($variable = $code)
end
function rewrite_and_return(x)
variable = gensym()
output_variable, code = _rewrite_to(x, variable)
# We need to use `let` because `rewrite(:(sum(i for i in 1:2))`
return quote
let
$variable = MutableArithmetics.Zero()
$code
$output_variable
end
end
end

rewrite_to(x, variable::Symbol) = _rewrite(x, variable, [], [])
_rewrite_to(x, variable::Symbol) = _rewrite(x, variable, [], [])

function _is_comparison(ex::Expr)
if isexpr(ex, :comparison)
Expand Down Expand Up @@ -209,7 +219,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Sym
for i in 2:length(x.args)
if _is_complex_expr(x.args[i])
s = gensym()
new_var_, parsed = rewrite_to(x.args[i], s)
new_var_, parsed = _rewrite_to(x.args[i], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
x.args[i] = new_var_
else
Expand All @@ -226,7 +236,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Sym
if x.args[3] == 2
blk = Expr(:block)
s = gensym()
new_var_, parsed = rewrite_to(x.args[2], s)
new_var_, parsed = _rewrite_to(x.args[2], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
push!(blk.args, :($new_var = MutableArithmetics.add_mul!(
$aff, $(Expr(:call, :*, lcoeffs..., new_var_, new_var_,
Expand All @@ -239,7 +249,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Sym
else
blk = Expr(:block)
s = gensym()
new_var_, parsed = rewrite_to(x.args[2], s)
new_var_, parsed = _rewrite_to(x.args[2], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
push!(blk.args, :($new_var = MutableArithmetics.add_mul!(
$aff, $(Expr(:call, :*, lcoeffs...,
Expand Down

0 comments on commit d74f9d2

Please sign in to comment.