Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dead branch prune before type inference. #3592

Merged
merged 6 commits into from Feb 27, 2019

Conversation

stuartarchibald
Copy link
Contributor

As title.

@codecov-io
Copy link

codecov-io commented Dec 12, 2018

Codecov Report

❗ No coverage uploaded for pull request base (master@865b87a). Click here to learn what that means.
The diff coverage is n/a.

@@            Coverage Diff            @@
##             master    #3592   +/-   ##
=========================================
  Coverage          ?   80.68%           
=========================================
  Files             ?      393           
  Lines             ?    80485           
  Branches          ?     9164           
=========================================
  Hits              ?    64942           
  Misses            ?    14127           
  Partials          ?     1416

@sklam
Copy link
Member

sklam commented Dec 18, 2018

Much of stage_dead_branch_prune should move into numba.analysis.

@stuartarchibald stuartarchibald changed the title WIP: Add dead branch prune before type inference. Add dead branch prune before type inference. Jan 3, 2019
@seibert
Copy link
Contributor

seibert commented Jan 3, 2019

Currently failing Flake8:

2019-01-03T17:15:45.6970830Z numba/tests/test_analysis.py:4:1: F401 'copy.deepcopy' imported but unused
2019-01-03T17:15:45.6971795Z numba/tests/test_analysis.py:5:1: F401 'numpy as np' imported but unused
2019-01-03T17:15:45.6972330Z numba/tests/test_analysis.py:7:1: F401 'numba.unittest_support as unittest' imported but unused
2019-01-03T17:15:45.6973045Z numba/tests/test_analysis.py:8:1: F401 'numba.compiler.Flags' imported but unused
2019-01-03T17:15:45.6973358Z numba/tests/test_analysis.py:9:1: F401 'numba.jit' imported but unused
2019-01-03T17:15:45.6974019Z numba/tests/test_analysis.py:9:1: F401 'numba.typeof' imported but unused
2019-01-03T17:15:45.6974961Z numba/tests/test_analysis.py:9:1: F401 'numba.errors' imported but unused
2019-01-03T17:15:45.6975219Z numba/tests/test_analysis.py:9:1: F401 'numba.utils' imported but unused
2019-01-03T17:15:45.6975525Z numba/tests/test_analysis.py:9:1: F401 'numba.config' imported but unused
2019-01-03T17:15:45.6975773Z numba/tests/test_analysis.py:10:1: F401 '.support.tag' imported but unused
2019-01-03T17:15:45.6976044Z numba/tests/test_analysis.py:13:1: F401 'numba.ir_utils.get_ir_of_code' imported but unused
2019-01-03T17:15:45.6976364Z numba/tests/test_analysis.py:138:17: F841 local variable 'z' is assigned to but never used
2019-01-03T17:15:45.6976639Z numba/tests/test_analysis.py:205:17: F841 local variable 'dead' is assigned to but never used
2019-01-03T17:15:45.6976907Z numba/tests/test_analysis.py:240:17: F841 local variable 'dead' is assigned to but never used
2019-01-03T17:15:45.6977396Z numba/tests/test_analysis.py:249:9: F841 local variable 'cfunc' is assigned to but never used
2019-01-03T17:15:45.6977666Z numba/tests/test_analysis.py:264:9: F841 local variable 'cfunc' is assigned to but never used

@seibert seibert requested a review from ehsantn January 3, 2019 20:20
As title.
# find *all* branches
branches = []
for idx, blk in func_ir.blocks.items():
tmp = [_ for _ in blk.find_insts(cls=ir.Branch)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using _ for var-name that is used is unusual. Also this statement can just be tmp = list(blk.find_insts(cls=ir.Branch))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah, thanks, will change.

# only prune branches to blocks that have a single access route, this is
# conservative
kill_count = 0
for targets in cfg._succs.values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try not to use attributes with name starting with _, they are considered as "private"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, and I wouldn't normally, just couldn't find something with the information I wanted. Realise now it's probably .predecessors().

# conservative
kill_count = 0
for targets in cfg._succs.values():
if kill in targets:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you look at the predecessors of kill instead?

Copy link
Collaborator

@ehsantn ehsantn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR; this is very useful! I have some comments that I think can help robustness and avoid code duplication. Let me know if you have any questions.

def find_branches(func_ir):
# find *all* branches
branches = []
for idx, blk in func_ir.blocks.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx is unused.

# find *all* branches
branches = []
for idx, blk in func_ir.blocks.items():
tmp = [_ for _ in blk.find_insts(cls=ir.Branch)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A basic block has one jump or branch at the end by definition. A simple check like isinstance(blk.body[-1], ir.Branch) is enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, done.

store = dict()
for branch in tmp:
store['branch'] = branch
expr = blk.find_variable_assignment(branch.cond.name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find_variable_assignment is potentially unsafe since it looks at assignments of this block only. It is also doing wasteful iteration over IR since definition information is available in func_ir._definitions. Using ir_utils.get_definition solves this problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's safe by virtue of the next line, i.e. if it cannot be found it won't be used? However this is not ideal as noted, will switch to ir_utils.get_definition. Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right; it's safe since blk is current block. Sounds good.

for branch in tmp:
store['branch'] = branch
expr = blk.find_variable_assignment(branch.cond.name)
if expr is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code would be simpler if the condition variable is just saved here and resolving expression is deferred to later.

class Unknown(object):
pass

def get_const_val(func_ir, arg):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this function. ir_utils.find_const does this work without going over the IR.

else:
return Unknown()

def resolve_input_arg_const(input_arg):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally, this should be folded into ir_utils.find_const.

@@ -259,3 +260,179 @@ def find_top_level_loops(cfg):
for loop in cfg.loops().values():
if loop.header not in blocks_in_loop:
yield loop

# Functions to manipulate IR
def dead_branch_prune(func_ir, called_args):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array analysis does some pruning for some reason: https://github.com/numba/numba/blob/master/numba/array_analysis.py#L1066
Maybe it is unnecessary now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would have to check, array analysis isn't on the default code path though? Also I think that code is just looking at e.g. if 1: or if True: ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, it's not on the default path.

Seems like it, but I don't remember the reason (upstream optimization produces if True:?). Just something to keep in mind for future, not necessarily this PR.

branch_assign = block.find_variable_assignment(cond.cond.name)
block.body.remove(branch_assign)

# replace the branch with a direct jump
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Branch is turned into jump only when the target block is removed. However, removing target block is not a necessary condition.

backbone = cfg.backbone()
rem = []

for idx, doms in dom.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the block if it is dominating some block and is not in the backbone?
I think a block can be removed if it doesn't have any predecessors left. Backbone check is just an assertion showing something has gone wrong.

Also, removing blocks can be done once after the big loop when all constant branches are resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks, done. Turns out that just switching branches for jumps and then computing the dead nodes in the CFG is all that's needed.

func_ir.blocks.pop(x)

block = branch['block']
# remove computation of the branch condition, it's dead
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be safe func_ir._definitions should be updated probably. Also, maybe removing stuff should be left to the remove_dead pass to be more robust.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will use remove_dead for statement level adjustments, think it will also update func_ir._definitions too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, it seems in practice that this breaks parfors as there's e.g. sentinels and other dead statements present.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems surprising to me, could you point to examples where this happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hard to demonstrate with a sample code due to the following...
Parfors compilation first enters the compilation pipeline via e.g. @njit and will use a pipeline specified by the pipeline_class kwarg which defaults to compiler.Pipeline. The compilation pipeline proceeds, but compilation is re-entered by the npyufunc gufunc compilation mechanism, this only ever uses the default pipeline:

numba/numba/npyufunc/parfor.py

Lines 1005 to 1012 in fdbe264

kernel_func = compiler.compile_ir(
typingctx,
targetctx,
gufunc_ir,
gufunc_param_types,
types.none,
flags,
locals)

As a result, a potential reproducer like:

import numba
import numpy as np

class TestPipeline(numba.compiler.BasePipeline):
    def define_pipelines(self, pm):
        pm.create_pipeline('fake_parfors')
        self.add_preprocessing_stage(pm)
        self.add_with_handling_stage(pm)
        self.add_pre_typing_stage(pm)
        pm.add_stage(self.rm_dead_stage, "DCE ahead of type inference")
        self.add_typing_stage(pm)
        self.add_optimization_stage(pm)
        pm.add_stage(self.stage_ir_legalization,
                    "ensure IR is legal prior to lowering")
        self.add_lowering_stage(pm)
        self.add_cleanup_stage(pm)

    def rm_dead_stage(self):
        numba.ir_utils.remove_dead(
            self.func_ir.blocks, self.func_ir.arg_names, self.func_ir)

@numba.jit(parallel=True, pipeline_class=TestPipeline)
def f(a):
    b = a + 1
    c = b + a
    return c

x = np.arange(10.)

f(x)

is destined to not reproduce because the pipeline_class will not be used in the call made in the compilation pipeline reentry in numba::npyufunc::parfor::_create_gufunc_for_parfor_body.

If, however, the default compiler pipeline is patched with DCE ahead of typing, like this stuartarchibald@b40257b and e.g. this example is run (one of many from the parfors_tests.py that fail):

import numba
import numpy as np

@numba.jit(parallel=True)
def test_impl():
    X = np.array(1)
    Y = np.ones((10, 12))
    return np.sum(X + Y)

test_impl()

the following appears:

$ python dce_breaks_parfors.py
Removing $phi28.1.13 = $16.6.12
Removing $16.6.12 = getiter(value=$16.5.11)
Removing $16.4.10 = getattr(value=$in_arr.0.71, attr=shape)
Removing $0.3.3 = getattr(value=$0.2.2, attr=init_prange)
Removing $0.2.2 = getattr(value=$0.1.1, attr=parfor)
Removing $0.1.1 = global(numba: <module 'numba' from '<path>/numba/numba/__init__.py'>)
Removing in_arr.0 = $in_arr.0.71
Removing $const_ind_0.65 = const(int, 0)
Removing $0.10 = getattr(value=$0.9, attr=sum)
Removing $0.9 = global(np: <module 'numpy' from '<env>/lib/python3.7/site-packages/numpy/__init__.py'>)
Removing Y = $Y.70
Removing $const0.7 = const(tuple, (10, 12))
Removing $0.6 = getattr(value=$0.5, attr=ones)
Removing $0.5 = global(np: <module 'numpy' from '<env>/lib/python3.7/site-packages/numpy/__init__.py'>)
Removing X = $X.69
Removing $0.14 = val.6
Removing $48.2.27 = val.6
Removing i.20 = $parfor_index_tuple_var.68
Removing $16.2.8 = getattr(value=$16.1.7, attr=pndindex)
Removing $16.1.7 = global(numba: <module 'numba' from '<path>/numba/numba/__init__.py'>)
Removing $ravel.59_size0.93 = static_getitem(value=$ravel.59_shape.92, index=0, index_var=None)
Removing $ravel.59_shape.92 = getattr(value=$ravel.59, attr=shape)
Removing $ravel.59_size0.103 = static_getitem(value=$ravel.59_shape.102, index=0, index_var=None)
Removing $ravel.59_shape.102 = getattr(value=$ravel.59, attr=shape)
Removing $ravel.59_size0.112 = static_getitem(value=$ravel.59_shape.111, index=0, index_var=None)
Removing $ravel.59_shape.111 = getattr(value=$ravel.59, attr=shape)
Removing $ravel.59_size0.121 = static_getitem(value=$ravel.59_shape.120, index=0, index_var=None)
Removing $ravel.59_shape.120 = getattr(value=$ravel.59, attr=shape)
Removing $ravel.59_size0.131 = static_getitem(value=$ravel.59_shape.130, index=0, index_var=None)
Removing $ravel.59_shape.130 = getattr(value=$ravel.59, attr=shape)
Removing $in_arr.0.71_size1.101 = static_getitem(value=$in_arr.0.71_shape.99, index=1, index_var=None)
Removing $in_arr.0.71_size0.100 = static_getitem(value=$in_arr.0.71_shape.99, index=0, index_var=None)
Removing $in_arr.0.71_shape.99 = getattr(value=$in_arr.0.71, attr=shape)
Removing $in_arr.0.71 = call $empty_attr_attr.54($tuple_var.52, $np_typ_var.55, func=$empty_attr_attr.54, args=[Var($tuple_var.52, dce_breaks_parfors.py (8)), Var($np_typ_var.55, dce_breaks_parfors.py (8))], kws=(), vararg=None)
Removing $np_typ_var.55 = getattr(value=$np_g_var.53, attr=float64)
Removing $empty_attr_attr.54 = getattr(value=$np_g_var.53, attr=empty)
Removing $np_g_var.53 = global(np: <module 'numpy' from '<env>/lib/python3.7/site-packages/numpy/__init__.py'>)
Removing $tuple_var.52 = build_tuple(items=[Var($const0.7_size0.31, dce_breaks_parfors.py (7)), Var($const0.7_size1.32, dce_breaks_parfors.py (7))])
Removing $30.5.24 = $expr_out_var.56
Removing $phi58.2 = $phi18.1
Removing $phi58.1 = $18.3

Traceback (most recent call last):
  File "<path>/numba/numba/errors.py", line 617, in new_error_context
    yield
  File "<path>/numba/numba/lowering.py", line 259, in lower_block
    self.lower_inst(inst)
  File "<path>/numba/numba/pylowering.py", line 96, in lower_inst
    value = self.lower_assign(inst)
  File "<path>/numba/numba/pylowering.py", line 186, in lower_assign
    val = self.loadvar(value.name)
  File "<path>/numba/numba/pylowering.py", line 555, in loadvar
    assert name in self._live_vars, name
AssertionError: $Y.70

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "dce_breaks_parfors.py", line 10, in <module>
    test_impl()
  File "<path>/numba/numba/dispatcher.py", line 369, in _compile_for_args
    raise e
  File "<path>/numba/numba/dispatcher.py", line 326, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "<path>/numba/numba/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "<path>/numba/numba/dispatcher.py", line 659, in compile
    cres = self._compiler.compile(args, return_type)
  File "<path>/numba/numba/dispatcher.py", line 82, in compile
    pipeline_class=self.pipeline_class)
  File "<path>/numba/numba/compiler.py", line 928, in compile_extra
    return pipeline.compile_extra(func)
  File "<path>/numba/numba/compiler.py", line 371, in compile_extra
    return self._compile_bytecode()
  File "<path>/numba/numba/compiler.py", line 859, in _compile_bytecode
    return self._compile_core()
  File "<path>/numba/numba/compiler.py", line 846, in _compile_core
    res = pm.run(self.status)
  File "<path>/numba/numba/compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "<path>/numba/numba/compiler.py", line 252, in run
    raise patched_exception
  File "<path>/numba/numba/compiler.py", line 243, in run
    stage()
  File "<path>/numba/numba/compiler.py", line 687, in stage_objectmode_backend
    self._backend(lowerfn, objectmode=True)
  File "<path>/numba/numba/compiler.py", line 663, in _backend
    lowered = lowerfn()
  File "<path>/numba/numba/compiler.py", line 635, in backend_object_mode
    self.flags)
  File "<path>/numba/numba/compiler.py", line 1074, in py_lowering_stage
    lower.lower()
  File "<path>/numba/numba/lowering.py", line 178, in lower
    self.lower_normal_function(self.fndesc)
  File "<path>/numba/numba/lowering.py", line 219, in lower_normal_function
    entry_block_tail = self.lower_function_body()
  File "<path>/numba/numba/lowering.py", line 244, in lower_function_body
    self.lower_block(block)
  File "<path>/numba/numba/lowering.py", line 259, in lower_block
    self.lower_inst(inst)
  File "<env>/lib/python3.7/contextlib.py", line 130, in __exit__
    self.gen.throw(type, value, traceback)
  File "<path>/numba/numba/errors.py", line 625, in new_error_context
    six.reraise(type(newerr), newerr, tb)
  File "<path>/numba/numba/six.py", line 659, in reraise
    raise value
numba.errors.LoweringError: Failed in object mode pipeline (step: object mode backend)
$Y.70

File "dce_breaks_parfors.py", line 7:
def test_impl():
    <source elided>
    X = np.array(1)
    Y = np.ones((10, 12))
    ^

[1] During: lowering "Y = $Y.70" at dce_breaks_parfors.py (7)

in other cases of running numba.tests.test_parfors test failures, the @do_scheduling function is often missing or phi values are missing from the typemap.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stuartarchibald @ehsantn Here's my working theory based on what I have observed. We get into the recursive lowering of functions called by test_impl (confirmed) and the remove dead code removed a variable that a subsequent "del" of that variable tried to find (confirmed). This caused an exception (confirmed) that propagated all the way back up and Numba tried to fallback to a compile of test_impl with parallel=False (I'm pretty sure here as I see the original code for test_impl being lowered). However, the compile with parallel=True changed some data structure in such a way it is no longer compatible with the unchanged version of the function IR and that is why the compile with parallel=False also failed. The last part is speculation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great detective work! I think this hasn't been a problem before since ParallelAccelerator passes ignore Dels altogether and just regenerate after everything is done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @DrTodd13, this seems logical and well found!!

As a generalisation, I've observed that a few of the passes/mutations performed by the ParallelAccelerator code may alter the original IR in such a way it becomes un-lowerable by the standard pipeline (particularly the inlining passes, there's a few tickets about this), this correlates. Assuming the patch in #3727 makes the DCE pass valid for general consumption, is its use ahead of typing still dangerous (I presume so due to unknown aliasing etc)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so. Any issues with removing dead code after inference?

@stuartarchibald stuartarchibald added 4 - Waiting on author Waiting for author to respond to review and removed 3 - Ready for Review labels Jan 8, 2019
@sklam sklam moved this from In Progress to Reviewed... discussion/fixes taking place in Active Jan 15, 2019
@stuartarchibald stuartarchibald added the 4 - Waiting on reviewer Waiting for reviewer to respond to author label Jan 16, 2019
@stuartarchibald stuartarchibald added the R&D Research and Development label Feb 7, 2019
@stuartarchibald
Copy link
Contributor Author

@ehsantn thanks very much for providing such helpful and detailed feedback, in light of this I've had another go at implementing the core algorithm. I've completely removed the use of a DCE pass on the basis that running it ahead of typing may well be dangerous, this can be added post typing in another PR if desirable.

@stuartarchibald stuartarchibald removed the 4 - Waiting on author Waiting for author to respond to review label Feb 7, 2019
Copy link
Collaborator

@ehsantn ehsantn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the effort. This looks much improved. I think it's ready to merge (just a minor optional comment).

numba/analysis.py Outdated Show resolved Hide resolved
@ehsantn
Copy link
Collaborator

ehsantn commented Feb 17, 2019

Flake8 has a few minor complaints in test_analysis.py. https://dev.azure.com/numba/numba/_build/results?buildId=956

@stuartarchibald stuartarchibald added 5 - Ready to merge Review and testing done, is ready to merge and removed 4 - Waiting on reviewer Waiting for reviewer to respond to author labels Feb 21, 2019
@stuartarchibald stuartarchibald merged commit c40ed0f into numba:master Feb 27, 2019
Active automation moved this from Reviewed... discussion/fixes taking place to Done Feb 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
5 - Ready to merge Review and testing done, is ready to merge R&D Research and Development
Projects
Development

Successfully merging this pull request may close these issues.

None yet

6 participants