-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dead branch prune before type inference. #3592
Changes from 3 commits
fbcf804
80428ac
a5a3a6e
2d2dc9e
55957f9
05ebb5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,6 +7,7 @@ | |||||||||||||||||
|
||||||||||||||||||
from numba import ir | ||||||||||||||||||
from numba.controlflow import CFGraph | ||||||||||||||||||
from numba import types | ||||||||||||||||||
|
||||||||||||||||||
# | ||||||||||||||||||
# Analysis related to variable lifetime | ||||||||||||||||||
|
@@ -259,3 +260,179 @@ def find_top_level_loops(cfg): | |||||||||||||||||
for loop in cfg.loops().values(): | ||||||||||||||||||
if loop.header not in blocks_in_loop: | ||||||||||||||||||
yield loop | ||||||||||||||||||
|
||||||||||||||||||
# Functions to manipulate IR | ||||||||||||||||||
def dead_branch_prune(func_ir, called_args): | ||||||||||||||||||
""" | ||||||||||||||||||
Removes dead branches based on constant inference from function args. | ||||||||||||||||||
This directly mutates the IR. | ||||||||||||||||||
|
||||||||||||||||||
func_ir is the IR | ||||||||||||||||||
called_args are the actual arguments with which the function is called | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
DEBUG = 0 | ||||||||||||||||||
|
||||||||||||||||||
def find_branches(func_ir): | ||||||||||||||||||
# find *all* branches | ||||||||||||||||||
branches = [] | ||||||||||||||||||
for idx, blk in func_ir.blocks.items(): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||
tmp = [_ for _ in blk.find_insts(cls=ir.Branch)] | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yeah, thanks, will change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A basic block has one jump or branch at the end by definition. A simple check like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree, done. |
||||||||||||||||||
store = dict() | ||||||||||||||||||
for branch in tmp: | ||||||||||||||||||
store['branch'] = branch | ||||||||||||||||||
expr = blk.find_variable_assignment(branch.cond.name) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's safe by virtue of the next line, i.e. if it cannot be found it won't be used? However this is not ideal as noted, will switch to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you are right; it's safe since |
||||||||||||||||||
if expr is not None: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code would be simpler if the condition variable is just saved here and resolving expression is deferred to later. |
||||||||||||||||||
val = expr.value | ||||||||||||||||||
if isinstance(val, ir.Expr) and val.op == 'binop': | ||||||||||||||||||
store['op'] = val | ||||||||||||||||||
args = [val.lhs, val.rhs] | ||||||||||||||||||
store['args'] = args | ||||||||||||||||||
store['block'] = blk | ||||||||||||||||||
branches.append(store) | ||||||||||||||||||
return branches | ||||||||||||||||||
|
||||||||||||||||||
def prune(func_ir, branch, all_branches, *conds): | ||||||||||||||||||
# this prunes a given branch and fixes up the IR | ||||||||||||||||||
lhs_cond, rhs_cond = conds | ||||||||||||||||||
take_truebr = branch['op'].fn(lhs_cond, rhs_cond) | ||||||||||||||||||
cond = branch['branch'] | ||||||||||||||||||
if take_truebr: | ||||||||||||||||||
keep = cond.truebr | ||||||||||||||||||
kill = cond.falsebr | ||||||||||||||||||
else: | ||||||||||||||||||
keep = cond.falsebr | ||||||||||||||||||
kill = cond.truebr | ||||||||||||||||||
|
||||||||||||||||||
cfg = compute_cfg_from_blocks(func_ir.blocks) | ||||||||||||||||||
if DEBUG > 0: | ||||||||||||||||||
print("Pruning %s" % kill, branch['branch'], lhs_cond, rhs_cond, branch['op'].fn) | ||||||||||||||||||
if DEBUG > 1: | ||||||||||||||||||
from pprint import pprint | ||||||||||||||||||
cfg.dump() | ||||||||||||||||||
|
||||||||||||||||||
# only prune branches to blocks that have a single access route, this is | ||||||||||||||||||
# conservative | ||||||||||||||||||
kill_count = 0 | ||||||||||||||||||
for targets in cfg._succs.values(): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try not to use attributes with name starting with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, and I wouldn't normally, just couldn't find something with the information I wanted. Realise now it's probably |
||||||||||||||||||
if kill in targets: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you look at the predecessors of |
||||||||||||||||||
kill_count += 1 | ||||||||||||||||||
if kill_count > 1: | ||||||||||||||||||
if DEBUG > 0: | ||||||||||||||||||
print("prune rejected") | ||||||||||||||||||
return | ||||||||||||||||||
|
||||||||||||||||||
# remove dominators that are not on the backbone | ||||||||||||||||||
dom = cfg.dominators() | ||||||||||||||||||
postdom = cfg.post_dominators() | ||||||||||||||||||
backbone = cfg.backbone() | ||||||||||||||||||
rem = [] | ||||||||||||||||||
|
||||||||||||||||||
for idx, doms in dom.items(): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the block if it is dominating some block and is not in the backbone? Also, removing blocks can be done once after the big loop when all constant branches are resolved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thanks, done. Turns out that just switching branches for jumps and then computing the dead nodes in the CFG is all that's needed. |
||||||||||||||||||
if kill in doms and kill not in backbone: | ||||||||||||||||||
rem.append(idx) | ||||||||||||||||||
if not rem: | ||||||||||||||||||
if DEBUG > 0: | ||||||||||||||||||
msg = "prune rejected no kill in dominators and not in backbone" | ||||||||||||||||||
print(msg) | ||||||||||||||||||
return | ||||||||||||||||||
|
||||||||||||||||||
for x in rem: | ||||||||||||||||||
func_ir.blocks.pop(x) | ||||||||||||||||||
|
||||||||||||||||||
block = branch['block'] | ||||||||||||||||||
# remove computation of the branch condition, it's dead | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to be safe func_ir._definitions should be updated probably. Also, maybe removing stuff should be left to the remove_dead pass to be more robust. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmmm, it seems in practice that this breaks parfors as there's e.g. sentinels and other dead statements present. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems surprising to me, could you point to examples where this happens? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is hard to demonstrate with a sample code due to the following... numba/numba/npyufunc/parfor.py Lines 1005 to 1012 in fdbe264
As a result, a potential reproducer like: import numba
import numpy as np
class TestPipeline(numba.compiler.BasePipeline):
def define_pipelines(self, pm):
pm.create_pipeline('fake_parfors')
self.add_preprocessing_stage(pm)
self.add_with_handling_stage(pm)
self.add_pre_typing_stage(pm)
pm.add_stage(self.rm_dead_stage, "DCE ahead of type inference")
self.add_typing_stage(pm)
self.add_optimization_stage(pm)
pm.add_stage(self.stage_ir_legalization,
"ensure IR is legal prior to lowering")
self.add_lowering_stage(pm)
self.add_cleanup_stage(pm)
def rm_dead_stage(self):
numba.ir_utils.remove_dead(
self.func_ir.blocks, self.func_ir.arg_names, self.func_ir)
@numba.jit(parallel=True, pipeline_class=TestPipeline)
def f(a):
b = a + 1
c = b + a
return c
x = np.arange(10.)
f(x) is destined to not reproduce because the If, however, the default compiler pipeline is patched with DCE ahead of typing, like this stuartarchibald@b40257b and e.g. this example is run (one of many from the import numba
import numpy as np
@numba.jit(parallel=True)
def test_impl():
X = np.array(1)
Y = np.ones((10, 12))
return np.sum(X + Y)
test_impl() the following appears:
in other cases of running There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CC @DrTodd13 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stuartarchibald @ehsantn Here's my working theory based on what I have observed. We get into the recursive lowering of functions called by test_impl (confirmed) and the remove dead code removed a variable that a subsequent "del" of that variable tried to find (confirmed). This caused an exception (confirmed) that propagated all the way back up and Numba tried to fallback to a compile of test_impl with parallel=False (I'm pretty sure here as I see the original code for test_impl being lowered). However, the compile with parallel=True changed some data structure in such a way it is no longer compatible with the unchanged version of the function IR and that is why the compile with parallel=False also failed. The last part is speculation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great detective work! I think this hasn't been a problem before since ParallelAccelerator passes ignore Dels altogether and just regenerate after everything is done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @DrTodd13, this seems logical and well found!! As a generalisation, I've observed that a few of the passes/mutations performed by the ParallelAccelerator code may alter the original IR in such a way it becomes un-lowerable by the standard pipeline (particularly the inlining passes, there's a few tickets about this), this correlates. Assuming the patch in #3727 makes the DCE pass valid for general consumption, is its use ahead of typing still dangerous (I presume so due to unknown aliasing etc)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe so. Any issues with removing dead code after inference? |
||||||||||||||||||
branch_assign = block.find_variable_assignment(cond.cond.name) | ||||||||||||||||||
block.body.remove(branch_assign) | ||||||||||||||||||
|
||||||||||||||||||
# replace the branch with a direct jump | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Branch is turned into jump only when the target block is removed. However, removing target block is not a necessary condition. |
||||||||||||||||||
jmp = ir.Jump(keep, loc=cond.loc) | ||||||||||||||||||
block.body[-1] = jmp | ||||||||||||||||||
|
||||||||||||||||||
branches = find_branches(func_ir) | ||||||||||||||||||
noldbranches = len(branches) | ||||||||||||||||||
nnewbranches = 0 | ||||||||||||||||||
|
||||||||||||||||||
class Unknown(object): | ||||||||||||||||||
pass | ||||||||||||||||||
|
||||||||||||||||||
def get_const_val(func_ir, arg): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for this function. |
||||||||||||||||||
""" | ||||||||||||||||||
Resolves the arg to a const if there is a variable assignment like: | ||||||||||||||||||
`$label = const(thing)` | ||||||||||||||||||
""" | ||||||||||||||||||
possibles = [] | ||||||||||||||||||
for idx, blk in func_ir.blocks.items(): | ||||||||||||||||||
found = blk.find_variable_assignment(arg.name) | ||||||||||||||||||
if found is not None and isinstance(found.value, ir.Const): | ||||||||||||||||||
possibles.append(found.value.value) | ||||||||||||||||||
# if there's more than one definition we don't know which const | ||||||||||||||||||
# propagates to here | ||||||||||||||||||
if len(possibles) == 1: | ||||||||||||||||||
return possibles[0] | ||||||||||||||||||
else: | ||||||||||||||||||
return Unknown() | ||||||||||||||||||
|
||||||||||||||||||
def resolve_input_arg_const(input_arg): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ideally, this should be folded into |
||||||||||||||||||
""" | ||||||||||||||||||
Resolves an input arg to a constant (if possible) | ||||||||||||||||||
""" | ||||||||||||||||||
idx = func_ir.arg_names.index(input_arg) | ||||||||||||||||||
input_arg_ty = called_args[idx] | ||||||||||||||||||
|
||||||||||||||||||
# comparing to None? | ||||||||||||||||||
if isinstance(input_arg_ty, types.NoneType): | ||||||||||||||||||
return None | ||||||||||||||||||
|
||||||||||||||||||
# is it a kwarg default | ||||||||||||||||||
if isinstance(input_arg_ty, types.Omitted): | ||||||||||||||||||
if isinstance(input_arg_ty.value, types.NoneType): | ||||||||||||||||||
return None | ||||||||||||||||||
else: | ||||||||||||||||||
return input_arg_ty.value | ||||||||||||||||||
|
||||||||||||||||||
# is it a literal | ||||||||||||||||||
const = getattr(input_arg_ty, 'literal_value', None) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Input arg can't be a literal. Otherwise, functions would be recompiled for every different arg value. (this is actually a feature we need in certain circumstances) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering if it should actually be an option? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, probably for future when value-based dispatch is supported. |
||||||||||||||||||
if const is not None: | ||||||||||||||||||
return const | ||||||||||||||||||
|
||||||||||||||||||
return Unknown() | ||||||||||||||||||
|
||||||||||||||||||
# keep iterating until branch prune stabilizes | ||||||||||||||||||
while(noldbranches != nnewbranches): | ||||||||||||||||||
# This looks for branches where: | ||||||||||||||||||
# an arg of the condition is in input args and const/literal | ||||||||||||||||||
# an arg of the condition is a const/literal | ||||||||||||||||||
# if the condition is met it will remove blocks that are not reached and | ||||||||||||||||||
# patch up the branch | ||||||||||||||||||
for branch in branches: | ||||||||||||||||||
const_arg_val = False | ||||||||||||||||||
const_conds = [] | ||||||||||||||||||
for arg in branch['args']: | ||||||||||||||||||
resolved_const = Unknown() | ||||||||||||||||||
if arg.name in func_ir.arg_names: | ||||||||||||||||||
# it's an e.g. literal argument to the function | ||||||||||||||||||
resolved_const = resolve_input_arg_const(arg.name) | ||||||||||||||||||
else: | ||||||||||||||||||
# it's some const argument to the function | ||||||||||||||||||
resolved_const = get_const_val(func_ir, arg) | ||||||||||||||||||
|
||||||||||||||||||
if not isinstance(resolved_const, Unknown): | ||||||||||||||||||
const_conds.append(resolved_const) | ||||||||||||||||||
|
||||||||||||||||||
# lhs/rhs are consts | ||||||||||||||||||
if len(const_conds) == 2: | ||||||||||||||||||
if DEBUG > 1: | ||||||||||||||||||
print("before".center(80, '-')) | ||||||||||||||||||
print(func_ir.dump()) | ||||||||||||||||||
|
||||||||||||||||||
prune(func_ir, branch, branches, *const_conds) | ||||||||||||||||||
|
||||||||||||||||||
if DEBUG > 1: | ||||||||||||||||||
print("after".center(80, '-')) | ||||||||||||||||||
print(func_ir.dump()) | ||||||||||||||||||
|
||||||||||||||||||
noldbranches = len(branches) | ||||||||||||||||||
branches = find_branches(func_ir) | ||||||||||||||||||
nnewbranches = len(branches) | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
array analysis does some pruning for some reason: https://github.com/numba/numba/blob/master/numba/array_analysis.py#L1066
Maybe it is unnecessary now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would have to check, array analysis isn't on the default code path though? Also I think that code is just looking at e.g.
if 1:
orif True:
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct, it's not on the default path.
Seems like it, but I don't remember the reason (upstream optimization produces
if True:
?). Just something to keep in mind for future, not necessarily this PR.