Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
"""
Implement transformation on Numba IR
"""
from collections import namedtuple, defaultdict
import logging
import operator
from numba.core.analysis import compute_cfg_from_blocks, find_top_level_loops
from numba.core import errors, ir, ir_utils
from numba.core.analysis import compute_use_defs, compute_cfg_from_blocks
from numba.core.utils import PYVERSION
_logger = logging.getLogger(__name__)
def _extract_loop_lifting_candidates(cfg, blocks):
"""
Returns a list of loops that are candidate for loop lifting
"""
# check well-formed-ness of the loop
def same_exit_point(loop):
"all exits must point to the same location"
outedges = set()
for k in loop.exits:
succs = set(x for x, _ in cfg.successors(k))
if not succs:
# If the exit point has no successor, it contains an return
# statement, which is not handled by the looplifting code.
# Thus, this loop is not a candidate.
_logger.debug("return-statement in loop.")
return False
outedges |= succs
ok = len(outedges) == 1
_logger.debug("same_exit_point=%s (%s)", ok, outedges)
return ok
def one_entry(loop):
"there is one entry"
ok = len(loop.entries) == 1
_logger.debug("one_entry=%s", ok)
return ok
def cannot_yield(loop):
"cannot have yield inside the loop"
insiders = set(loop.body) | set(loop.entries) | set(loop.exits)
for blk in map(blocks.__getitem__, insiders):
for inst in blk.body:
if isinstance(inst, ir.Assign):
if isinstance(inst.value, ir.Yield):
_logger.debug("has yield")
return False
_logger.debug("no yield")
return True
_logger.info('finding looplift candidates')
# the check for cfg.entry_point in the loop.entries is to prevent a bad
# rewrite where a prelude for a lifted loop would get written into block -1
# if a loop entry were in block 0
candidates = []
for loop in find_top_level_loops(cfg):
_logger.debug("top-level loop: %s", loop)
if (same_exit_point(loop) and one_entry(loop) and cannot_yield(loop) and
cfg.entry_point() not in loop.entries):
candidates.append(loop)
_logger.debug("add candidate: %s", loop)
return candidates
def find_region_inout_vars(blocks, livemap, callfrom, returnto, body_block_ids):
"""Find input and output variables to a block region.
"""
inputs = livemap[callfrom]
outputs = livemap[returnto]
# ensure live variables are actually used in the blocks, else remove,
# saves having to create something valid to run through postproc
# to achieve similar
loopblocks = {}
for k in body_block_ids:
loopblocks[k] = blocks[k]
used_vars = set()
def_vars = set()
defs = compute_use_defs(loopblocks)
for vs in defs.usemap.values():
used_vars |= vs
for vs in defs.defmap.values():
def_vars |= vs
used_or_defined = used_vars | def_vars
# note: sorted for stable ordering
inputs = sorted(set(inputs) & used_or_defined)
outputs = sorted(set(outputs) & used_or_defined & def_vars)
return inputs, outputs
_loop_lift_info = namedtuple('loop_lift_info',
'loop,inputs,outputs,callfrom,returnto')
def _loop_lift_get_candidate_infos(cfg, blocks, livemap):
"""
Returns information on looplifting candidates.
"""
loops = _extract_loop_lifting_candidates(cfg, blocks)
loopinfos = []
for loop in loops:
[callfrom] = loop.entries # requirement checked earlier
an_exit = next(iter(loop.exits)) # anyone of the exit block
if len(loop.exits) > 1:
# Pre-Py3.8 may have multiple exits
[(returnto, _)] = cfg.successors(an_exit) # requirement checked earlier
else:
# Post-Py3.8 DO NOT have multiple exits
returnto = an_exit
local_block_ids = set(loop.body) | set(loop.entries) | set(loop.exits)
inputs, outputs = find_region_inout_vars(
blocks=blocks,
livemap=livemap,
callfrom=callfrom,
returnto=returnto,
body_block_ids=local_block_ids,
)
lli = _loop_lift_info(loop=loop, inputs=inputs, outputs=outputs,
callfrom=callfrom, returnto=returnto)
loopinfos.append(lli)
return loopinfos
def _loop_lift_modify_call_block(liftedloop, block, inputs, outputs, returnto):
"""
Transform calling block from top-level function to call the lifted loop.
"""
scope = block.scope
loc = block.loc
blk = ir.Block(scope=scope, loc=loc)
ir_utils.fill_block_with_call(
newblock=blk,
callee=liftedloop,
label_next=returnto,
inputs=inputs,
outputs=outputs,
)
return blk
def _loop_lift_prepare_loop_func(loopinfo, blocks):
"""
Inplace transform loop blocks for use as lifted loop.
"""
entry_block = blocks[loopinfo.callfrom]
scope = entry_block.scope
loc = entry_block.loc
# Lowering assumes the first block to be the one with the smallest offset
firstblk = min(blocks) - 1
blocks[firstblk] = ir_utils.fill_callee_prologue(
block=ir.Block(scope=scope, loc=loc),
inputs=loopinfo.inputs,
label_next=loopinfo.callfrom,
)
blocks[loopinfo.returnto] = ir_utils.fill_callee_epilogue(
block=ir.Block(scope=scope, loc=loc),
outputs=loopinfo.outputs,
)
def _loop_lift_modify_blocks(func_ir, loopinfo, blocks,
typingctx, targetctx, flags, locals):
"""
Modify the block inplace to call to the lifted-loop.
Returns a dictionary of blocks of the lifted-loop.
"""
from numba.core.dispatcher import LiftedLoop
# Copy loop blocks
loop = loopinfo.loop
loopblockkeys = set(loop.body) | set(loop.entries)
if len(loop.exits) > 1:
# Pre-Py3.8 may have multiple exits
loopblockkeys |= loop.exits
loopblocks = dict((k, blocks[k].copy()) for k in loopblockkeys)
# Modify the loop blocks
_loop_lift_prepare_loop_func(loopinfo, loopblocks)
# Create a new IR for the lifted loop
lifted_ir = func_ir.derive(blocks=loopblocks,
arg_names=tuple(loopinfo.inputs),
arg_count=len(loopinfo.inputs),
force_non_generator=True)
liftedloop = LiftedLoop(lifted_ir,
typingctx, targetctx, flags, locals)
# modify for calling into liftedloop
callblock = _loop_lift_modify_call_block(liftedloop, blocks[loopinfo.callfrom],
loopinfo.inputs, loopinfo.outputs,
loopinfo.returnto)
# remove blocks
for k in loopblockkeys:
del blocks[k]
# update main interpreter callsite into the liftedloop
blocks[loopinfo.callfrom] = callblock
return liftedloop
def _has_multiple_loop_exits(cfg, lpinfo):
"""Returns True if there is more than one exit in the loop.
NOTE: "common exits" refers to the situation where a loop exit has another
loop exit as its successor. In that case, we do not need to alter it.
"""
if len(lpinfo.exits) <= 1:
return False
exits = set(lpinfo.exits)
pdom = cfg.post_dominators()
# Eliminate blocks that have other blocks as post-dominators.
processed = set()
remain = set(exits) # create a copy to work on
while remain:
node = remain.pop()
processed.add(node)
exits -= pdom[node] - {node}
remain = exits - processed
return len(exits) > 1
def _pre_looplift_transform(func_ir):
"""Canonicalize loops for looplifting.
"""
from numba.core.postproc import PostProcessor
cfg = compute_cfg_from_blocks(func_ir.blocks)
# For every loop that has multiple exits, combine the exits into one.
for loop_info in cfg.loops().values():
if _has_multiple_loop_exits(cfg, loop_info):
func_ir, _common_key = _fix_multi_exit_blocks(
func_ir, loop_info.exits
)
# Reset and reprocess the func_ir
func_ir._reset_analysis_variables()
PostProcessor(func_ir).run()
return func_ir
def loop_lifting(func_ir, typingctx, targetctx, flags, locals):
"""
Loop lifting transformation.
Given a interpreter `func_ir` returns a 2 tuple of
`(toplevel_interp, [loop0_interp, loop1_interp, ....])`
"""
func_ir = _pre_looplift_transform(func_ir)
blocks = func_ir.blocks.copy()
cfg = compute_cfg_from_blocks(blocks)
loopinfos = _loop_lift_get_candidate_infos(cfg, blocks,
func_ir.variable_lifetime.livemap)
loops = []
if loopinfos:
_logger.debug('loop lifting this IR with %d candidates:\n%s',
len(loopinfos), func_ir.dump_to_string())
for loopinfo in loopinfos:
lifted = _loop_lift_modify_blocks(func_ir, loopinfo, blocks,
typingctx, targetctx, flags, locals)
loops.append(lifted)
# Make main IR
main = func_ir.derive(blocks=blocks)
return main, loops
def canonicalize_cfg_single_backedge(blocks):
"""
Rewrite loops that have multiple backedges.
"""
cfg = compute_cfg_from_blocks(blocks)
newblocks = blocks.copy()
def new_block_id():
return max(newblocks.keys()) + 1
def has_multiple_backedges(loop):
count = 0
for k in loop.body:
blk = blocks[k]
edges = blk.terminator.get_targets()
# is a backedge?
if loop.header in edges:
count += 1
if count > 1:
# early exit
return True
return False
def yield_loops_with_multiple_backedges():
for lp in cfg.loops().values():
if has_multiple_backedges(lp):
yield lp
def replace_target(term, src, dst):
def replace(target):
return (dst if target == src else target)
if isinstance(term, ir.Branch):
return ir.Branch(cond=term.cond,
truebr=replace(term.truebr),
falsebr=replace(term.falsebr),
loc=term.loc)
elif isinstance(term, ir.Jump):
return ir.Jump(target=replace(term.target), loc=term.loc)
else:
assert not term.get_targets()
return term
def rewrite_single_backedge(loop):
"""
Add new tail block that gathers all the backedges
"""
header = loop.header
tailkey = new_block_id()
for blkkey in loop.body:
blk = newblocks[blkkey]
if header in blk.terminator.get_targets():
newblk = blk.copy()
# rewrite backedge into jumps to new tail block
newblk.body[-1] = replace_target(blk.terminator, header,
tailkey)
newblocks[blkkey] = newblk
# create new tail block
entryblk = newblocks[header]
tailblk = ir.Block(scope=entryblk.scope, loc=entryblk.loc)
# add backedge
tailblk.append(ir.Jump(target=header, loc=tailblk.loc))
newblocks[tailkey] = tailblk
for loop in yield_loops_with_multiple_backedges():
rewrite_single_backedge(loop)
return newblocks
def canonicalize_cfg(blocks):
"""
Rewrite the given blocks to canonicalize the CFG.
Returns a new dictionary of blocks.
"""
return canonicalize_cfg_single_backedge(blocks)
def with_lifting(func_ir, typingctx, targetctx, flags, locals):
"""With-lifting transformation
Rewrite the IR to extract all withs.
Only the top-level withs are extracted.
Returns the (the_new_ir, the_lifted_with_ir)
"""
from numba.core import postproc
def dispatcher_factory(func_ir, objectmode=False, **kwargs):
from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith
myflags = flags.copy()
if objectmode:
# Lifted with-block cannot looplift
myflags.enable_looplift = False
# Lifted with-block uses object mode
myflags.enable_pyobject = True
myflags.force_pyobject = True
myflags.no_cpython_wrapper = False
cls = ObjModeLiftedWith
else:
cls = LiftedWith
return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs)
# find where with-contexts regions are
withs, func_ir = find_setupwiths(func_ir)
if not withs:
return func_ir, []
postproc.PostProcessor(func_ir).run() # ensure we have variable lifetime
assert func_ir.variable_lifetime
vlt = func_ir.variable_lifetime
blocks = func_ir.blocks.copy()
cfg = vlt.cfg
# For each with-regions, mutate them according to
# the kind of contextmanager
sub_irs = []
for (blk_start, blk_end) in withs:
body_blocks = []
for node in _cfg_nodes_in_region(cfg, blk_start, blk_end):
body_blocks.append(node)
_legalize_with_head(blocks[blk_start])
# Find the contextmanager
cmkind, extra = _get_with_contextmanager(func_ir, blocks, blk_start)
# Mutate the body and get new IR
sub = cmkind.mutate_with_body(func_ir, blocks, blk_start, blk_end,
body_blocks, dispatcher_factory,
extra)
sub_irs.append(sub)
if not sub_irs:
# Unchanged
new_ir = func_ir
else:
new_ir = func_ir.derive(blocks)
return new_ir, sub_irs
def _get_with_contextmanager(func_ir, blocks, blk_start):
"""Get the global object used for the context manager
"""
_illegal_cm_msg = "Illegal use of context-manager."
def get_var_dfn(var):
"""Get the definition given a variable"""
return func_ir.get_definition(var)
def get_ctxmgr_obj(var_ref):
"""Return the context-manager object and extra info.
The extra contains the arguments if the context-manager is used
as a call.
"""
# If the contextmanager used as a Call
dfn = func_ir.get_definition(var_ref)
if isinstance(dfn, ir.Expr) and dfn.op == 'call':
args = [get_var_dfn(x) for x in dfn.args]
kws = {k: get_var_dfn(v) for k, v in dfn.kws}
extra = {'args': args, 'kwargs': kws}
var_ref = dfn.func
else:
extra = None
ctxobj = ir_utils.guard(ir_utils.find_global_value, func_ir, var_ref)
# check the contextmanager object
if ctxobj is ir.UNDEFINED:
raise errors.CompilerError(
"Undefined variable used as context manager",
loc=blocks[blk_start].loc,
)
if ctxobj is None:
raise errors.CompilerError(_illegal_cm_msg, loc=dfn.loc)
return ctxobj, extra
# Scan the start of the with-region for the contextmanager
for stmt in blocks[blk_start].body:
if isinstance(stmt, ir.EnterWith):
var_ref = stmt.contextmanager
ctxobj, extra = get_ctxmgr_obj(var_ref)
if not hasattr(ctxobj, 'mutate_with_body'):
raise errors.CompilerError(
"Unsupported context manager in use",
loc=blocks[blk_start].loc,
)
return ctxobj, extra
# No contextmanager found?
raise errors.CompilerError(
"malformed with-context usage",
loc=blocks[blk_start].loc,
)
def _legalize_with_head(blk):
"""Given *blk*, the head block of the with-context, check that it doesn't
do anything else.
"""
counters = defaultdict(int)
for stmt in blk.body:
counters[type(stmt)] += 1
if counters.pop(ir.EnterWith) != 1:
raise errors.CompilerError(
"with's head-block must have exactly 1 ENTER_WITH",
loc=blk.loc,
)
if counters.pop(ir.Jump) != 1:
raise errors.CompilerError(
"with's head-block must have exactly 1 JUMP",
loc=blk.loc,
)
# Can have any number of del
counters.pop(ir.Del, None)
# There MUST NOT be any other statements
if counters:
raise errors.CompilerError(
"illegal statements in with's head-block",
loc=blk.loc,
)
def _cfg_nodes_in_region(cfg, region_begin, region_end):
"""Find the set of CFG nodes that are in the given region
"""
region_nodes = set()
stack = [region_begin]
while stack:
tos = stack.pop()
succs, _ = zip(*cfg.successors(tos))
nodes = set([node for node in succs
if node not in region_nodes and
node != region_end])
stack.extend(nodes)
region_nodes |= nodes
return region_nodes
def find_setupwiths(func_ir):
"""Find all top-level with.
Returns a list of ranges for the with-regions.
"""
def find_ranges(blocks):
cfg = compute_cfg_from_blocks(blocks)
sus_setups, sus_pops = set(), set()
# traverse the cfg and collect all suspected SETUP_WITH and POP_BLOCK
# statements so that we can iterate over them
for label, block in blocks.items():
for stmt in block.body:
if ir_utils.is_setup_with(stmt):
sus_setups.add(label)
if ir_utils.is_pop_block(stmt):
sus_pops.add(label)
# now that we do have the statements, iterate through them in reverse
# topo order and from each start looking for pop_blocks
setup_with_to_pop_blocks_map = defaultdict(set)
for setup_block in cfg.topo_sort(sus_setups, reverse=True):
# begin pop_block, search
to_visit, seen = [], []
to_visit.append(setup_block)
while to_visit:
# get whatever is next and record that we have seen it
block = to_visit.pop()
seen.append(block)
# go through the body of the block, looking for statements
for stmt in blocks[block].body:
# raise detected before pop_block
if ir_utils.is_raise(stmt):
raise errors.CompilerError(
'unsupported control flow due to raise '
'statements inside with block'
)
# special case 3.7, return before POP_BLOCK
if PYVERSION < (3, 8) and ir_utils.is_return(stmt):
raise errors.CompilerError(
'unsupported control flow: due to return '
'statements inside with block'
)
# if a pop_block, process it
if ir_utils.is_pop_block(stmt) and block in sus_pops:
# record the jump target of this block belonging to this setup
setup_with_to_pop_blocks_map[setup_block].add(block)
# remove the block from blocks to be matched
sus_pops.remove(block)
# stop looking, we have reached the frontier
break
# if we are still here, by the block terminator,
# add all its targets to the to_visit stack, unless we
# have seen them already
if ir_utils.is_terminator(stmt):
for t in stmt.get_targets():
if t not in seen:
to_visit.append(t)
return setup_with_to_pop_blocks_map
blocks = func_ir.blocks
# initial find, will return a dictionary, mapping indices of blocks
# containing SETUP_WITH statements to a set of indices of blocks containing
# POP_BLOCK statements
with_ranges_dict = find_ranges(blocks)
# rewrite the CFG in case there are multiple POP_BLOCK statements for one
# with
func_ir = consolidate_multi_exit_withs(with_ranges_dict, blocks, func_ir)
# here we need to turn the withs back into a list of tuples so that the
# rest of the code can cope
with_ranges_tuple = [(s, list(p)[0])
for (s, p) in with_ranges_dict.items()]
# check for POP_BLOCKS with multiple outgoing edges and reject
for (_, p) in with_ranges_tuple:
targets = blocks[p].terminator.get_targets()
if len(targets) != 1:
raise errors.CompilerError(
"unsupported control flow: with-context contains branches "
"(i.e. break/return/raise) that can leave the block "
)
# now we check for returns inside with and reject them
for (_, p) in with_ranges_tuple:
target_block = blocks[p]
if ir_utils.is_return(func_ir.blocks[
target_block.terminator.get_targets()[0]].terminator):
if PYVERSION == (3, 8):
# 3.8 needs to bail here, if this is the case, because the
# later code can't handle it.
raise errors.CompilerError(
"unsupported control flow: due to return statements "
"inside with block"
)
_rewrite_return(func_ir, p)
# now we need to rewrite the tuple such that we have SETUP_WITH matching the
# successor of the block that contains the POP_BLOCK.
with_ranges_tuple = [(s, func_ir.blocks[p].terminator.get_targets()[0])
for (s, p) in with_ranges_tuple]
# finally we check for nested with statements and reject them
with_ranges_tuple = _eliminate_nested_withs(with_ranges_tuple)
return with_ranges_tuple, func_ir
def _rewrite_return(func_ir, target_block_label):
"""Rewrite a return block inside a with statement.
Arguments
---------
func_ir: Function IR
the CFG to transform
target_block_label: int
the block index/label of the block containing the POP_BLOCK statement
This implements a CFG transformation to insert a block between two other
blocks.
The input situation is:
┌───────────────┐
│ top │
│ POP_BLOCK │
│ bottom │
└───────┬───────┘
┌───────▼───────┐
│ │
│ RETURN │
│ │
└───────────────┘
If such a pattern is detected in IR, it means there is a `return` statement
within a `with` context. The basic idea is to rewrite the CFG as follows:
┌───────────────┐
│ top │
│ POP_BLOCK │
│ │
└───────┬───────┘
┌───────▼───────┐
│ │
│ bottom │
│ │
└───────┬───────┘
┌───────▼───────┐
│ │
│ RETURN │
│ │
└───────────────┘
We split the block that contains the `POP_BLOCK` statement into two blocks.
Everything from the beginning of the block up to and including the
`POP_BLOCK` statement is considered the 'top' and everything below is
considered 'bottom'. Finally the jump statements are re-wired to make sure
the CFG remains valid.
"""
# the block itself from the index
target_block = func_ir.blocks[target_block_label]
# get the index of the block containing the return
target_block_successor_label = target_block.terminator.get_targets()[0]
# the return block
target_block_successor = func_ir.blocks[target_block_successor_label]
# create the new return block with an appropriate label
max_label = ir_utils.find_max_label(func_ir.blocks)
new_label = max_label + 1
# create the new return block
new_block_loc = target_block_successor.loc
new_block_scope = ir.Scope(None, loc=new_block_loc)
new_block = ir.Block(new_block_scope, loc=new_block_loc)
# Split the block containing the POP_BLOCK into top and bottom
# Block must be of the form:
# -----------------
# <some stmts>
# POP_BLOCK
# <some more stmts>
# JUMP
# -----------------
top_body, bottom_body = [], []
pop_blocks = [*target_block.find_insts(ir.PopBlock)]
assert len(pop_blocks) == 1
assert len([*target_block.find_insts(ir.Jump)]) == 1
assert isinstance(target_block.body[-1], ir.Jump)
pb_marker = pop_blocks[0]
pb_is = target_block.body.index(pb_marker)
top_body.extend(target_block.body[:pb_is])
top_body.append(ir.Jump(target_block_successor_label, target_block.loc))
bottom_body.extend(target_block.body[pb_is:-1])
bottom_body.append(ir.Jump(new_label, target_block.loc))
# get the contents of the return block
return_body = func_ir.blocks[target_block_successor_label].body
# finally, re-assign all blocks
new_block.body.extend(return_body)
target_block_successor.body.clear()
target_block_successor.body.extend(bottom_body)
target_block.body.clear()
target_block.body.extend(top_body)
# finally, append the new return block and rebuild the IR properties
func_ir.blocks[new_label] = new_block
func_ir._definitions = ir_utils.build_definitions(func_ir.blocks)
return func_ir
def _eliminate_nested_withs(with_ranges):
known_ranges = []
def within_known_range(start, end, known_ranges):
for a, b in known_ranges:
# FIXME: this should be a comparison in topological order, right
# now we are comparing the integers of the blocks, stuff probably
# works by accident.
if start > a and end < b:
return True
return False
for s, e in sorted(with_ranges):
if not within_known_range(s, e, known_ranges):
known_ranges.append((s, e))
return known_ranges
def consolidate_multi_exit_withs(withs: dict, blocks, func_ir):
"""Modify the FunctionIR to merge the exit blocks of with constructs.
"""
out = []
for k in withs:
vs : set = withs[k]
if len(vs) > 1:
func_ir, common = _fix_multi_exit_blocks(
func_ir, vs, split_condition=ir_utils.is_pop_block,
)
withs[k] = {common}
return func_ir
def _fix_multi_exit_blocks(func_ir, exit_nodes, *, split_condition=None):
"""Modify the FunctionIR to create a single common exit node given the
original exit nodes.
Parameters
----------
func_ir :
The FunctionIR. Mutated inplace.
exit_nodes :
The original exit nodes. A sequence of block keys.
split_condition : callable or None
If not None, it is a callable with the signature
`split_condition(statement)` that determines if the `statement` is the
splitting point (e.g. `POP_BLOCK`) in an exit node.
If it's None, the exit node is not split.
"""
# Convert the following:
#
# | |
# +-------+ +-------+
# | exit0 | | exit1 |
# +-------+ +-------+
# | |
# +-------+ +-------+
# | after0| | after1|
# +-------+ +-------+
# | |
#
# To roughly:
#
# | |
# +-------+ +-------+
# | exit0 | | exit1 |
# +-------+ +-------+
# | |
# +-----+-----+
# |
# +---------+
# | common |
# +---------+
# |
# +-------+
# | post |
# +-------+
# |
# +-----+-----+
# | |
# +-------+ +-------+
# | after0| | after1|
# +-------+ +-------+
blocks = func_ir.blocks
# Getting the scope
any_blk = min(func_ir.blocks.values())
scope = any_blk.scope
# Getting the maximum block label
max_label = max(func_ir.blocks) + 1
# Define the new common block for the new exit.
common_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
common_label = max_label
max_label += 1
blocks[common_label] = common_block
# Define the new block after the exit.
post_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
post_label = max_label
max_label += 1
blocks[post_label] = post_block
# Adjust each exit node
remainings = []
for i, k in enumerate(exit_nodes):
blk = blocks[k]
# split the block if needed
if split_condition is not None:
for pt, stmt in enumerate(blk.body):
if split_condition(stmt):
break
else:
# no splitting
pt = -1
before = blk.body[:pt]
after = blk.body[pt:]
remainings.append(after)
# Add control-point variable to mark which exit block this is.
blk.body = before
loc = blk.loc
blk.body.append(
ir.Assign(value=ir.Const(i, loc=loc),
target=scope.get_or_define("$cp", loc=loc),
loc=loc)
)
# Replace terminator with a jump to the common block
assert not blk.is_terminated
blk.body.append(ir.Jump(common_label, loc=ir.unknown_loc))
if split_condition is not None:
# Move the splitting statement to the common block
common_block.body.append(remainings[0][0])
assert not common_block.is_terminated
# Append jump from common block to post block
common_block.body.append(ir.Jump(post_label, loc=loc))
# Make if-else tree to jump to target
remain_blocks = []
for remain in remainings:
remain_blocks.append(max_label)
max_label += 1
switch_block = post_block
loc = ir.unknown_loc
for i, remain in enumerate(remainings):
match_expr = scope.redefine("$cp_check", loc=loc)
match_rhs = scope.redefine("$cp_rhs", loc=loc)
# Do comparison to match control-point variable to the exit block
switch_block.body.append(
ir.Assign(
value=ir.Const(i, loc=loc),
target=match_rhs,
loc=loc
),
)
# Add assignment for the comparison
switch_block.body.append(
ir.Assign(
value=ir.Expr.binop(
fn=operator.eq, lhs=scope.get("$cp"), rhs=match_rhs,
loc=loc,
),
target=match_expr,
loc=loc
),
)
# Insert jump to the next case
[jump_target] = remain[-1].get_targets()
switch_block.body.append(
ir.Branch(match_expr, jump_target, remain_blocks[i], loc=loc),
)
switch_block = ir.Block(scope=scope, loc=loc)
blocks[remain_blocks[i]] = switch_block
# Add the final jump
switch_block.body.append(ir.Jump(jump_target, loc=loc))
return func_ir, common_label