-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
ir_print.py
82 lines (73 loc) · 2.9 KB
/
ir_print.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from numba.core import errors, ir
from numba.core.rewrites import register_rewrite, Rewrite
@register_rewrite('before-inference')
class RewritePrintCalls(Rewrite):
"""
Rewrite calls to the print() global function to dedicated IR print() nodes.
"""
def match(self, func_ir, block, typemap, calltypes):
self.prints = prints = {}
self.block = block
# Find all assignments with a right-hand print() call
for inst in block.find_insts(ir.Assign):
if isinstance(inst.value, ir.Expr) and inst.value.op == 'call':
expr = inst.value
try:
callee = func_ir.infer_constant(expr.func)
except errors.ConstantInferenceError:
continue
if callee is print:
if expr.kws:
# Only positional args are supported
msg = ("Numba's print() function implementation does not "
"support keyword arguments.")
raise errors.UnsupportedError(msg, inst.loc)
prints[inst] = expr
return len(prints) > 0
def apply(self):
"""
Rewrite `var = call <print function>(...)` as a sequence of
`print(...)` and `var = const(None)`.
"""
new_block = self.block.copy()
new_block.clear()
for inst in self.block.body:
if inst in self.prints:
expr = self.prints[inst]
print_node = ir.Print(args=expr.args, vararg=expr.vararg,
loc=expr.loc)
new_block.append(print_node)
assign_node = ir.Assign(value=ir.Const(None, loc=expr.loc),
target=inst.target,
loc=inst.loc)
new_block.append(assign_node)
else:
new_block.append(inst)
return new_block
@register_rewrite('before-inference')
class DetectConstPrintArguments(Rewrite):
"""
Detect and store constant arguments to print() nodes.
"""
def match(self, func_ir, block, typemap, calltypes):
self.consts = consts = {}
self.block = block
for inst in block.find_insts(ir.Print):
if inst.consts:
# Already rewritten
continue
for idx, var in enumerate(inst.args):
try:
const = func_ir.infer_constant(var)
except errors.ConstantInferenceError:
continue
consts.setdefault(inst, {})[idx] = const
return len(consts) > 0
def apply(self):
"""
Store detected constant arguments on their nodes.
"""
for inst in self.block.body:
if inst in self.consts:
inst.consts = self.consts[inst]
return self.block