/
rewrite_generic.jl
330 lines (310 loc) · 12.6 KB
/
rewrite_generic.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
# Copyright (c) 2019 MutableArithmetics.jl contributors
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v.2.0. If a copy of the MPL was not distributed with this file, You can obtain
# one at http://mozilla.org/MPL/2.0/.
# We need these two methods because we're changing how * is re-written.
operate!(::typeof(*), x::AbstractArray{T}, y::T) where {T} = (x .*= y)
operate!(::typeof(*), x::AbstractArray, y) = (x .= operate(*, x, y))
"""
_rewrite_generic(stack::Expr, expr::T)::Tuple{Any,Bool}
This method is the heart of the rewrite logic. It converts `expr` into a mutable
equivalent, places any intermediate calculations onto `stack`, and returns a
tuple containing the return value---which is either `expr` or a `gensym`ed
variable equivalent to `expr`---and a boolean flag that indicates whether the
return value can be mutated by future callers.
"""
function _rewrite_generic end
"""
_rewrite_generic(::Expr, x)
A generic fallback. Given a type `x` we return it without mutation. In addition,
this type should not be mutated by future callers.
"""
_rewrite_generic(::Expr, x) = esc(x), false
"""
_rewrite_generic(::Expr, x::Number)
If `x` is a `Number` at macro expansion time, it _must_ be a constant literal.
We return `x` without mutation, but we return `true` because other callers may
mutate the value without fear. Put aother way, they don't need to wrap the value
in `copy_if_mutable(x)` before using it as the first argument to `operate!!`.
This most commonly happens in situations like `x^2`.
"""
_rewrite_generic(::Expr, x::Number) = x, true
function _is_generator(expr)
return Meta.isexpr(expr, :call, 2) && Meta.isexpr(expr.args[2], :generator)
end
function _is_flatten(expr)
return Meta.isexpr(expr, :call, 2) && Meta.isexpr(expr.args[2], :flatten)
end
function _is_parameters(expr)
return Meta.isexpr(expr, :call, 3) && Meta.isexpr(expr.args[2], :parameters)
end
function _is_kwarg(expr, kwarg::Symbol)
return Meta.isexpr(expr, :kw) && expr.args[1] == kwarg
end
"""
_rewrite_generic(stack::Expr, expr::Expr)
This method is the heart of the rewrite logic. It converts `expr` into a mutable
equivalent.
"""
function _rewrite_generic(stack::Expr, expr::Expr)
if !Meta.isexpr(expr, :call)
# In situations like `x[i]`, we do not attempt to rewrite. Return `expr`
# and don't let future callers mutate.
return esc(expr), false
elseif Meta.isexpr(expr, :call, 1)
# A zero-argument function
return esc(expr), false
elseif Meta.isexpr(expr.args[2], :(...))
# If the first argument is a splat.
return esc(expr), false
elseif _is_generator(expr) || _is_flatten(expr) || _is_parameters(expr)
if !(expr.args[1] in (:sum, :Σ, :∑))
# We don't know what this is. Return the expression and don't let
# future callers mutate.
return esc(expr), false
end
# This is a generator expression like `sum(i for i in args)`. Generators
# come in two forms: `sum(i for i=I, j=J)` or `sum(i for i=I for j=J)`.
# The latter is a `:flatten` expression and needs additional handling,
# but we delay this complexity for _rewrite_generic_generator.
if Meta.isexpr(expr.args[2], :parameters)
# The summation has keyword arguments. We can deal with `init`, but
# not any of the others.
p = expr.args[2]
if length(p.args) == 1 && _is_kwarg(p.args[1], :init)
# sum(iter ; init) form!
root = gensym()
init, _ = _rewrite_generic(stack, p.args[1].args[2])
push!(stack.args, :($root = $init))
return _rewrite_generic_generator(stack, :+, expr.args[3], root)
else
# We don't know how to deal with this
return esc(expr), false
end
else
# Summations use :+ as the reduction operator.
init_expr = expr.args[2].args[end]
if Meta.isexpr(init_expr, :(=)) && init_expr.args[1] == :init
# sum(iter, init) form!
root = gensym()
init, _ = _rewrite_generic(stack, init_expr.args[2])
push!(stack.args, :($root = $init))
new_expr = copy(expr.args[2])
pop!(new_expr.args)
return _rewrite_generic_generator(stack, :+, new_expr, root)
elseif Meta.isexpr(expr.args[2], :flatten)
# sum(iter for iter, init) form!
first_generator = expr.args[2].args[1].args[1]
init_expr = first_generator.args[end]
if Meta.isexpr(init_expr, :(=)) && init_expr.args[1] == :init
root = gensym()
init, _ = _rewrite_generic(stack, init_expr.args[2])
push!(stack.args, :($root = $init))
new_expr = copy(expr.args[2])
pop!(new_expr.args[1].args[1].args)
return _rewrite_generic_generator(stack, :+, new_expr, root)
end
end
return _rewrite_generic_generator(stack, :+, expr.args[2])
end
end
# At this point, we have an expression like `op(args...)`. We can either
# choose to convert the operation to it's mutable equivalent, or return the
# non-mutating operation.
if expr.args[1] == :+
# +(args...) => add_mul(add_mul(arg1, arg2), arg3)
@assert length(expr.args) > 1
if length(expr.args) == 2 # +(arg)
return _rewrite_generic(stack, expr.args[2])
elseif length(expr.args) == 3 && _is_call(expr.args[3], :*)
# +(x, *(y...)) => add_mul(x, y...)
x, is_mutable = _rewrite_generic(stack, expr.args[2])
rhs = if is_mutable
Expr(:call, operate!!, add_mul, x)
else
Expr(:call, operate, add_mul, x)
end
for i in 2:length(expr.args[3].args)
yi, _ = _rewrite_generic(stack, expr.args[3].args[i])
push!(rhs.args, yi)
end
root = gensym()
push!(stack.args, :($root = $rhs))
return root, true
end
return _rewrite_generic_to_nested_op(stack, expr, add_mul)
elseif expr.args[1] == :-
# -(args...) => sub_mul(sub_mul(arg1, arg2), arg3)
@assert length(expr.args) > 1
if length(expr.args) == 2 # -(arg)
return _rewrite_generic(stack, Expr(:call, :*, -1, expr.args[2]))
end
return _rewrite_generic_to_nested_op(stack, expr, sub_mul)
elseif expr.args[1] == :*
# *(args...) => *(*(arg1, arg2), arg3)
@assert length(expr.args) > 2
arg1, is_mutable = _rewrite_generic(stack, expr.args[2])
arg2, _ = _rewrite_generic(stack, expr.args[3])
rhs = if is_mutable
Expr(:call, operate!!, *, arg1, arg2)
else
Expr(:call, operate, *, arg1, arg2)
end
root = gensym()
push!(stack.args, :($root = $rhs))
for i in 4:length(expr.args)
arg, _ = _rewrite_generic(stack, expr.args[i])
rhs = if is_mutable
Expr(:call, operate!!, *, root, arg)
else
Expr(:call, operate, *, root, arg)
end
root = gensym()
push!(stack.args, :($root = $rhs))
end
return root, is_mutable
elseif expr.args[1] == :.+
# .+(args...) => add_mul.(add_mul.(arg1, arg2), arg3)
@assert length(expr.args) > 1
if length(expr.args) == 2 # +(arg)
return _rewrite_generic(stack, expr.args[2])
end
return _rewrite_generic_to_nested_op(
stack,
expr,
add_mul;
broadcast = true,
)
elseif expr.args[1] == :.-
# .-(args...) => sub_mul.(sub_mul.(arg1, arg2), arg3)
@assert length(expr.args) > 1
if length(expr.args) == 2 # .-(arg)
return _rewrite_generic(stack, Expr(:call, :.*, -1, expr.args[2]))
end
return _rewrite_generic_to_nested_op(
stack,
expr,
sub_mul;
broadcast = true,
)
else
# Use the non-mutating call.
result = Expr(:call, esc(expr.args[1]))
for i in 2:length(expr.args)
arg, _ = _rewrite_generic(stack, expr.args[i])
push!(result.args, arg)
end
root = gensym()
push!(stack.args, Expr(:(=), root, result))
# This value isn't safe to mutate, because it might be a reference to
# another object.
return root, false
end
end
function _rewrite_generic_to_nested_op(stack, expr, op; broadcast::Bool = false)
root, is_mutable = _rewrite_generic(stack, expr.args[2])
if !is_mutable
# The first argument isn't mutable, so we need to make a copy.
arg = Expr(:call, copy_if_mutable, root)
root = gensym()
push!(stack.args, Expr(:(=), root, arg))
end
for i in 3:length(expr.args)
arg, _ = _rewrite_generic(stack, expr.args[i])
rhs = if broadcast
Expr(:call, broadcast!!, op, root, arg)
else
Expr(:call, operate!!, op, root, arg)
end
root = gensym()
push!(stack.args, Expr(:(=), root, rhs))
end
return root, true
end
_is_call(expr, op) = Meta.isexpr(expr, :call) && expr.args[1] == op
"""
_rewrite_generic_generator(stack::Expr, op::Symbol, expr::Expr)
Special handling for generator expressions.
`op` is `:+` and `expr` is a `:generator` or `:flatten` expression.
"""
function _rewrite_generic_generator(
stack::Expr,
op::Symbol,
expr::Expr,
root = nothing,
)
@assert op == :+
is_flatten = Meta.isexpr(expr, :flatten)
if is_flatten
expr = expr.args[1]
end
# The value we're going to mutate. Start it off at `Zero`.
if root === nothing
root = gensym()
push!(stack.args, Expr(:(=), root, Zero()))
end
# We need a new stack to go inside our for-loops since we want to
# recursively rewrite the inner part as well.
new_stack = quote end
if _is_call(expr.args[1], op)
# Optimization time! Instead of operate!!(op, root, op(args...)),
# rewrite as operate!!(op, root, arg) for arg in args
for arg in expr.args[1].args[2:end]
value, _ = _rewrite_generic(new_stack, arg)
rhs = Expr(:call, operate!!, add_mul, root, value)
push!(new_stack.args, :($root = $rhs))
end
elseif op == :+ && _is_call(expr.args[1], :*)
# Optimization time! Instead of operate!!(+, root, *(args...)), rewrite
# this as operate!!(add_mul, root, args...)
rhs = Expr(:call, operate!!, add_mul, root)
for arg in expr.args[1].args[2:end]
value, _ = _rewrite_generic(new_stack, arg)
push!(rhs.args, value)
end
push!(new_stack.args, :($root = $rhs))
elseif is_flatten
# The first argument is itself a generator
_rewrite_generic_generator(new_stack, op, expr.args[1], root)
else
# expr.args[1] is the inner part of the loop. Rewrite it. We don't care
# if it is mutable because we need a new value every iteration.
inner, _ = _rewrite_generic(new_stack, expr.args[1])
# Now build up the summation or product part of the inner loop. It's
# always safe to mutate because we're going to start with `root=Zero()`.
rhs = Expr(:call, operate!!, add_mul, root, inner)
push!(new_stack.args, :($root = $rhs))
end
# This is a little complicated: walk back out of the generator statements
# wrapping each level in a for loop and the over-writing the `new_stack`
# variable.
#
# !!! warning
# The Julia syntax sum(i for i in 1:2, j in 1:i) is incorrect, but we
# handle it anyway! Because the user will write dependencies from left
# to right, we need to wrap from right to left.
for i in length(expr.args):-1:2
new_stack = _iterable_condition(new_stack, expr.args[i])
end
# Finally, push our new_stack onto the old `stack`...
push!(stack.args, new_stack)
# and return the `root`. We can mutate this in future because it started off
# as `Zero`.
return root, true
end
function _iterable_condition(new_stack, expr)
if !Meta.isexpr(expr, :filter)
return Expr(:for, esc(expr), new_stack)
end
body = quote
if $(esc(expr.args[1]))
$new_stack
end
end
# A filter might be over multiple index sets
for i in length(expr.args):-1:2
body = Expr(:for, esc(expr.args[i]), body)
end
return body
end