Skip to content

Commit

Permalink
Add tests for transforms.
Browse files Browse the repository at this point in the history
As title.
  • Loading branch information
stuartarchibald committed Jan 3, 2019
1 parent fbcf804 commit 5a64996
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 93 deletions.
177 changes: 177 additions & 0 deletions numba/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from numba import ir
from numba.controlflow import CFGraph
from numba import types

#
# Analysis related to variable lifetime
Expand Down Expand Up @@ -259,3 +260,179 @@ def find_top_level_loops(cfg):
for loop in cfg.loops().values():
if loop.header not in blocks_in_loop:
yield loop

# Functions to manipulate IR
def dead_branch_prune(func_ir, called_args):
"""
Removes dead branches based on constant inference from function args.
This directly mutates the IR.
func_ir is the IR
called_args are the actual arguments with which the function is called
"""

DEBUG = 0

def find_branches(func_ir):
# find *all* branches
branches = []
for idx, blk in func_ir.blocks.items():
tmp = [_ for _ in blk.find_insts(cls=ir.Branch)]
store = dict()
for branch in tmp:
store['branch'] = branch
expr = blk.find_variable_assignment(branch.cond.name)
if expr is not None:
val = expr.value
if isinstance(val, ir.Expr) and val.op == 'binop':
store['op'] = val
args = [val.lhs, val.rhs]
store['args'] = args
store['block'] = blk
branches.append(store)
return branches

def prune(func_ir, branch, all_branches, *conds):
# this prunes a given branch and fixes up the IR
lhs_cond, rhs_cond = conds
take_truebr = branch['op'].fn(lhs_cond, rhs_cond)
cond = branch['branch']
if take_truebr:
keep = cond.truebr
kill = cond.falsebr
else:
keep = cond.falsebr
kill = cond.truebr

cfg = compute_cfg_from_blocks(func_ir.blocks)
if DEBUG > 0:
print("Pruning %s" % kill, branch['branch'], lhs_cond, rhs_cond, branch['op'].fn)
if DEBUG > 1:
from pprint import pprint
cfg.dump()

# only prune branches to blocks that have a single access route, this is
# conservative
kill_count = 0
for targets in cfg._succs.values():
if kill in targets:
kill_count += 1
if kill_count > 1:
if DEBUG > 0:
print("prune rejected")
return

# remove dominators that are not on the backbone
dom = cfg.dominators()
postdom = cfg.post_dominators()
backbone = cfg.backbone()
rem = []

for idx, doms in dom.items():
if kill in doms and kill not in backbone:
rem.append(idx)
if not rem:
if DEBUG > 0:
msg = "prune rejected no kill in dominators and not in backbone"
print(msg)
return

for x in rem:
func_ir.blocks.pop(x)

block = branch['block']
# remove computation of the branch condition, it's dead
branch_assign = block.find_variable_assignment(cond.cond.name)
block.body.remove(branch_assign)

# replace the branch with a direct jump
jmp = ir.Jump(keep, loc=cond.loc)
block.body[-1] = jmp

branches = find_branches(func_ir)
noldbranches = len(branches)
nnewbranches = 0

class Unknown(object):
pass

def get_const_val(func_ir, arg):
"""
Resolves the arg to a const if there is a variable assignment like:
`$label = const(thing)`
"""
possibles = []
for idx, blk in func_ir.blocks.items():
found = blk.find_variable_assignment(arg.name)
if found is not None and isinstance(found.value, ir.Const):
possibles.append(found.value.value)
# if there's more than one definition we don't know which const
# propagates to here
if len(possibles) == 1:
return possibles[0]
else:
return Unknown()

def resolve_input_arg_const(input_arg):
"""
Resolves an input arg to a constant (if possible)
"""
idx = func_ir.arg_names.index(input_arg)
input_arg_ty = called_args[idx]

# comparing to None?
if isinstance(input_arg_ty, types.NoneType):
return None

# is it a kwarg default
if isinstance(input_arg_ty, types.Omitted):
if isinstance(input_arg_ty.value, types.NoneType):
return None
else:
return input_arg_ty.value

# is it a literal
const = getattr(input_arg_ty, 'literal_value', None)
if const is not None:
return const

return Unknown()

# keep iterating until branch prune stabilizes
while(noldbranches != nnewbranches):
# This looks for branches where:
# an arg of the condition is in input args and const/literal
# an arg of the condition is a const/literal
# if the condition is met it will remove blocks that are not reached and
# patch up the branch
for branch in branches:
const_arg_val = False
const_conds = []
for arg in branch['args']:
resolved_const = Unknown()
if arg.name in func_ir.arg_names:
# it's an e.g. literal argument to the function
resolved_const = resolve_input_arg_const(arg.name)
else:
# it's some const argument to the function
resolved_const = get_const_val(func_ir, arg)

if not isinstance(resolved_const, Unknown):
const_conds.append(resolved_const)

# lhs/rhs are consts
if len(const_conds) == 2:
if DEBUG > 1:
print("before".center(80, '-'))
print(func_ir.dump())

prune(func_ir, branch, branches, *const_conds)

if DEBUG > 1:
print("after".center(80, '-'))
print(func_ir.dump())

noldbranches = len(branches)
branches = find_branches(func_ir)
nnewbranches = len(branches)

98 changes: 8 additions & 90 deletions numba/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from numba.errors import CompilerError
from numba.ir_utils import raise_on_unsupported_feature
from numba.compiler_lock import global_compiler_lock
from numba.analysis import dead_branch_prune

# terminal color markup
_termcolor = errors.termcolor()
Expand Down Expand Up @@ -475,99 +476,16 @@ def stage_dead_branch_prune(self):
This prunes dead branches, a dead branch is one which is derivable as
not taken at compile time purely based on const/literal evaluation.
"""
func_ir = self.func_ir
# find *all* branches
branches = []
for idx, blk in func_ir.blocks.items():
tmp = [_ for _ in blk.find_insts(cls=ir.Branch)]
store = dict()
for branch in tmp:
store['branch'] = branch
expr = blk.find_variable_assignment(branch.cond.name)
if expr is not None:
val = expr.value
if isinstance(val, ir.Expr) and val.op == 'binop':
store['op'] = val
args = [val.lhs, val.rhs]
store['args'] = args
store['block'] = blk
branches.append(store)
# This looks for branches where:
# an arg of the condition is in input args and const/literal
# an arg of the condition is a const/literal
# if the condition is met it will remove blocks that are not reached and
# patch up the branch
prune = []
for branch in branches:
is_input_arg = False
const_arg_val = False
for arg in branch['args']:
if arg.name in func_ir.arg_names:
# it's an argument to the function
is_input_arg = arg.name
continue
for idx, blk in func_ir.blocks.items():
found = blk.find_variable_assignment(arg.name)
if found is not None and isinstance(found.value, ir.Const):
const_arg_val = found.value.value
continue

if is_input_arg is not False and const_arg_val is not False:
idx = func_ir.arg_names.index(is_input_arg)
input_arg_ty = self.args[idx]
proceed = False

# comparing None ?
lhs_cond = isinstance(input_arg_ty, types.NoneType)
rhs_cond = const_arg_val is None

if lhs_cond and rhs_cond:
proceed = True
else:
# comparing known types?
# is it ommitted
lhs_cond = getattr(input_arg_ty, 'value', None)

# is it a literal
if lhs_cond is None:
lhs_cond = getattr(input_arg_ty, 'literal_value', None)

rhs_cond = const_arg_val
if lhs_cond is not None and rhs_cond:
proceed = True

if not proceed:
continue # try next branch inst

take_truebr = branch['op'].fn(lhs_cond, rhs_cond)
cond = branch['branch']
if take_truebr:
keep = cond.truebr
kill = cond.falsebr
else:
keep = cond.falsebr
kill = cond.truebr

# rip out dead blocks
from numba.analysis import compute_cfg_from_blocks
cfg = compute_cfg_from_blocks(func_ir.blocks)
dom = cfg.dominators()
backbone = cfg.backbone()
rem = []
for idx, doms in dom.items():
if kill in doms and kill not in backbone:
#print("KILLING", idx)
rem.append(idx)
for x in rem:
func_ir.blocks.pop(x)

# fix up branch location, it's now just a jump
jmp = ir.Jump(keep, loc=cond.loc)
branch['block'].body[-1] = jmp
assert self.func_ir
msg = ('Internal error in pre-inference dead branch pruning '
'pass encountered during compilation of '
'function "%s"' % (self.func_id.func_name,))
with self.fallback_context(msg):
dead_branch_prune(self.func_ir, self.args)

if config.DEBUG or config.DUMP_IR:
print('branch_pruned_ir'.center(80, '-'))
print(func_ir.dump())
print(self.func_ir.dump())
print('end branch_pruned_ir'.center(80, '-'))

def stage_nopython_frontend(self):
Expand Down
13 changes: 10 additions & 3 deletions numba/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ def add_edge(self, src, dest, data=None):
If such an edge already exists, it is replaced (duplicate edges
are not possible).
"""
assert src in self._nodes
assert dest in self._nodes
try:
assert src in self._nodes
assert dest in self._nodes
except AssertionError as e:
print(e)
import pdb; pdb.set_trace()
pass
self._add_edge(src, dest, data)

def successors(self, src):
Expand Down Expand Up @@ -234,6 +239,8 @@ def dump(self, file=None):
pprint.pprint(self._loops, stream=file)
print("CFG node-to-loops:", file=file)
pprint.pprint(self._in_loops, stream=file)
print("CFG backbone:", file=file)
pprint.pprint(self.backbone(), stream=file)

# Internal APIs

Expand Down Expand Up @@ -443,7 +450,7 @@ def _find_loops(self):
self._in_loops = in_loops

def _dump_adj_lists(self, file):
adj_lists = dict((src, list(dests))
adj_lists = dict((src, sorted(list(dests)))
for src, dests in self._succs.items())
import pprint
pprint.pprint(adj_lists, stream=file)
Expand Down

0 comments on commit 5a64996

Please sign in to comment.