/
rewriters.py
372 lines (296 loc) · 13.4 KB
/
rewriters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import abc
from collections import OrderedDict
from time import time
from sympy import cos, sin
from devito.equation import Eq
from devito.ir import (DataSpace, IterationSpace, Interval, IntervalGroup, Cluster,
detect_accesses, build_intervals)
from devito.dse.aliases import collect
from devito.dse.manipulation import (collect_nested, common_subexprs_elimination,
make_is_time_invariant)
from devito.exceptions import DSEException
from devito.logger import dse_warning as warning
from devito.symbolics import (bhaskara_cos, bhaskara_sin, estimate_cost, freeze,
pow_to_mul, q_leaf, q_sum_of_product, q_terminalop,
yreplace)
from devito.tools import flatten
from devito.types import Array, Scalar
__all__ = ['BasicRewriter', 'AdvancedRewriter', 'AggressiveRewriter', 'CustomRewriter']
class State(object):
def __init__(self, cluster, template):
self.clusters = [cluster]
self.template = template
# Track performance of each pass
self.ops = OrderedDict()
self.timings = OrderedDict()
def update(self, clusters):
self.clusters = clusters or self.clusters
def dse_pass(func):
def wrapper(self, state, **kwargs):
# Invoke the DSE pass on each Cluster
tic = time()
state.update(flatten([func(self, c, state.template, **kwargs)
for c in state.clusters]))
toc = time()
# Profiling
key = '%s%d' % (func.__name__, len(state.timings))
state.timings[key] = toc - tic
if self.profile:
candidates = [c.exprs for c in state.clusters if c.is_dense]
state.ops[key] = estimate_cost(flatten(candidates))
return wrapper
class AbstractRewriter(object):
"""
Transform a Cluster of SymPy expressions into one or more clusters with
reduced operation count.
"""
__metaclass__ = abc.ABCMeta
def __init__(self, profile=True, template=None):
self.profile = profile
assert callable(template)
self.template = template
def run(self, cluster):
state = State(cluster, self.template)
self._pipeline(state)
self._finalize(state)
return state
@abc.abstractmethod
def _pipeline(self, state):
return
@dse_pass
def _finalize(self, cluster, *args, **kwargs):
"""
Finalize the DSE output: ::
* Pow-->Mul. Convert integer powers in an expression to Muls,
like a**2 => a*a.
* Freezing. Make sure that subsequent SymPy operations applied to
the expressions in ``cluster.exprs`` will not alter the effect of
the DSE passes.
"""
exprs = [pow_to_mul(e) for e in cluster.exprs]
return cluster.rebuild([freeze(e) for e in exprs])
class BasicRewriter(AbstractRewriter):
def _pipeline(self, state):
self._extract_increments(state)
@dse_pass
def _extract_increments(self, cluster, template, **kwargs):
"""
Extract the RHS of non-local tensor expressions performing an associative
and commutative increment, and assign them to temporaries.
"""
processed = []
for e in cluster.exprs:
if e.is_Increment and e.lhs.function.is_Input:
handle = Scalar(name=template(), dtype=e.dtype).indexify()
if e.rhs.is_Number or e.rhs.is_Symbol:
extracted = e.rhs
else:
extracted = e.rhs.func(*[i for i in e.rhs.args if i != e.lhs])
processed.extend([e.func(handle, extracted, is_Increment=False),
e.func(e.lhs, handle)])
else:
processed.append(e)
return cluster.rebuild(processed)
@dse_pass
def _eliminate_intra_stencil_redundancies(self, cluster, template, **kwargs):
"""
Perform common subexpression elimination, bypassing the tensor expressions
extracted in previous passes.
"""
make = lambda: Scalar(name=template(), dtype=cluster.dtype).indexify()
processed = common_subexprs_elimination(cluster.exprs, make)
return cluster.rebuild(processed)
@dse_pass
def _optimize_trigonometry(self, cluster, **kwargs):
"""
Rebuild ``exprs`` replacing trigonometric functions with Bhaskara
polynomials.
"""
processed = []
for expr in cluster.exprs:
handle = expr.replace(sin, bhaskara_sin)
handle = handle.replace(cos, bhaskara_cos)
processed.append(handle)
return cluster.rebuild(processed)
class AdvancedRewriter(BasicRewriter):
MIN_COST_ALIAS = 10
"""
Minimum operation count of an alias (i.e., "redundant") expression
to be lifted into a vector temporary.
"""
MIN_COST_ALIAS_INV = 50
"""
Minimum operation count of a time-invariant alias (i.e., "redundant")
expression to be lifted into a vector temporary. Time-invariant aliases
are lifted outside of the time-marching loop, thus they will require
vector temporaries as big as the entire grid.
"""
MIN_COST_FACTORIZE = 100
"""
Minimum operation count of an expression so that aggressive factorization
is applied.
"""
def _pipeline(self, state):
self._extract_time_invariants(state)
self._eliminate_inter_stencil_redundancies(state)
self._eliminate_intra_stencil_redundancies(state)
self._factorize(state)
@dse_pass
def _extract_time_invariants(self, cluster, template, **kwargs):
"""
Extract time-invariant subexpressions, and assign them to temporaries.
"""
make = lambda: Scalar(name=template(), dtype=cluster.dtype).indexify()
rule = make_is_time_invariant(cluster.exprs)
costmodel = lambda e: estimate_cost(e, True) >= self.MIN_COST_ALIAS_INV
processed, found = yreplace(cluster.exprs, make, rule, costmodel, eager=True)
return cluster.rebuild(processed)
@dse_pass
def _factorize(self, cluster, *args, **kwargs):
"""
Factorize trascendental functions, symbolic powers, numeric coefficients.
If the expression has an operation count greater than
``self.MIN_COST_FACTORIZE``, then the algorithm is applied recursively
until no more factorization opportunities are detected.
"""
processed = []
for expr in cluster.exprs:
handle = collect_nested(expr)
cost_handle = estimate_cost(handle)
if cost_handle >= self.MIN_COST_FACTORIZE:
handle_prev = handle
cost_prev = estimate_cost(expr)
while cost_handle < cost_prev:
handle_prev, handle = handle, collect_nested(handle)
cost_prev, cost_handle = cost_handle, estimate_cost(handle)
cost_handle, handle = cost_prev, handle_prev
processed.append(handle)
return cluster.rebuild(processed)
@dse_pass
def _eliminate_inter_stencil_redundancies(self, cluster, template, **kwargs):
"""
Search aliasing expressions and capture them into vector temporaries.
Examples
--------
1) temp = (a[x,y,z]+b[x,y,z])*c[t,x,y,z]
>>>
ti[x,y,z] = a[x,y,z] + b[x,y,z]
temp = ti[x,y,z]*c[t,x,y,z]
2) temp1 = 2.0*a[x,y,z]*b[x,y,z]
temp2 = 3.0*a[x,y,z+1]*b[x,y,z+1]
>>>
ti[x,y,z] = a[x,y,z]*b[x,y,z]
temp1 = 2.0*ti[x,y,z]
temp2 = 3.0*ti[x,y,z+1]
"""
exprs = cluster.exprs
# For more information about "aliases", refer to collect.__doc__
aliases = collect(exprs)
# Redundancies will be stored in space-varying temporaries
is_time_invariant = make_is_time_invariant(exprs)
time_invariants = {e.rhs: is_time_invariant(e) for e in exprs}
# Find the candidate expressions
processed = []
candidates = OrderedDict()
for e in exprs:
# Cost check (to keep the memory footprint under control)
naliases = len(aliases.get(e.rhs))
cost = estimate_cost(e, True)*naliases
test0 = lambda: cost >= self.MIN_COST_ALIAS and naliases > 1
test1 = lambda: cost >= self.MIN_COST_ALIAS_INV and time_invariants[e.rhs]
if test0() or test1():
candidates[e.rhs] = e.lhs
else:
processed.append(e)
# Create alias Clusters and all necessary substitution rules
# for the new temporaries
alias_clusters = []
subs = {}
for origin, alias in aliases.items():
if all(i not in candidates for i in alias.aliased):
continue
# The write-to Intervals
writeto = [Interval(i.dim, *alias.relaxed_diameter.get(i.dim, (0, 0)))
for i in cluster.ispace.intervals if not i.dim.is_Time]
writeto = IntervalGroup(writeto)
# Optimization: no need to retain a SpaceDimension if it does not
# induce a flow/anti dependence (below, `i.offsets` captures this, by
# telling how much halo will be needed to honour such dependences)
dep_inducing = [i for i in writeto if any(i.offsets)]
try:
index = writeto.index(dep_inducing[0])
writeto = IntervalGroup(writeto[index:])
except IndexError:
warning("Couldn't optimize some of the detected redundancies")
# Create a temporary to store `alias`
dimensions = [d.root for d in writeto.dimensions]
halo = [(abs(i.lower), abs(i.upper)) for i in writeto]
array = Array(name=template(), dimensions=dimensions, halo=halo,
dtype=cluster.dtype)
# Build up the expression evaluating `alias`
access = tuple(i.dim - i.lower for i in writeto)
expression = Eq(array[access], origin.xreplace(subs))
# Create the substitution rules so that we can use the newly created
# temporary in place of the aliasing expressions
for aliased, distance in alias.with_distance:
assert all(i.dim in distance.labels for i in writeto)
access = [i.dim - i.lower + distance[i.dim] for i in writeto]
if aliased in candidates:
# It would *not* be in `candidates` if part of a composite alias
subs[candidates[aliased]] = array[access]
subs[aliased] = array[access]
# Construct the `alias` IterationSpace
intervals, sub_iterators, directions = cluster.ispace.args
ispace = IterationSpace(intervals.add(writeto), sub_iterators, directions)
# Construct the `alias` DataSpace
mapper = detect_accesses(expression)
parts = {k: IntervalGroup(build_intervals(v)).add(ispace.intervals)
for k, v in mapper.items() if k}
dspace = DataSpace(cluster.dspace.intervals, parts)
# Create a new Cluster for `alias`
alias_clusters.append(Cluster([expression], ispace, dspace))
# Switch temporaries in the expression trees
processed = [e.xreplace(subs) for e in processed]
return alias_clusters + [cluster.rebuild(processed)]
class AggressiveRewriter(AdvancedRewriter):
def _pipeline(self, state):
self._extract_sum_of_products(state)
self._extract_time_invariants(state)
self._eliminate_inter_stencil_redundancies(state)
self._extract_sum_of_products(state)
self._eliminate_inter_stencil_redundancies(state)
self._extract_sum_of_products(state)
self._factorize(state)
self._eliminate_intra_stencil_redundancies(state)
@dse_pass
def _extract_sum_of_products(self, cluster, template, **kwargs):
"""
Extract sub-expressions in sum-of-product form, and assign them to temporaries.
"""
make = lambda: Scalar(name=template(), dtype=cluster.dtype).indexify()
rule = q_sum_of_product
costmodel = lambda e: not (q_leaf(e) or q_terminalop(e))
processed, _ = yreplace(cluster.exprs, make, rule, costmodel)
return cluster.rebuild(processed)
class CustomRewriter(AggressiveRewriter):
passes_mapper = {
'extract_sop': AggressiveRewriter._extract_sum_of_products,
'factorize': AggressiveRewriter._factorize,
'gcse': AggressiveRewriter._eliminate_inter_stencil_redundancies,
'cse': AggressiveRewriter._eliminate_intra_stencil_redundancies,
'extract_invariants': AdvancedRewriter._extract_time_invariants,
'extract_increments': BasicRewriter._extract_increments,
'opt_transcedentals': BasicRewriter._optimize_trigonometry
}
def __init__(self, passes, template=None, profile=True):
try:
passes = passes.split(',')
except AttributeError:
# Already in tuple format
if not all(i in CustomRewriter.passes_mapper for i in passes):
raise DSEException("Unknown passes `%s`" % str(passes))
self.passes = passes
super(CustomRewriter, self).__init__(profile, template)
def _pipeline(self, state):
for i in self.passes:
CustomRewriter.passes_mapper[i](self, state)