-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathtransforms.py
914 lines (783 loc) · 31.1 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
"""
Implement transformation on Numba IR
"""
from collections import namedtuple, defaultdict
import logging
import operator
from numba.core.analysis import compute_cfg_from_blocks, find_top_level_loops
from numba.core import errors, ir, ir_utils
from numba.core.analysis import compute_use_defs, compute_cfg_from_blocks
from numba.core.utils import PYVERSION, _lazy_pformat
_logger = logging.getLogger(__name__)
def _extract_loop_lifting_candidates(cfg, blocks):
"""
Returns a list of loops that are candidate for loop lifting
"""
# check well-formed-ness of the loop
def same_exit_point(loop):
"all exits must point to the same location"
outedges = set()
for k in loop.exits:
succs = set(x for x, _ in cfg.successors(k))
if not succs:
# If the exit point has no successor, it contains an return
# statement, which is not handled by the looplifting code.
# Thus, this loop is not a candidate.
_logger.debug("return-statement in loop.")
return False
outedges |= succs
ok = len(outedges) == 1
_logger.debug("same_exit_point=%s (%s)", ok, outedges)
return ok
def one_entry(loop):
"there is one entry"
ok = len(loop.entries) == 1
_logger.debug("one_entry=%s", ok)
return ok
def cannot_yield(loop):
"cannot have yield inside the loop"
insiders = set(loop.body) | set(loop.entries) | set(loop.exits)
for blk in map(blocks.__getitem__, insiders):
for inst in blk.body:
if isinstance(inst, ir.Assign):
if isinstance(inst.value, ir.Yield):
_logger.debug("has yield")
return False
_logger.debug("no yield")
return True
_logger.info('finding looplift candidates')
# the check for cfg.entry_point in the loop.entries is to prevent a bad
# rewrite where a prelude for a lifted loop would get written into block -1
# if a loop entry were in block 0
candidates = []
for loop in find_top_level_loops(cfg):
_logger.debug("top-level loop: %s", loop)
if (same_exit_point(loop) and one_entry(loop) and cannot_yield(loop) and
cfg.entry_point() not in loop.entries):
candidates.append(loop)
_logger.debug("add candidate: %s", loop)
return candidates
def find_region_inout_vars(blocks, livemap, callfrom, returnto, body_block_ids):
"""Find input and output variables to a block region.
"""
inputs = livemap[callfrom]
outputs = livemap[returnto]
# ensure live variables are actually used in the blocks, else remove,
# saves having to create something valid to run through postproc
# to achieve similar
loopblocks = {}
for k in body_block_ids:
loopblocks[k] = blocks[k]
used_vars = set()
def_vars = set()
defs = compute_use_defs(loopblocks)
for vs in defs.usemap.values():
used_vars |= vs
for vs in defs.defmap.values():
def_vars |= vs
used_or_defined = used_vars | def_vars
# note: sorted for stable ordering
inputs = sorted(set(inputs) & used_or_defined)
outputs = sorted(set(outputs) & used_or_defined & def_vars)
return inputs, outputs
_loop_lift_info = namedtuple('loop_lift_info',
'loop,inputs,outputs,callfrom,returnto')
def _loop_lift_get_candidate_infos(cfg, blocks, livemap):
"""
Returns information on looplifting candidates.
"""
loops = _extract_loop_lifting_candidates(cfg, blocks)
loopinfos = []
for loop in loops:
[callfrom] = loop.entries # requirement checked earlier
an_exit = next(iter(loop.exits)) # anyone of the exit block
if len(loop.exits) > 1:
# has multiple exits
[(returnto, _)] = cfg.successors(an_exit) # requirement checked earlier
else:
# does not have multiple exits
returnto = an_exit
local_block_ids = set(loop.body) | set(loop.entries) | set(loop.exits)
inputs, outputs = find_region_inout_vars(
blocks=blocks,
livemap=livemap,
callfrom=callfrom,
returnto=returnto,
body_block_ids=local_block_ids,
)
lli = _loop_lift_info(loop=loop, inputs=inputs, outputs=outputs,
callfrom=callfrom, returnto=returnto)
loopinfos.append(lli)
return loopinfos
def _loop_lift_modify_call_block(liftedloop, block, inputs, outputs, returnto):
"""
Transform calling block from top-level function to call the lifted loop.
"""
scope = block.scope
loc = block.loc
blk = ir.Block(scope=scope, loc=loc)
ir_utils.fill_block_with_call(
newblock=blk,
callee=liftedloop,
label_next=returnto,
inputs=inputs,
outputs=outputs,
)
return blk
def _loop_lift_prepare_loop_func(loopinfo, blocks):
"""
Inplace transform loop blocks for use as lifted loop.
"""
entry_block = blocks[loopinfo.callfrom]
scope = entry_block.scope
loc = entry_block.loc
# Lowering assumes the first block to be the one with the smallest offset
firstblk = min(blocks) - 1
blocks[firstblk] = ir_utils.fill_callee_prologue(
block=ir.Block(scope=scope, loc=loc),
inputs=loopinfo.inputs,
label_next=loopinfo.callfrom,
)
blocks[loopinfo.returnto] = ir_utils.fill_callee_epilogue(
block=ir.Block(scope=scope, loc=loc),
outputs=loopinfo.outputs,
)
def _loop_lift_modify_blocks(func_ir, loopinfo, blocks,
typingctx, targetctx, flags, locals):
"""
Modify the block inplace to call to the lifted-loop.
Returns a dictionary of blocks of the lifted-loop.
"""
from numba.core.dispatcher import LiftedLoop
# Copy loop blocks
loop = loopinfo.loop
loopblockkeys = set(loop.body) | set(loop.entries)
if len(loop.exits) > 1:
# has multiple exits
loopblockkeys |= loop.exits
loopblocks = dict((k, blocks[k].copy()) for k in loopblockkeys)
# Modify the loop blocks
_loop_lift_prepare_loop_func(loopinfo, loopblocks)
# Since Python 3.13, [END_FOR, POP_TOP] sequence becomes the start of the
# block causing the block to have line number of the start of previous loop.
# Fix this using the loc of the first getiter.
getiter_exprs = []
for blk in loopblocks.values():
getiter_exprs.extend(blk.find_exprs(op="getiter"))
first_getiter = min(getiter_exprs, key=lambda x: x.loc.line)
loop_loc = first_getiter.loc
# Create a new IR for the lifted loop
lifted_ir = func_ir.derive(blocks=loopblocks,
arg_names=tuple(loopinfo.inputs),
arg_count=len(loopinfo.inputs),
force_non_generator=True,
loc=loop_loc)
liftedloop = LiftedLoop(lifted_ir,
typingctx, targetctx, flags, locals)
# modify for calling into liftedloop
callblock = _loop_lift_modify_call_block(liftedloop, blocks[loopinfo.callfrom],
loopinfo.inputs, loopinfo.outputs,
loopinfo.returnto)
# remove blocks
for k in loopblockkeys:
del blocks[k]
# update main interpreter callsite into the liftedloop
blocks[loopinfo.callfrom] = callblock
return liftedloop
def _has_multiple_loop_exits(cfg, lpinfo):
"""Returns True if there is more than one exit in the loop.
NOTE: "common exits" refers to the situation where a loop exit has another
loop exit as its successor. In that case, we do not need to alter it.
"""
if len(lpinfo.exits) <= 1:
return False
exits = set(lpinfo.exits)
pdom = cfg.post_dominators()
# Eliminate blocks that have other blocks as post-dominators.
processed = set()
remain = set(exits) # create a copy to work on
while remain:
node = remain.pop()
processed.add(node)
exits -= pdom[node] - {node}
remain = exits - processed
return len(exits) > 1
def _pre_looplift_transform(func_ir):
"""Canonicalize loops for looplifting.
"""
from numba.core.postproc import PostProcessor
cfg = compute_cfg_from_blocks(func_ir.blocks)
# For every loop that has multiple exits, combine the exits into one.
for loop_info in cfg.loops().values():
if _has_multiple_loop_exits(cfg, loop_info):
func_ir, _common_key = _fix_multi_exit_blocks(
func_ir, loop_info.exits
)
# Reset and reprocess the func_ir
func_ir._reset_analysis_variables()
PostProcessor(func_ir).run()
return func_ir
def loop_lifting(func_ir, typingctx, targetctx, flags, locals):
"""
Loop lifting transformation.
Given a interpreter `func_ir` returns a 2 tuple of
`(toplevel_interp, [loop0_interp, loop1_interp, ....])`
"""
func_ir = _pre_looplift_transform(func_ir)
blocks = func_ir.blocks.copy()
cfg = compute_cfg_from_blocks(blocks)
loopinfos = _loop_lift_get_candidate_infos(cfg, blocks,
func_ir.variable_lifetime.livemap)
loops = []
if loopinfos:
_logger.debug('loop lifting this IR with %d candidates:\n%s',
len(loopinfos),
_lazy_pformat(func_ir, lazy_func=lambda x: x.dump_to_string()))
for loopinfo in loopinfos:
lifted = _loop_lift_modify_blocks(func_ir, loopinfo, blocks,
typingctx, targetctx, flags, locals)
loops.append(lifted)
# Make main IR
main = func_ir.derive(blocks=blocks)
return main, loops
def canonicalize_cfg_single_backedge(blocks):
"""
Rewrite loops that have multiple backedges.
"""
cfg = compute_cfg_from_blocks(blocks)
newblocks = blocks.copy()
def new_block_id():
return max(newblocks.keys()) + 1
def has_multiple_backedges(loop):
count = 0
for k in loop.body:
blk = blocks[k]
edges = blk.terminator.get_targets()
# is a backedge?
if loop.header in edges:
count += 1
if count > 1:
# early exit
return True
return False
def yield_loops_with_multiple_backedges():
for lp in cfg.loops().values():
if has_multiple_backedges(lp):
yield lp
def replace_target(term, src, dst):
def replace(target):
return (dst if target == src else target)
if isinstance(term, ir.Branch):
return ir.Branch(cond=term.cond,
truebr=replace(term.truebr),
falsebr=replace(term.falsebr),
loc=term.loc)
elif isinstance(term, ir.Jump):
return ir.Jump(target=replace(term.target), loc=term.loc)
else:
assert not term.get_targets()
return term
def rewrite_single_backedge(loop):
"""
Add new tail block that gathers all the backedges
"""
header = loop.header
tailkey = new_block_id()
for blkkey in loop.body:
blk = newblocks[blkkey]
if header in blk.terminator.get_targets():
newblk = blk.copy()
# rewrite backedge into jumps to new tail block
newblk.body[-1] = replace_target(blk.terminator, header,
tailkey)
newblocks[blkkey] = newblk
# create new tail block
entryblk = newblocks[header]
tailblk = ir.Block(scope=entryblk.scope, loc=entryblk.loc)
# add backedge
tailblk.append(ir.Jump(target=header, loc=tailblk.loc))
newblocks[tailkey] = tailblk
for loop in yield_loops_with_multiple_backedges():
rewrite_single_backedge(loop)
return newblocks
def canonicalize_cfg(blocks):
"""
Rewrite the given blocks to canonicalize the CFG.
Returns a new dictionary of blocks.
"""
return canonicalize_cfg_single_backedge(blocks)
def with_lifting(func_ir, typingctx, targetctx, flags, locals):
"""With-lifting transformation
Rewrite the IR to extract all withs.
Only the top-level withs are extracted.
Returns the (the_new_ir, the_lifted_with_ir)
"""
from numba.core import postproc
def dispatcher_factory(func_ir, objectmode=False, **kwargs):
from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith
myflags = flags.copy()
if objectmode:
# Lifted with-block cannot looplift
myflags.enable_looplift = False
# Lifted with-block uses object mode
myflags.enable_pyobject = True
myflags.force_pyobject = True
myflags.no_cpython_wrapper = False
cls = ObjModeLiftedWith
else:
cls = LiftedWith
return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs)
# find where with-contexts regions are
withs, func_ir = find_setupwiths(func_ir)
if not withs:
return func_ir, []
postproc.PostProcessor(func_ir).run() # ensure we have variable lifetime
assert func_ir.variable_lifetime
vlt = func_ir.variable_lifetime
blocks = func_ir.blocks.copy()
cfg = vlt.cfg
# For each with-regions, mutate them according to
# the kind of contextmanager
sub_irs = []
for (blk_start, blk_end) in withs:
body_blocks = []
for node in _cfg_nodes_in_region(cfg, blk_start, blk_end):
body_blocks.append(node)
_legalize_with_head(blocks[blk_start])
# Find the contextmanager
cmkind, extra = _get_with_contextmanager(func_ir, blocks, blk_start)
# Mutate the body and get new IR
sub = cmkind.mutate_with_body(func_ir, blocks, blk_start, blk_end,
body_blocks, dispatcher_factory,
extra)
sub_irs.append(sub)
if not sub_irs:
# Unchanged
new_ir = func_ir
else:
new_ir = func_ir.derive(blocks)
return new_ir, sub_irs
def _get_with_contextmanager(func_ir, blocks, blk_start):
"""Get the global object used for the context manager
"""
_illegal_cm_msg = "Illegal use of context-manager."
def get_var_dfn(var):
"""Get the definition given a variable"""
return func_ir.get_definition(var)
def get_ctxmgr_obj(var_ref):
"""Return the context-manager object and extra info.
The extra contains the arguments if the context-manager is used
as a call.
"""
# If the contextmanager used as a Call
dfn = func_ir.get_definition(var_ref)
if isinstance(dfn, ir.Expr) and dfn.op == 'call':
args = [get_var_dfn(x) for x in dfn.args]
kws = {k: get_var_dfn(v) for k, v in dfn.kws}
extra = {'args': args, 'kwargs': kws}
var_ref = dfn.func
else:
extra = None
ctxobj = ir_utils.guard(ir_utils.find_outer_value, func_ir, var_ref)
# check the contextmanager object
if ctxobj is ir.UNDEFINED:
raise errors.CompilerError(
"Undefined variable used as context manager",
loc=blocks[blk_start].loc,
)
if ctxobj is None:
raise errors.CompilerError(_illegal_cm_msg, loc=dfn.loc)
return ctxobj, extra
# Scan the start of the with-region for the contextmanager
for stmt in blocks[blk_start].body:
if isinstance(stmt, ir.EnterWith):
var_ref = stmt.contextmanager
ctxobj, extra = get_ctxmgr_obj(var_ref)
if not hasattr(ctxobj, 'mutate_with_body'):
raise errors.CompilerError(
"Unsupported context manager in use",
loc=blocks[blk_start].loc,
)
return ctxobj, extra
# No contextmanager found?
raise errors.CompilerError(
"malformed with-context usage",
loc=blocks[blk_start].loc,
)
def _legalize_with_head(blk):
"""Given *blk*, the head block of the with-context, check that it doesn't
do anything else.
"""
counters = defaultdict(int)
for stmt in blk.body:
counters[type(stmt)] += 1
if counters.pop(ir.EnterWith) != 1:
raise errors.CompilerError(
"with's head-block must have exactly 1 ENTER_WITH",
loc=blk.loc,
)
if counters.pop(ir.Jump, 0) != 1:
raise errors.CompilerError(
"with's head-block must have exactly 1 JUMP",
loc=blk.loc,
)
# Can have any number of del
counters.pop(ir.Del, None)
# There MUST NOT be any other statements
if counters:
raise errors.CompilerError(
"illegal statements in with's head-block",
loc=blk.loc,
)
def _cfg_nodes_in_region(cfg, region_begin, region_end):
"""Find the set of CFG nodes that are in the given region
"""
region_nodes = set()
stack = [region_begin]
while stack:
tos = stack.pop()
succlist = list(cfg.successors(tos))
# a single block function will have a empty successor list
if succlist:
succs, _ = zip(*succlist)
nodes = set([node for node in succs
if node not in region_nodes and
node != region_end])
stack.extend(nodes)
region_nodes |= nodes
return region_nodes
def find_setupwiths(func_ir):
"""Find all top-level with.
Returns a list of ranges for the with-regions.
"""
def find_ranges(blocks):
cfg = compute_cfg_from_blocks(blocks)
sus_setups, sus_pops = set(), set()
# traverse the cfg and collect all suspected SETUP_WITH and POP_BLOCK
# statements so that we can iterate over them
for label, block in blocks.items():
for stmt in block.body:
if ir_utils.is_setup_with(stmt):
sus_setups.add(label)
if ir_utils.is_pop_block(stmt):
sus_pops.add(label)
# now that we do have the statements, iterate through them in reverse
# topo order and from each start looking for pop_blocks
setup_with_to_pop_blocks_map = defaultdict(set)
for setup_block in cfg.topo_sort(sus_setups, reverse=True):
# begin pop_block, search
to_visit, seen = [], []
to_visit.append(setup_block)
while to_visit:
# get whatever is next and record that we have seen it
block = to_visit.pop()
seen.append(block)
# go through the body of the block, looking for statements
for stmt in blocks[block].body:
# raise detected before pop_block
if ir_utils.is_raise(stmt):
raise errors.CompilerError(
'unsupported control flow due to raise '
'statements inside with block'
)
# if a pop_block, process it
if ir_utils.is_pop_block(stmt) and block in sus_pops:
# record the jump target of this block belonging to this setup
setup_with_to_pop_blocks_map[setup_block].add(block)
# remove the block from blocks to be matched
sus_pops.remove(block)
# stop looking, we have reached the frontier
break
# if we are still here, by the block terminator,
# add all its targets to the to_visit stack, unless we
# have seen them already
if ir_utils.is_terminator(stmt):
for t in stmt.get_targets():
if t not in seen:
to_visit.append(t)
return setup_with_to_pop_blocks_map
blocks = func_ir.blocks
# initial find, will return a dictionary, mapping indices of blocks
# containing SETUP_WITH statements to a set of indices of blocks containing
# POP_BLOCK statements
with_ranges_dict = find_ranges(blocks)
# rewrite the CFG in case there are multiple POP_BLOCK statements for one
# with
func_ir = consolidate_multi_exit_withs(with_ranges_dict, blocks, func_ir)
# here we need to turn the withs back into a list of tuples so that the
# rest of the code can cope
with_ranges_tuple = [(s, list(p)[0])
for (s, p) in with_ranges_dict.items()]
# check for POP_BLOCKS with multiple outgoing edges and reject
for (_, p) in with_ranges_tuple:
targets = blocks[p].terminator.get_targets()
if len(targets) != 1:
raise errors.CompilerError(
"unsupported control flow: with-context contains branches "
"(i.e. break/return/raise) that can leave the block "
)
# now we check for returns inside with and reject them
for (_, p) in with_ranges_tuple:
target_block = blocks[p]
if ir_utils.is_return(func_ir.blocks[
target_block.terminator.get_targets()[0]].terminator):
_rewrite_return(func_ir, p)
# now we need to rewrite the tuple such that we have SETUP_WITH matching the
# successor of the block that contains the POP_BLOCK.
with_ranges_tuple = [(s, func_ir.blocks[p].terminator.get_targets()[0])
for (s, p) in with_ranges_tuple]
# finally we check for nested with statements and reject them
with_ranges_tuple = _eliminate_nested_withs(with_ranges_tuple)
return with_ranges_tuple, func_ir
def _rewrite_return(func_ir, target_block_label):
"""Rewrite a return block inside a with statement.
Arguments
---------
func_ir: Function IR
the CFG to transform
target_block_label: int
the block index/label of the block containing the POP_BLOCK statement
This implements a CFG transformation to insert a block between two other
blocks.
The input situation is:
┌───────────────┐
│ top │
│ POP_BLOCK │
│ bottom │
└───────┬───────┘
│
┌───────▼───────┐
│ │
│ RETURN │
│ │
└───────────────┘
If such a pattern is detected in IR, it means there is a `return` statement
within a `with` context. The basic idea is to rewrite the CFG as follows:
┌───────────────┐
│ top │
│ POP_BLOCK │
│ │
└───────┬───────┘
│
┌───────▼───────┐
│ │
│ bottom │
│ │
└───────┬───────┘
│
┌───────▼───────┐
│ │
│ RETURN │
│ │
└───────────────┘
We split the block that contains the `POP_BLOCK` statement into two blocks.
Everything from the beginning of the block up to and including the
`POP_BLOCK` statement is considered the 'top' and everything below is
considered 'bottom'. Finally the jump statements are re-wired to make sure
the CFG remains valid.
"""
# the block itself from the index
target_block = func_ir.blocks[target_block_label]
# get the index of the block containing the return
target_block_successor_label = target_block.terminator.get_targets()[0]
# the return block
target_block_successor = func_ir.blocks[target_block_successor_label]
# create the new return block with an appropriate label
max_label = ir_utils.find_max_label(func_ir.blocks)
new_label = max_label + 1
# create the new return block
new_block_loc = target_block_successor.loc
new_block_scope = ir.Scope(None, loc=new_block_loc)
new_block = ir.Block(new_block_scope, loc=new_block_loc)
# Split the block containing the POP_BLOCK into top and bottom
# Block must be of the form:
# -----------------
# <some stmts>
# POP_BLOCK
# <some more stmts>
# JUMP
# -----------------
top_body, bottom_body = [], []
pop_blocks = [*target_block.find_insts(ir.PopBlock)]
assert len(pop_blocks) == 1
assert len([*target_block.find_insts(ir.Jump)]) == 1
assert isinstance(target_block.body[-1], ir.Jump)
pb_marker = pop_blocks[0]
pb_is = target_block.body.index(pb_marker)
top_body.extend(target_block.body[:pb_is])
top_body.append(ir.Jump(target_block_successor_label, target_block.loc))
bottom_body.extend(target_block.body[pb_is:-1])
bottom_body.append(ir.Jump(new_label, target_block.loc))
# get the contents of the return block
return_body = func_ir.blocks[target_block_successor_label].body
# finally, re-assign all blocks
new_block.body.extend(return_body)
target_block_successor.body.clear()
target_block_successor.body.extend(bottom_body)
target_block.body.clear()
target_block.body.extend(top_body)
# finally, append the new return block and rebuild the IR properties
func_ir.blocks[new_label] = new_block
func_ir._definitions = ir_utils.build_definitions(func_ir.blocks)
return func_ir
def _eliminate_nested_withs(with_ranges):
known_ranges = []
def within_known_range(start, end, known_ranges):
for a, b in known_ranges:
# FIXME: this should be a comparison in topological order, right
# now we are comparing the integers of the blocks, stuff probably
# works by accident.
if start > a and end < b:
return True
return False
for s, e in sorted(with_ranges):
if not within_known_range(s, e, known_ranges):
known_ranges.append((s, e))
return known_ranges
def consolidate_multi_exit_withs(withs: dict, blocks, func_ir):
"""Modify the FunctionIR to merge the exit blocks of with constructs.
"""
for k in withs:
vs : set = withs[k]
if len(vs) > 1:
func_ir, common = _fix_multi_exit_blocks(
func_ir, vs, split_condition=ir_utils.is_pop_block,
)
withs[k] = {common}
return func_ir
def _fix_multi_exit_blocks(func_ir, exit_nodes, *, split_condition=None):
"""Modify the FunctionIR to create a single common exit node given the
original exit nodes.
Parameters
----------
func_ir :
The FunctionIR. Mutated inplace.
exit_nodes :
The original exit nodes. A sequence of block keys.
split_condition : callable or None
If not None, it is a callable with the signature
`split_condition(statement)` that determines if the `statement` is the
splitting point (e.g. `POP_BLOCK`) in an exit node.
If it's None, the exit node is not split.
"""
# Convert the following:
#
# | |
# +-------+ +-------+
# | exit0 | | exit1 |
# +-------+ +-------+
# | |
# +-------+ +-------+
# | after0| | after1|
# +-------+ +-------+
# | |
#
# To roughly:
#
# | |
# +-------+ +-------+
# | exit0 | | exit1 |
# +-------+ +-------+
# | |
# +-----+-----+
# |
# +---------+
# | common |
# +---------+
# |
# +-------+
# | post |
# +-------+
# |
# +-----+-----+
# | |
# +-------+ +-------+
# | after0| | after1|
# +-------+ +-------+
blocks = func_ir.blocks
# Getting the scope
any_blk = min(func_ir.blocks.values())
scope = any_blk.scope
# Getting the maximum block label
max_label = max(func_ir.blocks) + 1
# Define the new common block for the new exit.
common_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
common_label = max_label
max_label += 1
blocks[common_label] = common_block
# Define the new block after the exit.
post_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
post_label = max_label
max_label += 1
blocks[post_label] = post_block
# Adjust each exit node
remainings = []
for i, k in enumerate(exit_nodes):
blk = blocks[k]
# split the block if needed
if split_condition is not None:
for pt, stmt in enumerate(blk.body):
if split_condition(stmt):
break
else:
# no splitting
pt = -1
before = blk.body[:pt]
after = blk.body[pt:]
remainings.append(after)
# Add control-point variable to mark which exit block this is.
blk.body = before
loc = blk.loc
blk.body.append(
ir.Assign(value=ir.Const(i, loc=loc),
target=scope.get_or_define("$cp", loc=loc),
loc=loc)
)
# Replace terminator with a jump to the common block
assert not blk.is_terminated
blk.body.append(ir.Jump(common_label, loc=ir.unknown_loc))
if split_condition is not None:
# Move the splitting statement to the common block
common_block.body.append(remainings[0][0])
assert not common_block.is_terminated
# Append jump from common block to post block
common_block.body.append(ir.Jump(post_label, loc=loc))
# Make if-else tree to jump to target
remain_blocks = []
for remain in remainings:
remain_blocks.append(max_label)
max_label += 1
switch_block = post_block
loc = ir.unknown_loc
for i, remain in enumerate(remainings):
match_expr = scope.redefine("$cp_check", loc=loc)
match_rhs = scope.redefine("$cp_rhs", loc=loc)
# Do comparison to match control-point variable to the exit block
switch_block.body.append(
ir.Assign(
value=ir.Const(i, loc=loc),
target=match_rhs,
loc=loc
),
)
# Add assignment for the comparison
switch_block.body.append(
ir.Assign(
value=ir.Expr.binop(
fn=operator.eq, lhs=scope.get("$cp"), rhs=match_rhs,
loc=loc,
),
target=match_expr,
loc=loc
),
)
# Insert jump to the next case
[jump_target] = remain[-1].get_targets()
switch_block.body.append(
ir.Branch(match_expr, jump_target, remain_blocks[i], loc=loc),
)
switch_block = ir.Block(scope=scope, loc=loc)
blocks[remain_blocks[i]] = switch_block
# Add the final jump
switch_block.body.append(ir.Jump(jump_target, loc=loc))
return func_ir, common_label