Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dead branch prune before type inference. #3592

Merged
merged 6 commits into from
Feb 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ exclude =
numba/compiler.py
numba/ctypes_support.py
numba/withcontexts.py
numba/analysis.py
numba/_version.py
numba/unicode.py
numba/inline_closurecall.py
Expand Down
151 changes: 151 additions & 0 deletions numba/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from numba import ir
from numba.controlflow import CFGraph
from numba import types

#
# Analysis related to variable lifetime
Expand All @@ -18,6 +19,7 @@
# format: {type:function}
ir_extension_usedefs = {}


def compute_use_defs(blocks):
"""
Find variable use/def per block.
Expand Down Expand Up @@ -259,3 +261,152 @@ def find_top_level_loops(cfg):
for loop in cfg.loops().values():
if loop.header not in blocks_in_loop:
yield loop


# Functions to manipulate IR
def dead_branch_prune(func_ir, called_args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array analysis does some pruning for some reason: https://github.com/numba/numba/blob/master/numba/array_analysis.py#L1066
Maybe it is unnecessary now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would have to check, array analysis isn't on the default code path though? Also I think that code is just looking at e.g. if 1: or if True: ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, it's not on the default path.

Seems like it, but I don't remember the reason (upstream optimization produces if True:?). Just something to keep in mind for future, not necessarily this PR.

"""
Removes dead branches based on constant inference from function args.
This directly mutates the IR.

func_ir is the IR
called_args are the actual arguments with which the function is called
"""
from .ir_utils import get_definition, guard, find_const, GuardException

DEBUG = 0

def find_branches(func_ir):
# find *all* branches
branches = []
for blk in func_ir.blocks.values():
branch_or_jump = blk.body[-1]
if isinstance(branch_or_jump, ir.Branch):
branch = branch_or_jump
condition = guard(get_definition, func_ir, branch.cond.name)
if condition is not None:
branches.append((branch, condition, blk))
return branches

def do_prune(take_truebr, blk):
keep = branch.truebr if take_truebr else branch.falsebr
# replace the branch with a direct jump
jmp = ir.Jump(keep, loc=branch.loc)
blk.body[-1] = jmp

def prune_by_type(branch, condition, blk, *conds):
# this prunes a given branch and fixes up the IR
# at least one needs to be a NoneType
lhs_cond, rhs_cond = conds
lhs_none = isinstance(lhs_cond, types.NoneType)
rhs_none = isinstance(rhs_cond, types.NoneType)
if lhs_none or rhs_none:
take_truebr = condition.fn(lhs_cond, rhs_cond)
if DEBUG > 0:
kill = branch.falsebr if take_truebr else branch.truebr
print("Pruning %s" % kill, branch, lhs_cond, rhs_cond,
condition.fn)
do_prune(take_truebr, blk)
return True
return False

def prune_by_value(branch, condition, blk, *conds):
lhs_cond, rhs_cond = conds
take_truebr = condition.fn(lhs_cond, rhs_cond)
if DEBUG > 0:
kill = branch.falsebr if take_truebr else branch.truebr
print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn)
do_prune(take_truebr, blk)
return True

class Unknown(object):
pass

def resolve_input_arg_const(input_arg):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally, this should be folded into ir_utils.find_const.

"""
Resolves an input arg to a constant (if possible)
"""
idx = func_ir.arg_names.index(input_arg)
input_arg_ty = called_args[idx]

# comparing to None?
if isinstance(input_arg_ty, types.NoneType):
return input_arg_ty

# is it a kwarg default
if isinstance(input_arg_ty, types.Omitted):
val = input_arg_ty.value
if isinstance(val, types.NoneType):
return val
elif val is None:
return types.NoneType('none')

# literal type, return the type itself so comparisons like `x == None`
# still work as e.g. x = types.int64 will never be None/NoneType so
# the branch can still be pruned
return getattr(input_arg_ty, 'literal_type', Unknown())

if DEBUG > 1:
print("before".center(80, '-'))
print(func_ir.dump())

# This looks for branches where:
# at least one arg of the condition is in input args and const
# at least one an arg of the condition is a const
# if the condition is met it will replace the branch with a jump
branch_info = find_branches(func_ir)
nullified_conditions = [] # stores conditions that have no impact post prune
for branch, condition, blk in branch_info:
const_conds = []
if isinstance(condition, ir.Expr) and condition.op == 'binop':
prune = prune_by_value
for arg in [condition.lhs, condition.rhs]:
resolved_const = Unknown()
if arg.name in func_ir.arg_names:
# it's an e.g. literal argument to the function
resolved_const = resolve_input_arg_const(arg.name)
prune = prune_by_type
else:
# it's some const argument to the function, cannot use guard
# here as the const itself may be None
try:
resolved_const = find_const(func_ir, arg)
if resolved_const is None:
resolved_const = types.NoneType('none')
except GuardException:
pass

if not isinstance(resolved_const, Unknown):
const_conds.append(resolved_const)

# lhs/rhs are consts
if len(const_conds) == 2:
# prune the branch, switch the branch for an unconditional jump
if(prune(branch, condition, blk, *const_conds)):
# add the condition to the list of nullified conditions
nullified_conditions.append(condition)

# 'ERE BE DRAGONS...
# It is the evaluation of the condition expression that often trips up type
# inference, so ideally it would be removed as it is effectively rendered
# dead by the unconditional jump if a branch was pruned. However, there may
# be references to the condition that exist in multiple places (e.g. dels)
# and we cannot run DCE here as typing has not taken place to give enough
# information to run DCE safely. Upshot of all this is the condition gets
# rewritten below into a benign const that typing will be happy with and DCE
# can remove it and its reference post typing when it is safe to do so
# (if desired).
for _, condition, blk in branch_info:
if condition in nullified_conditions:
for x in blk.body:
if isinstance(x, ir.Assign) and x.value is condition:
x.value = ir.Const(0, loc=x.loc)

# Remove dead blocks, this is safe as it relies on the CFG only.
cfg = compute_cfg_from_blocks(func_ir.blocks)
for dead in cfg.dead_nodes():
del func_ir.blocks[dead]

if DEBUG > 1:
print("after".center(80, '-'))
print(func_ir.dump())
19 changes: 19 additions & 0 deletions numba/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from numba.errors import CompilerError
from numba.ir_utils import raise_on_unsupported_feature
from numba.compiler_lock import global_compiler_lock
from numba.analysis import dead_branch_prune

# terminal color markup
_termcolor = errors.termcolor()
Expand Down Expand Up @@ -467,6 +468,23 @@ def stage_objectmode_frontend(self):
self.calltypes = defaultdict(lambda: types.pyobject)
self.return_type = types.pyobject

def stage_dead_branch_prune(self):
"""
This prunes dead branches, a dead branch is one which is derivable as
not taken at compile time purely based on const/literal evaluation.
"""
assert self.func_ir
msg = ('Internal error in pre-inference dead branch pruning '
'pass encountered during compilation of '
'function "%s"' % (self.func_id.func_name,))
with self.fallback_context(msg):
dead_branch_prune(self.func_ir, self.args)

if config.DEBUG or config.DUMP_IR:
print('branch_pruned_ir'.center(80, '-'))
print(self.func_ir.dump())
print('end branch_pruned_ir'.center(80, '-'))

def stage_nopython_frontend(self):
"""
Type inference and legalization
Expand Down Expand Up @@ -759,6 +777,7 @@ def add_pre_typing_stage(self, pm):
pm.add_stage(self.stage_preserve_ir,
"preserve IR for fallback")
pm.add_stage(self.stage_generic_rewrites, "nopython rewrites")
pm.add_stage(self.stage_dead_branch_prune, "dead branch pruning")
pm.add_stage(self.stage_inline_pass,
"inline calls to locally defined closures")

Expand Down
4 changes: 3 additions & 1 deletion numba/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def dump(self, file=None):
pprint.pprint(self._loops, stream=file)
print("CFG node-to-loops:", file=file)
pprint.pprint(self._in_loops, stream=file)
print("CFG backbone:", file=file)
pprint.pprint(self.backbone(), stream=file)

# Internal APIs

Expand Down Expand Up @@ -443,7 +445,7 @@ def _find_loops(self):
self._in_loops = in_loops

def _dump_adj_lists(self, file):
adj_lists = dict((src, list(dests))
adj_lists = dict((src, sorted(list(dests)))
for src, dests in self._succs.items())
import pprint
pprint.pprint(adj_lists, stream=file)
Expand Down