/
registry.py
98 lines (84 loc) · 3.57 KB
/
registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from collections import defaultdict
from numba.core import config
class Rewrite(object):
'''Defines the abstract base class for Numba rewrites.
'''
def __init__(self, state=None):
'''Constructor for the Rewrite class.
'''
pass
def match(self, func_ir, block, typemap, calltypes):
'''Overload this method to check an IR block for matching terms in the
rewrite.
'''
return False
def apply(self):
'''Overload this method to return a rewritten IR basic block when a
match has been found.
'''
raise NotImplementedError("Abstract Rewrite.apply() called!")
class RewriteRegistry(object):
'''Defines a registry for Numba rewrites.
'''
_kinds = frozenset(['before-inference', 'after-inference'])
def __init__(self):
'''Constructor for the rewrite registry. Initializes the rewrites
member to an empty list.
'''
self.rewrites = defaultdict(list)
def register(self, kind):
"""
Decorator adding a subclass of Rewrite to the registry for
the given *kind*.
"""
if kind not in self._kinds:
raise KeyError("invalid kind %r" % (kind,))
def do_register(rewrite_cls):
if not issubclass(rewrite_cls, Rewrite):
raise TypeError('{0} is not a subclass of Rewrite'.format(
rewrite_cls))
self.rewrites[kind].append(rewrite_cls)
return rewrite_cls
return do_register
def apply(self, kind, state):
'''Given a pipeline and a dictionary of basic blocks, exhaustively
attempt to apply all registered rewrites to all basic blocks.
'''
assert kind in self._kinds
blocks = state.func_ir.blocks
old_blocks = blocks.copy()
for rewrite_cls in self.rewrites[kind]:
# Exhaustively apply a rewrite until it stops matching.
rewrite = rewrite_cls(state)
work_list = list(blocks.items())
while work_list:
key, block = work_list.pop()
matches = rewrite.match(state.func_ir, block, state.typemap,
state.calltypes)
if matches:
if config.DEBUG or config.DUMP_IR:
print("_" * 70)
print("REWRITING (%s):" % rewrite_cls.__name__)
block.dump()
print("_" * 60)
new_block = rewrite.apply()
blocks[key] = new_block
work_list.append((key, new_block))
if config.DEBUG or config.DUMP_IR:
new_block.dump()
print("_" * 70)
# If any blocks were changed, perform a sanity check.
for key, block in blocks.items():
if block != old_blocks[key]:
block.verify()
# Some passes, e.g. _inline_const_arraycall are known to occasionally
# do invalid things WRT ir.Del, others, e.g. RewriteArrayExprs do valid
# things with ir.Del, but the placement is not optimal. The lines below
# fix-up the IR so that ref counts are valid and optimally placed,
# see #4093 for context. This has to be run here opposed to in
# apply() as the CFG needs computing so full IR is needed.
from numba.core import postproc
post_proc = postproc.PostProcessor(state.func_ir)
post_proc.run()
rewrite_registry = RewriteRegistry()
register_rewrite = rewrite_registry.register