New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dead branch prune before type inference. #3592
Add dead branch prune before type inference. #3592
Conversation
Codecov Report
@@ Coverage Diff @@
## master #3592 +/- ##
=========================================
Coverage ? 80.68%
=========================================
Files ? 393
Lines ? 80485
Branches ? 9164
=========================================
Hits ? 64942
Misses ? 14127
Partials ? 1416 |
Much of |
5a64996
to
a45c867
Compare
As title.
a45c867
to
80428ac
Compare
Currently failing Flake8:
|
As title.
numba/analysis.py
Outdated
# find *all* branches | ||
branches = [] | ||
for idx, blk in func_ir.blocks.items(): | ||
tmp = [_ for _ in blk.find_insts(cls=ir.Branch)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using _
for var-name that is used is unusual. Also this statement can just be tmp = list(blk.find_insts(cls=ir.Branch))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah, thanks, will change.
numba/analysis.py
Outdated
# only prune branches to blocks that have a single access route, this is | ||
# conservative | ||
kill_count = 0 | ||
for targets in cfg._succs.values(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try not to use attributes with name starting with _
, they are considered as "private"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, and I wouldn't normally, just couldn't find something with the information I wanted. Realise now it's probably .predecessors()
.
numba/analysis.py
Outdated
# conservative | ||
kill_count = 0 | ||
for targets in cfg._succs.values(): | ||
if kill in targets: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you look at the predecessors of kill
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR; this is very useful! I have some comments that I think can help robustness and avoid code duplication. Let me know if you have any questions.
numba/analysis.py
Outdated
def find_branches(func_ir): | ||
# find *all* branches | ||
branches = [] | ||
for idx, blk in func_ir.blocks.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idx
is unused.
numba/analysis.py
Outdated
# find *all* branches | ||
branches = [] | ||
for idx, blk in func_ir.blocks.items(): | ||
tmp = [_ for _ in blk.find_insts(cls=ir.Branch)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A basic block has one jump or branch at the end by definition. A simple check like isinstance(blk.body[-1], ir.Branch)
is enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, done.
numba/analysis.py
Outdated
store = dict() | ||
for branch in tmp: | ||
store['branch'] = branch | ||
expr = blk.find_variable_assignment(branch.cond.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
find_variable_assignment
is potentially unsafe since it looks at assignments of this block only. It is also doing wasteful iteration over IR since definition information is available in func_ir._definitions
. Using ir_utils.get_definition
solves this problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's safe by virtue of the next line, i.e. if it cannot be found it won't be used? However this is not ideal as noted, will switch to ir_utils.get_definition
. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right; it's safe since blk
is current block. Sounds good.
numba/analysis.py
Outdated
for branch in tmp: | ||
store['branch'] = branch | ||
expr = blk.find_variable_assignment(branch.cond.name) | ||
if expr is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code would be simpler if the condition variable is just saved here and resolving expression is deferred to later.
numba/analysis.py
Outdated
class Unknown(object): | ||
pass | ||
|
||
def get_const_val(func_ir, arg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for this function. ir_utils.find_const
does this work without going over the IR.
else: | ||
return Unknown() | ||
|
||
def resolve_input_arg_const(input_arg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ideally, this should be folded into ir_utils.find_const
.
@@ -259,3 +260,179 @@ def find_top_level_loops(cfg): | |||
for loop in cfg.loops().values(): | |||
if loop.header not in blocks_in_loop: | |||
yield loop | |||
|
|||
# Functions to manipulate IR | |||
def dead_branch_prune(func_ir, called_args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
array analysis does some pruning for some reason: https://github.com/numba/numba/blob/master/numba/array_analysis.py#L1066
Maybe it is unnecessary now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would have to check, array analysis isn't on the default code path though? Also I think that code is just looking at e.g. if 1:
or if True:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct, it's not on the default path.
Seems like it, but I don't remember the reason (upstream optimization produces if True:
?). Just something to keep in mind for future, not necessarily this PR.
numba/analysis.py
Outdated
branch_assign = block.find_variable_assignment(cond.cond.name) | ||
block.body.remove(branch_assign) | ||
|
||
# replace the branch with a direct jump |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Branch is turned into jump only when the target block is removed. However, removing target block is not a necessary condition.
numba/analysis.py
Outdated
backbone = cfg.backbone() | ||
rem = [] | ||
|
||
for idx, doms in dom.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the block if it is dominating some block and is not in the backbone?
I think a block can be removed if it doesn't have any predecessors left. Backbone check is just an assertion showing something has gone wrong.
Also, removing blocks can be done once after the big loop when all constant branches are resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks, done. Turns out that just switching branches for jumps and then computing the dead nodes in the CFG is all that's needed.
numba/analysis.py
Outdated
func_ir.blocks.pop(x) | ||
|
||
block = branch['block'] | ||
# remove computation of the branch condition, it's dead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be safe func_ir._definitions should be updated probably. Also, maybe removing stuff should be left to the remove_dead pass to be more robust.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will use remove_dead
for statement level adjustments, think it will also update func_ir._definitions
too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, it seems in practice that this breaks parfors as there's e.g. sentinels and other dead statements present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems surprising to me, could you point to examples where this happens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is hard to demonstrate with a sample code due to the following...
Parfors compilation first enters the compilation pipeline via e.g. @njit
and will use a pipeline specified by the pipeline_class
kwarg which defaults to compiler.Pipeline
. The compilation pipeline proceeds, but compilation is re-entered by the npyufunc
gufunc compilation mechanism, this only ever uses the default pipeline:
numba/numba/npyufunc/parfor.py
Lines 1005 to 1012 in fdbe264
kernel_func = compiler.compile_ir( | |
typingctx, | |
targetctx, | |
gufunc_ir, | |
gufunc_param_types, | |
types.none, | |
flags, | |
locals) |
As a result, a potential reproducer like:
import numba
import numpy as np
class TestPipeline(numba.compiler.BasePipeline):
def define_pipelines(self, pm):
pm.create_pipeline('fake_parfors')
self.add_preprocessing_stage(pm)
self.add_with_handling_stage(pm)
self.add_pre_typing_stage(pm)
pm.add_stage(self.rm_dead_stage, "DCE ahead of type inference")
self.add_typing_stage(pm)
self.add_optimization_stage(pm)
pm.add_stage(self.stage_ir_legalization,
"ensure IR is legal prior to lowering")
self.add_lowering_stage(pm)
self.add_cleanup_stage(pm)
def rm_dead_stage(self):
numba.ir_utils.remove_dead(
self.func_ir.blocks, self.func_ir.arg_names, self.func_ir)
@numba.jit(parallel=True, pipeline_class=TestPipeline)
def f(a):
b = a + 1
c = b + a
return c
x = np.arange(10.)
f(x)
is destined to not reproduce because the pipeline_class
will not be used in the call made in the compilation pipeline reentry in numba::npyufunc::parfor::_create_gufunc_for_parfor_body
.
If, however, the default compiler pipeline is patched with DCE ahead of typing, like this stuartarchibald@b40257b and e.g. this example is run (one of many from the parfors_tests.py
that fail):
import numba
import numpy as np
@numba.jit(parallel=True)
def test_impl():
X = np.array(1)
Y = np.ones((10, 12))
return np.sum(X + Y)
test_impl()
the following appears:
$ python dce_breaks_parfors.py
Removing $phi28.1.13 = $16.6.12
Removing $16.6.12 = getiter(value=$16.5.11)
Removing $16.4.10 = getattr(value=$in_arr.0.71, attr=shape)
Removing $0.3.3 = getattr(value=$0.2.2, attr=init_prange)
Removing $0.2.2 = getattr(value=$0.1.1, attr=parfor)
Removing $0.1.1 = global(numba: <module 'numba' from '<path>/numba/numba/__init__.py'>)
Removing in_arr.0 = $in_arr.0.71
Removing $const_ind_0.65 = const(int, 0)
Removing $0.10 = getattr(value=$0.9, attr=sum)
Removing $0.9 = global(np: <module 'numpy' from '<env>/lib/python3.7/site-packages/numpy/__init__.py'>)
Removing Y = $Y.70
Removing $const0.7 = const(tuple, (10, 12))
Removing $0.6 = getattr(value=$0.5, attr=ones)
Removing $0.5 = global(np: <module 'numpy' from '<env>/lib/python3.7/site-packages/numpy/__init__.py'>)
Removing X = $X.69
Removing $0.14 = val.6
Removing $48.2.27 = val.6
Removing i.20 = $parfor_index_tuple_var.68
Removing $16.2.8 = getattr(value=$16.1.7, attr=pndindex)
Removing $16.1.7 = global(numba: <module 'numba' from '<path>/numba/numba/__init__.py'>)
Removing $ravel.59_size0.93 = static_getitem(value=$ravel.59_shape.92, index=0, index_var=None)
Removing $ravel.59_shape.92 = getattr(value=$ravel.59, attr=shape)
Removing $ravel.59_size0.103 = static_getitem(value=$ravel.59_shape.102, index=0, index_var=None)
Removing $ravel.59_shape.102 = getattr(value=$ravel.59, attr=shape)
Removing $ravel.59_size0.112 = static_getitem(value=$ravel.59_shape.111, index=0, index_var=None)
Removing $ravel.59_shape.111 = getattr(value=$ravel.59, attr=shape)
Removing $ravel.59_size0.121 = static_getitem(value=$ravel.59_shape.120, index=0, index_var=None)
Removing $ravel.59_shape.120 = getattr(value=$ravel.59, attr=shape)
Removing $ravel.59_size0.131 = static_getitem(value=$ravel.59_shape.130, index=0, index_var=None)
Removing $ravel.59_shape.130 = getattr(value=$ravel.59, attr=shape)
Removing $in_arr.0.71_size1.101 = static_getitem(value=$in_arr.0.71_shape.99, index=1, index_var=None)
Removing $in_arr.0.71_size0.100 = static_getitem(value=$in_arr.0.71_shape.99, index=0, index_var=None)
Removing $in_arr.0.71_shape.99 = getattr(value=$in_arr.0.71, attr=shape)
Removing $in_arr.0.71 = call $empty_attr_attr.54($tuple_var.52, $np_typ_var.55, func=$empty_attr_attr.54, args=[Var($tuple_var.52, dce_breaks_parfors.py (8)), Var($np_typ_var.55, dce_breaks_parfors.py (8))], kws=(), vararg=None)
Removing $np_typ_var.55 = getattr(value=$np_g_var.53, attr=float64)
Removing $empty_attr_attr.54 = getattr(value=$np_g_var.53, attr=empty)
Removing $np_g_var.53 = global(np: <module 'numpy' from '<env>/lib/python3.7/site-packages/numpy/__init__.py'>)
Removing $tuple_var.52 = build_tuple(items=[Var($const0.7_size0.31, dce_breaks_parfors.py (7)), Var($const0.7_size1.32, dce_breaks_parfors.py (7))])
Removing $30.5.24 = $expr_out_var.56
Removing $phi58.2 = $phi18.1
Removing $phi58.1 = $18.3
Traceback (most recent call last):
File "<path>/numba/numba/errors.py", line 617, in new_error_context
yield
File "<path>/numba/numba/lowering.py", line 259, in lower_block
self.lower_inst(inst)
File "<path>/numba/numba/pylowering.py", line 96, in lower_inst
value = self.lower_assign(inst)
File "<path>/numba/numba/pylowering.py", line 186, in lower_assign
val = self.loadvar(value.name)
File "<path>/numba/numba/pylowering.py", line 555, in loadvar
assert name in self._live_vars, name
AssertionError: $Y.70
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "dce_breaks_parfors.py", line 10, in <module>
test_impl()
File "<path>/numba/numba/dispatcher.py", line 369, in _compile_for_args
raise e
File "<path>/numba/numba/dispatcher.py", line 326, in _compile_for_args
return self.compile(tuple(argtypes))
File "<path>/numba/numba/compiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "<path>/numba/numba/dispatcher.py", line 659, in compile
cres = self._compiler.compile(args, return_type)
File "<path>/numba/numba/dispatcher.py", line 82, in compile
pipeline_class=self.pipeline_class)
File "<path>/numba/numba/compiler.py", line 928, in compile_extra
return pipeline.compile_extra(func)
File "<path>/numba/numba/compiler.py", line 371, in compile_extra
return self._compile_bytecode()
File "<path>/numba/numba/compiler.py", line 859, in _compile_bytecode
return self._compile_core()
File "<path>/numba/numba/compiler.py", line 846, in _compile_core
res = pm.run(self.status)
File "<path>/numba/numba/compiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "<path>/numba/numba/compiler.py", line 252, in run
raise patched_exception
File "<path>/numba/numba/compiler.py", line 243, in run
stage()
File "<path>/numba/numba/compiler.py", line 687, in stage_objectmode_backend
self._backend(lowerfn, objectmode=True)
File "<path>/numba/numba/compiler.py", line 663, in _backend
lowered = lowerfn()
File "<path>/numba/numba/compiler.py", line 635, in backend_object_mode
self.flags)
File "<path>/numba/numba/compiler.py", line 1074, in py_lowering_stage
lower.lower()
File "<path>/numba/numba/lowering.py", line 178, in lower
self.lower_normal_function(self.fndesc)
File "<path>/numba/numba/lowering.py", line 219, in lower_normal_function
entry_block_tail = self.lower_function_body()
File "<path>/numba/numba/lowering.py", line 244, in lower_function_body
self.lower_block(block)
File "<path>/numba/numba/lowering.py", line 259, in lower_block
self.lower_inst(inst)
File "<env>/lib/python3.7/contextlib.py", line 130, in __exit__
self.gen.throw(type, value, traceback)
File "<path>/numba/numba/errors.py", line 625, in new_error_context
six.reraise(type(newerr), newerr, tb)
File "<path>/numba/numba/six.py", line 659, in reraise
raise value
numba.errors.LoweringError: Failed in object mode pipeline (step: object mode backend)
$Y.70
File "dce_breaks_parfors.py", line 7:
def test_impl():
<source elided>
X = np.array(1)
Y = np.ones((10, 12))
^
[1] During: lowering "Y = $Y.70" at dce_breaks_parfors.py (7)
in other cases of running numba.tests.test_parfors
test failures, the @do_scheduling
function is often missing or phi
values are missing from the typemap.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CC @DrTodd13
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stuartarchibald @ehsantn Here's my working theory based on what I have observed. We get into the recursive lowering of functions called by test_impl (confirmed) and the remove dead code removed a variable that a subsequent "del" of that variable tried to find (confirmed). This caused an exception (confirmed) that propagated all the way back up and Numba tried to fallback to a compile of test_impl with parallel=False (I'm pretty sure here as I see the original code for test_impl being lowered). However, the compile with parallel=True changed some data structure in such a way it is no longer compatible with the unchanged version of the function IR and that is why the compile with parallel=False also failed. The last part is speculation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great detective work! I think this hasn't been a problem before since ParallelAccelerator passes ignore Dels altogether and just regenerate after everything is done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @DrTodd13, this seems logical and well found!!
As a generalisation, I've observed that a few of the passes/mutations performed by the ParallelAccelerator code may alter the original IR in such a way it becomes un-lowerable by the standard pipeline (particularly the inlining passes, there's a few tickets about this), this correlates. Assuming the patch in #3727 makes the DCE pass valid for general consumption, is its use ahead of typing still dangerous (I presume so due to unknown aliasing etc)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so. Any issues with removing dead code after inference?
As title.
@ehsantn thanks very much for providing such helpful and detailed feedback, in light of this I've had another go at implementing the core algorithm. I've completely removed the use of a DCE pass on the basis that running it ahead of typing may well be dangerous, this can be added post typing in another PR if desirable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the effort. This looks much improved. I think it's ready to merge (just a minor optional comment).
Flake8 has a few minor complaints in |
As title.
As title.