forked from dask/dask
/
rewrite.py
438 lines (344 loc) · 12.4 KB
/
rewrite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
from __future__ import absolute_import, division, print_function
from collections import deque
from dask.core import istask, subs
def head(task):
"""Return the top level node of a task"""
if istask(task):
return task[0]
elif isinstance(task, list):
return list
else:
return task
def args(task):
"""Get the arguments for the current task"""
if istask(task):
return task[1:]
elif isinstance(task, list):
return task
else:
return ()
class Traverser(object):
"""Traverser interface for tasks.
Class for storing the state while performing a preorder-traversal of a
task.
Parameters
----------
term : task
The task to be traversed
Attributes
----------
term
The current element in the traversal
current
The head of the current element in the traversal. This is simply `head`
applied to the attribute `term`.
"""
def __init__(self, term, stack=None):
self.term = term
if not stack:
self._stack = deque([END])
else:
self._stack = stack
def __iter__(self):
while self.current is not END:
yield self.current
self.next()
def copy(self):
"""Copy the traverser in its current state.
This allows the traversal to be pushed onto a stack, for easy
backtracking."""
return Traverser(self.term, deque(self._stack))
def next(self):
"""Proceed to the next term in the preorder traversal."""
subterms = args(self.term)
if not subterms:
# No subterms, pop off stack
self.term = self._stack.pop()
else:
self.term = subterms[0]
self._stack.extend(reversed(subterms[1:]))
@property
def current(self):
return head(self.term)
def skip(self):
"""Skip over all subterms of the current level in the traversal"""
self.term = self._stack.pop()
class Token(object):
"""A token object.
Used to express certain objects in the traversal of a task or pattern."""
def __init__(self, name):
self.name = name
def __repr__(self):
return self.name
# A variable to represent *all* variables in a discrimination net
VAR = Token('?')
# Represents the end of the traversal of an expression. We can't use `None`,
# 'False', etc... here, as anything may be an argument to a function.
END = Token('end')
class Node(tuple):
"""A Discrimination Net node."""
__slots__ = ()
def __new__(cls, edges=None, patterns=None):
edges = edges if edges else {}
patterns = patterns if patterns else []
return tuple.__new__(cls, (edges, patterns))
@property
def edges(self):
"""A dictionary, where the keys are edges, and the values are nodes"""
return self[0]
@property
def patterns(self):
"""A list of all patterns that currently match at this node"""
return self[1]
class RewriteRule(object):
"""A rewrite rule.
Expresses `lhs` -> `rhs`, for variables `vars`.
Parameters
----------
lhs : task
The left-hand-side of the rewrite rule.
rhs : task or function
The right-hand-side of the rewrite rule. If it's a task, variables in
`rhs` will be replaced by terms in the subject that match the variables
in `lhs`. If it's a function, the function will be called with a dict
of such matches.
vars: tuple, optional
Tuple of variables found in the lhs. Variables can be represented as
any hashable object; a good convention is to use strings. If there are
no variables, this can be omitted.
Examples
--------
Here's a `RewriteRule` to replace all nested calls to `list`, so that
`(list, (list, 'x'))` is replaced with `(list, 'x')`, where `'x'` is a
variable.
>>> lhs = (list, (list, 'x'))
>>> rhs = (list, 'x')
>>> variables = ('x',)
>>> rule = RewriteRule(lhs, rhs, variables)
Here's a more complicated rule that uses a callable right-hand-side. A
callable `rhs` takes in a dictionary mapping variables to their matching
values. This rule replaces all occurrences of `(list, 'x')` with `'x'` if
`'x'` is a list itself.
>>> lhs = (list, 'x')
>>> def repl_list(sd):
... x = sd['x']
... if isinstance(x, list):
... return x
... else:
... return (list, x)
>>> rule = RewriteRule(lhs, repl_list, variables)
"""
def __init__(self, lhs, rhs, vars=()):
if not isinstance(vars, tuple):
raise TypeError("vars must be a tuple of variables")
self.lhs = lhs
if callable(rhs):
self.subs = rhs
else:
self.subs = self._apply
self.rhs = rhs
self._varlist = [t for t in Traverser(lhs) if t in vars]
# Reduce vars down to just variables found in lhs
self.vars = tuple(sorted(set(self._varlist)))
def _apply(self, sub_dict):
term = self.rhs
for key, val in sub_dict.items():
term = subs(term, key, val)
return term
def __str__(self):
return "RewriteRule({0}, {1}, {2})".format(self.lhs, self.rhs,
self.vars)
def __repr__(self):
return str(self)
class RuleSet(object):
"""A set of rewrite rules.
Forms a structure for fast rewriting over a set of rewrite rules. This
allows for syntactic matching of terms to patterns for many patterns at
the same time.
Examples
--------
>>> def f(*args): pass
>>> def g(*args): pass
>>> def h(*args): pass
>>> from operator import add
>>> rs = RuleSet( # Make RuleSet with two Rules
... RewriteRule((add, 'x', 0), 'x', ('x',)),
... RewriteRule((f, (g, 'x'), 'y'),
... (h, 'x', 'y'),
... ('x', 'y')))
>>> rs.rewrite((add, 2, 0)) # Apply ruleset to single task
2
>>> rs.rewrite((f, (g, 'a', 3))) # doctest: +SKIP
(h, 'a', 3)
>>> dsk = {'a': (add, 2, 0), # Apply ruleset to full dask graph
... 'b': (f, (g, 'a', 3))}
>>> from toolz import valmap
>>> valmap(rs.rewrite, dsk) # doctest: +SKIP
{'a': 2,
'b': (h, 'a', 3)}
Attributes
----------
rules : list
A list of `RewriteRule`s included in the `RuleSet`.
"""
def __init__(self, *rules):
"""Create a `RuleSet` for a number of rules
Parameters
----------
rules
One or more instances of RewriteRule
"""
self._net = Node()
self.rules = []
for p in rules:
self.add(p)
def add(self, rule):
"""Add a rule to the RuleSet.
Parameters
----------
rule : RewriteRule
"""
if not isinstance(rule, RewriteRule):
raise TypeError("rule must be instance of RewriteRule")
vars = rule.vars
curr_node = self._net
ind = len(self.rules)
# List of variables, in order they appear in the POT of the term
for t in Traverser(rule.lhs):
prev_node = curr_node
if t in vars:
t = VAR
if t in curr_node.edges:
curr_node = curr_node.edges[t]
else:
curr_node.edges[t] = Node()
curr_node = curr_node.edges[t]
# We've reached a leaf node. Add the term index to this leaf.
prev_node.edges[t].patterns.append(ind)
self.rules.append(rule)
def iter_matches(self, term):
"""A generator that lazily finds matchings for term from the RuleSet.
Parameters
----------
term : task
Yields
------
Tuples of `(rule, subs)`, where `rule` is the rewrite rule being
matched, and `subs` is a dictionary mapping the variables in the lhs
of the rule to their matching values in the term."""
S = Traverser(term)
for m, syms in _match(S, self._net):
for i in m:
rule = self.rules[i]
subs = _process_match(rule, syms)
if subs is not None:
yield rule, subs
def _rewrite(self, term):
"""Apply the rewrite rules in RuleSet to top level of term"""
for rule, sd in self.iter_matches(term):
# We use for (...) because it's fast in all cases for getting the
# first element from the match iterator. As we only want that
# element, we break here
term = rule.subs(sd)
break
return term
def rewrite(self, task, strategy="bottom_up"):
"""Apply the `RuleSet` to `task`.
This applies the most specific matching rule in the RuleSet to the
task, using the provided strategy.
Parameters
----------
term: a task
The task to be rewritten
strategy: str, optional
The rewriting strategy to use. Options are "bottom_up" (default),
or "top_level".
Examples
--------
Suppose there was a function `add` that returned the sum of 2 numbers,
and another function `double` that returned twice its input:
>>> add = lambda x, y: x + y
>>> double = lambda x: 2*x
Now suppose `double` was *significantly* faster than `add`, so
you'd like to replace all expressions `(add, x, x)` with `(double,
x)`, where `x` is a variable. This can be expressed as a rewrite rule:
>>> rule = RewriteRule((add, 'x', 'x'), (double, 'x'), ('x',))
>>> rs = RuleSet(rule)
This can then be applied to terms to perform the rewriting:
>>> term = (add, (add, 2, 2), (add, 2, 2))
>>> rs.rewrite(term) # doctest: +SKIP
(double, (double, 2))
If we only wanted to apply this to the top level of the term, the
`strategy` kwarg can be set to "top_level".
>>> rs.rewrite(term) # doctest: +SKIP
(double, (add, 2, 2))
"""
return strategies[strategy](self, task)
def _top_level(net, term):
return net._rewrite(term)
def _bottom_up(net, term):
if istask(term):
term = (head(term),) + tuple(_bottom_up(net, t) for t in args(term))
elif isinstance(term, list):
term = [_bottom_up(net, t) for t in args(term)]
return net._rewrite(term)
strategies = {'top_level': _top_level,
'bottom_up': _bottom_up}
def _match(S, N):
"""Structural matching of term S to discrimination net node N."""
stack = deque()
restore_state_flag = False
# matches are stored in a tuple, because all mutations result in a copy,
# preventing operations from changing matches stored on the stack.
matches = ()
while True:
if S.current is END:
yield N.patterns, matches
try:
# This try-except block is to catch hashing errors from un-hashable
# types. This allows for variables to be matched with un-hashable
# objects.
n = N.edges.get(S.current, None)
if n and not restore_state_flag:
stack.append((S.copy(), N, matches))
N = n
S.next()
continue
except TypeError:
pass
n = N.edges.get(VAR, None)
if n:
restore_state_flag = False
matches = matches + (S.term,)
S.skip()
N = n
continue
try:
# Backtrack here
(S, N, matches) = stack.pop()
restore_state_flag = True
except:
return
def _process_match(rule, syms):
"""Process a match to determine if it is correct, and to find the correct
substitution that will convert the term into the pattern.
Parameters
----------
rule : RewriteRule
syms : iterable
Iterable of subterms that match a corresponding variable.
Returns
-------
A dictionary of {vars : subterms} describing the substitution to make the
pattern equivalent with the term. Returns `None` if the match is
invalid."""
subs = {}
varlist = rule._varlist
if not len(varlist) == len(syms):
raise RuntimeError("length of varlist doesn't match length of syms.")
for v, s in zip(varlist, syms):
if v in subs and subs[v] != s:
return None
else:
subs[v] = s
return subs