Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
620 lines (533 sloc) 23.7 KB
"""
Utils for IR analysis
"""
import operator
from functools import reduce
from collections import namedtuple, defaultdict
from .controlflow import CFGraph
from numba.core import types, errors, ir, consts
from numba.misc import special
#
# Analysis related to variable lifetime
#
_use_defs_result = namedtuple('use_defs_result', 'usemap,defmap')
# other packages that define new nodes add calls for finding defs
# format: {type:function}
ir_extension_usedefs = {}
def compute_use_defs(blocks):
"""
Find variable use/def per block.
"""
var_use_map = {} # { block offset -> set of vars }
var_def_map = {} # { block offset -> set of vars }
for offset, ir_block in blocks.items():
var_use_map[offset] = use_set = set()
var_def_map[offset] = def_set = set()
for stmt in ir_block.body:
if type(stmt) in ir_extension_usedefs:
func = ir_extension_usedefs[type(stmt)]
func(stmt, use_set, def_set)
continue
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Inst):
rhs_set = set(var.name for var in stmt.value.list_vars())
elif isinstance(stmt.value, ir.Var):
rhs_set = set([stmt.value.name])
elif isinstance(stmt.value, (ir.Arg, ir.Const, ir.Global,
ir.FreeVar)):
rhs_set = ()
else:
raise AssertionError('unreachable', type(stmt.value))
# If lhs not in rhs of the assignment
if stmt.target.name not in rhs_set:
def_set.add(stmt.target.name)
for var in stmt.list_vars():
# do not include locally defined vars to use-map
if var.name not in def_set:
use_set.add(var.name)
return _use_defs_result(usemap=var_use_map, defmap=var_def_map)
def compute_live_map(cfg, blocks, var_use_map, var_def_map):
"""
Find variables that must be alive at the ENTRY of each block.
We use a simple fix-point algorithm that iterates until the set of
live variables is unchanged for each block.
"""
def fix_point_progress(dct):
"""Helper function to determine if a fix-point has been reached.
"""
return tuple(len(v) for v in dct.values())
def fix_point(fn, dct):
"""Helper function to run fix-point algorithm.
"""
old_point = None
new_point = fix_point_progress(dct)
while old_point != new_point:
fn(dct)
old_point = new_point
new_point = fix_point_progress(dct)
def def_reach(dct):
"""Find all variable definition reachable at the entry of a block
"""
for offset in var_def_map:
used_or_defined = var_def_map[offset] | var_use_map[offset]
dct[offset] |= used_or_defined
# Propagate to outgoing nodes
for out_blk, _ in cfg.successors(offset):
dct[out_blk] |= dct[offset]
def liveness(dct):
"""Find live variables.
Push var usage backward.
"""
for offset in dct:
# Live vars here
live_vars = dct[offset]
for inc_blk, _data in cfg.predecessors(offset):
# Reachable at the predecessor
reachable = live_vars & def_reach_map[inc_blk]
# But not defined in the predecessor
dct[inc_blk] |= reachable - var_def_map[inc_blk]
live_map = {}
for offset in blocks.keys():
live_map[offset] = set(var_use_map[offset])
def_reach_map = defaultdict(set)
fix_point(def_reach, def_reach_map)
fix_point(liveness, live_map)
return live_map
_dead_maps_result = namedtuple('dead_maps_result', 'internal,escaping,combined')
def compute_dead_maps(cfg, blocks, live_map, var_def_map):
"""
Compute the end-of-live information for variables.
`live_map` contains a mapping of block offset to all the living
variables at the ENTRY of the block.
"""
# The following three dictionaries will be
# { block offset -> set of variables to delete }
# all vars that should be deleted at the start of the successors
escaping_dead_map = defaultdict(set)
# all vars that should be deleted within this block
internal_dead_map = defaultdict(set)
# all vars that should be deleted after the function exit
exit_dead_map = defaultdict(set)
for offset, ir_block in blocks.items():
# live vars WITHIN the block will include all the locally
# defined variables
cur_live_set = live_map[offset] | var_def_map[offset]
# vars alive in the outgoing blocks
outgoing_live_map = dict((out_blk, live_map[out_blk])
for out_blk, _data in cfg.successors(offset))
# vars to keep alive for the terminator
terminator_liveset = set(v.name
for v in ir_block.terminator.list_vars())
# vars to keep alive in the successors
combined_liveset = reduce(operator.or_, outgoing_live_map.values(),
set())
# include variables used in terminator
combined_liveset |= terminator_liveset
# vars that are dead within the block because they are not
# propagated to any outgoing blocks
internal_set = cur_live_set - combined_liveset
internal_dead_map[offset] = internal_set
# vars that escape this block
escaping_live_set = cur_live_set - internal_set
for out_blk, new_live_set in outgoing_live_map.items():
# successor should delete the unused escaped vars
new_live_set = new_live_set | var_def_map[out_blk]
escaping_dead_map[out_blk] |= escaping_live_set - new_live_set
# if no outgoing blocks
if not outgoing_live_map:
# insert var used by terminator
exit_dead_map[offset] = terminator_liveset
# Verify that the dead maps cover all live variables
all_vars = reduce(operator.or_, live_map.values(), set())
internal_dead_vars = reduce(operator.or_, internal_dead_map.values(),
set())
escaping_dead_vars = reduce(operator.or_, escaping_dead_map.values(),
set())
exit_dead_vars = reduce(operator.or_, exit_dead_map.values(), set())
dead_vars = (internal_dead_vars | escaping_dead_vars | exit_dead_vars)
missing_vars = all_vars - dead_vars
if missing_vars:
# There are no exit points
if not cfg.exit_points():
# We won't be able to verify this
pass
else:
msg = 'liveness info missing for vars: {0}'.format(missing_vars)
raise RuntimeError(msg)
combined = dict((k, internal_dead_map[k] | escaping_dead_map[k])
for k in blocks)
return _dead_maps_result(internal=internal_dead_map,
escaping=escaping_dead_map,
combined=combined)
def compute_live_variables(cfg, blocks, var_def_map, var_dead_map):
"""
Compute the live variables at the beginning of each block
and at each yield point.
The ``var_def_map`` and ``var_dead_map`` indicates the variable defined
and deleted at each block, respectively.
"""
# live var at the entry per block
block_entry_vars = defaultdict(set)
def fix_point_progress():
return tuple(map(len, block_entry_vars.values()))
old_point = None
new_point = fix_point_progress()
# Propagate defined variables and still live the successors.
# (note the entry block automatically gets an empty set)
# Note: This is finding the actual available variables at the entry
# of each block. The algorithm in compute_live_map() is finding
# the variable that must be available at the entry of each block.
# This is top-down in the dataflow. The other one is bottom-up.
while old_point != new_point:
# We iterate until the result stabilizes. This is necessary
# because of loops in the graphself.
for offset in blocks:
# vars available + variable defined
avail = block_entry_vars[offset] | var_def_map[offset]
# subtract variables deleted
avail -= var_dead_map[offset]
# add ``avail`` to each successors
for succ, _data in cfg.successors(offset):
block_entry_vars[succ] |= avail
old_point = new_point
new_point = fix_point_progress()
return block_entry_vars
#
# Analysis related to controlflow
#
def compute_cfg_from_blocks(blocks):
cfg = CFGraph()
for k in blocks:
cfg.add_node(k)
for k, b in blocks.items():
term = b.terminator
for target in term.get_targets():
cfg.add_edge(k, target)
cfg.set_entry_point(min(blocks))
cfg.process()
return cfg
def find_top_level_loops(cfg):
"""
A generator that yields toplevel loops given a control-flow-graph
"""
blocks_in_loop = set()
# get loop bodies
for loop in cfg.loops().values():
insiders = set(loop.body) | set(loop.entries) | set(loop.exits)
insiders.discard(loop.header)
blocks_in_loop |= insiders
# find loop that is not part of other loops
for loop in cfg.loops().values():
if loop.header not in blocks_in_loop:
yield _fix_loop_exit(cfg, loop)
def _fix_loop_exit(cfg, loop):
"""
Fixes loop.exits for Py3.8 bytecode CFG changes.
This is to handle `break` inside loops.
"""
# Computes the common postdoms of exit nodes
postdoms = cfg.post_dominators()
exits = reduce(
operator.and_,
[postdoms[b] for b in loop.exits],
loop.exits,
)
if exits:
# Put the non-common-exits as body nodes
body = loop.body | loop.exits - exits
return loop._replace(exits=exits, body=body)
else:
return loop
# Used to describe a nullified condition in dead branch pruning
nullified = namedtuple('nullified', 'condition, taken_br, rewrite_stmt')
# Functions to manipulate IR
def dead_branch_prune(func_ir, called_args):
"""
Removes dead branches based on constant inference from function args.
This directly mutates the IR.
func_ir is the IR
called_args are the actual arguments with which the function is called
"""
from numba.core.ir_utils import (get_definition, guard, find_const,
GuardException)
DEBUG = 0
def find_branches(func_ir):
# find *all* branches
branches = []
for blk in func_ir.blocks.values():
branch_or_jump = blk.body[-1]
if isinstance(branch_or_jump, ir.Branch):
branch = branch_or_jump
pred = guard(get_definition, func_ir, branch.cond.name)
if pred is not None and pred.op == "call":
function = guard(get_definition, func_ir, pred.func)
if (function is not None and
isinstance(function, ir.Global) and
function.value is bool):
condition = guard(get_definition, func_ir, pred.args[0])
if condition is not None:
branches.append((branch, condition, blk))
return branches
def do_prune(take_truebr, blk):
keep = branch.truebr if take_truebr else branch.falsebr
# replace the branch with a direct jump
jmp = ir.Jump(keep, loc=branch.loc)
blk.body[-1] = jmp
return 1 if keep == branch.truebr else 0
def prune_by_type(branch, condition, blk, *conds):
# this prunes a given branch and fixes up the IR
# at least one needs to be a NoneType
lhs_cond, rhs_cond = conds
lhs_none = isinstance(lhs_cond, types.NoneType)
rhs_none = isinstance(rhs_cond, types.NoneType)
if lhs_none or rhs_none:
try:
take_truebr = condition.fn(lhs_cond, rhs_cond)
except Exception:
return False, None
if DEBUG > 0:
kill = branch.falsebr if take_truebr else branch.truebr
print("Pruning %s" % kill, branch, lhs_cond, rhs_cond,
condition.fn)
taken = do_prune(take_truebr, blk)
return True, taken
return False, None
def prune_by_value(branch, condition, blk, *conds):
lhs_cond, rhs_cond = conds
try:
take_truebr = condition.fn(lhs_cond, rhs_cond)
except Exception:
return False, None
if DEBUG > 0:
kill = branch.falsebr if take_truebr else branch.truebr
print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn)
taken = do_prune(take_truebr, blk)
return True, taken
def prune_by_predicate(branch, pred, blk):
try:
# Just to prevent accidents, whilst already guarded, ensure this
# is an ir.Const
if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)):
raise TypeError('Expected constant Numba IR node')
take_truebr = bool(pred.value)
except TypeError:
return False, None
if DEBUG > 0:
kill = branch.falsebr if take_truebr else branch.truebr
print("Pruning %s" % kill, branch, pred)
taken = do_prune(take_truebr, blk)
return True, taken
class Unknown(object):
pass
def resolve_input_arg_const(input_arg_idx):
"""
Resolves an input arg to a constant (if possible)
"""
input_arg_ty = called_args[input_arg_idx]
# comparing to None?
if isinstance(input_arg_ty, types.NoneType):
return input_arg_ty
# is it a kwarg default
if isinstance(input_arg_ty, types.Omitted):
val = input_arg_ty.value
if isinstance(val, types.NoneType):
return val
elif val is None:
return types.NoneType('none')
# literal type, return the type itself so comparisons like `x == None`
# still work as e.g. x = types.int64 will never be None/NoneType so
# the branch can still be pruned
return getattr(input_arg_ty, 'literal_type', Unknown())
if DEBUG > 1:
print("before".center(80, '-'))
print(func_ir.dump())
# This looks for branches where:
# at least one arg of the condition is in input args and const
# at least one an arg of the condition is a const
# if the condition is met it will replace the branch with a jump
branch_info = find_branches(func_ir)
# stores conditions that have no impact post prune
nullified_conditions = []
for branch, condition, blk in branch_info:
const_conds = []
if isinstance(condition, ir.Expr) and condition.op == 'binop':
prune = prune_by_value
for arg in [condition.lhs, condition.rhs]:
resolved_const = Unknown()
arg_def = guard(get_definition, func_ir, arg)
if isinstance(arg_def, ir.Arg):
# it's an e.g. literal argument to the function
resolved_const = resolve_input_arg_const(arg_def.index)
prune = prune_by_type
else:
# it's some const argument to the function, cannot use guard
# here as the const itself may be None
try:
resolved_const = find_const(func_ir, arg)
if resolved_const is None:
resolved_const = types.NoneType('none')
except GuardException:
pass
if not isinstance(resolved_const, Unknown):
const_conds.append(resolved_const)
# lhs/rhs are consts
if len(const_conds) == 2:
# prune the branch, switch the branch for an unconditional jump
prune_stat, taken = prune(branch, condition, blk, *const_conds)
if(prune_stat):
# add the condition to the list of nullified conditions
nullified_conditions.append(nullified(condition, taken,
True))
else:
# see if this is a branch on a constant value predicate
resolved_const = Unknown()
try:
pred_call = get_definition(func_ir, branch.cond)
resolved_const = find_const(func_ir, pred_call.args[0])
if resolved_const is None:
resolved_const = types.NoneType('none')
except GuardException:
pass
if not isinstance(resolved_const, Unknown):
prune_stat, taken = prune_by_predicate(branch, condition, blk)
if(prune_stat):
# add the condition to the list of nullified conditions
nullified_conditions.append(nullified(condition, taken,
False))
# 'ERE BE DRAGONS...
# It is the evaluation of the condition expression that often trips up type
# inference, so ideally it would be removed as it is effectively rendered
# dead by the unconditional jump if a branch was pruned. However, there may
# be references to the condition that exist in multiple places (e.g. dels)
# and we cannot run DCE here as typing has not taken place to give enough
# information to run DCE safely. Upshot of all this is the condition gets
# rewritten below into a benign const that typing will be happy with and DCE
# can remove it and its reference post typing when it is safe to do so
# (if desired). It is required that the const is assigned a value that
# indicates the branch taken as its mutated value would be read in the case
# of object mode fall back in place of the condition itself. For
# completeness the func_ir._definitions and ._consts are also updated to
# make the IR state self consistent.
deadcond = [x.condition for x in nullified_conditions]
for _, cond, blk in branch_info:
if cond in deadcond:
for x in blk.body:
if isinstance(x, ir.Assign) and x.value is cond:
# rewrite the condition as a true/false bit
nullified_info = nullified_conditions[deadcond.index(cond)]
# only do a rewrite of conditions, predicates need to retain
# their value as they may be used later.
if nullified_info.rewrite_stmt:
branch_bit = nullified_info.taken_br
x.value = ir.Const(branch_bit, loc=x.loc)
# update the specific definition to the new const
defns = func_ir._definitions[x.target.name]
repl_idx = defns.index(cond)
defns[repl_idx] = x.value
# Remove dead blocks, this is safe as it relies on the CFG only.
cfg = compute_cfg_from_blocks(func_ir.blocks)
for dead in cfg.dead_nodes():
del func_ir.blocks[dead]
# if conditions were nullified then consts were rewritten, update
if nullified_conditions:
func_ir._consts = consts.ConstantInference(func_ir)
if DEBUG > 1:
print("after".center(80, '-'))
print(func_ir.dump())
def rewrite_semantic_constants(func_ir, called_args):
"""
This rewrites values known to be constant by their semantics as ir.Const
nodes, this is to give branch pruning the best chance possible of killing
branches. An example might be rewriting len(tuple) as the literal length.
func_ir is the IR
called_args are the actual arguments with which the function is called
"""
DEBUG = 0
if DEBUG > 1:
print(("rewrite_semantic_constants: " +
func_ir.func_id.func_name).center(80, '-'))
print("before".center(80, '*'))
func_ir.dump()
def rewrite_statement(func_ir, stmt, new_val):
"""
Rewrites the stmt as a ir.Const new_val and fixes up the entries in
func_ir._definitions
"""
stmt.value = ir.Const(new_val, stmt.loc)
defns = func_ir._definitions[stmt.target.name]
repl_idx = defns.index(val)
defns[repl_idx] = stmt.value
def rewrite_array_ndim(val, func_ir, called_args):
# rewrite Array.ndim as const(ndim)
if getattr(val, 'op', None) == 'getattr':
if val.attr == 'ndim':
arg_def = guard(get_definition, func_ir, val.value)
if isinstance(arg_def, ir.Arg):
argty = called_args[arg_def.index]
if isinstance(argty, types.Array):
rewrite_statement(func_ir, stmt, argty.ndim)
def rewrite_tuple_len(val, func_ir, called_args):
# rewrite len(tuple) as const(len(tuple))
if getattr(val, 'op', None) == 'call':
func = guard(get_definition, func_ir, val.func)
if (func is not None and isinstance(func, ir.Global) and
getattr(func, 'value', None) is len):
(arg,) = val.args
arg_def = guard(get_definition, func_ir, arg)
if isinstance(arg_def, ir.Arg):
argty = called_args[arg_def.index]
if isinstance(argty, types.BaseTuple):
rewrite_statement(func_ir, stmt, argty.count)
from numba.core.ir_utils import get_definition, guard
for blk in func_ir.blocks.values():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
val = stmt.value
if isinstance(val, ir.Expr):
rewrite_array_ndim(val, func_ir, called_args)
rewrite_tuple_len(val, func_ir, called_args)
if DEBUG > 1:
print("after".center(80, '*'))
func_ir.dump()
print('-' * 80)
def find_literally_calls(func_ir, argtypes):
"""An analysis to find `numba.literally` call inside the given IR.
When an unsatisfied literal typing request is found, a `ForceLiteralArg`
exception is raised.
Parameters
----------
func_ir : numba.ir.FunctionIR
argtypes : Sequence[numba.types.Type]
The argument types.
"""
from numba.core import ir_utils
marked_args = set()
first_loc = {}
# Scan for literally calls
for blk in func_ir.blocks.values():
for assign in blk.find_exprs(op='call'):
var = ir_utils.guard(ir_utils.get_definition, func_ir, assign.func)
if isinstance(var, (ir.Global, ir.FreeVar)):
fnobj = var.value
else:
fnobj = ir_utils.guard(ir_utils.resolve_func_from_module,
func_ir, var)
if fnobj is special.literally:
# Found
[arg] = assign.args
defarg = func_ir.get_definition(arg)
if isinstance(defarg, ir.Arg):
argindex = defarg.index
marked_args.add(argindex)
first_loc.setdefault(argindex, assign.loc)
# Signal the dispatcher to force literal typing
for pos in marked_args:
query_arg = argtypes[pos]
do_raise = (isinstance(query_arg, types.InitialValue) and
query_arg.initial_value is None)
if do_raise:
loc = first_loc[pos]
raise errors.ForceLiteralArg(marked_args, loc=loc)
if not isinstance(query_arg, (types.Literal, types.InitialValue)):
loc = first_loc[pos]
raise errors.ForceLiteralArg(marked_args, loc=loc)
You can’t perform that action at this time.