forked from swiftlang/swift
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TFCanonicalizeCFG.cpp
1442 lines (1300 loc) · 56.3 KB
/
TFCanonicalizeCFG.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//===--- TFCanonicalizeCFG.cpp - Transform CFG into SESE regions ----------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file is responsible for canonicalizing semi-structured CFGs into well
// structured CFGs that can be reduced into Single-Entry-Single-Exit (SESE)
// regions, composed of properly nested "if" diamonds and while loops. This
// transformation gives the subsequent graph generation pass a simple form that
// is known to translate into XLA control flow primitives.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "canonicalize-cfg-for-xla"
#include "TFCanonicalizeCFG.h"
#include "TFUtilities.h"
#include "swift/AST/TensorFlow.h"
#include "swift/Basic/TransformArrayRef.h"
#include "swift/SILOptimizer/PassManager/Passes.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
#include "swift/SILOptimizer/Utils/LoopUtils.h"
#include "swift/SILOptimizer/Utils/CFG.h"
#include "swift/SIL/Dominance.h"
#include "swift/SIL/GraphOperationBuilder.h"
#include "swift/SIL/LoopInfo.h"
#include "swift/SIL/SILBuilder.h"
#include "swift/SIL/SILCloner.h"
#include "swift/SIL/SILConstants.h"
#include "swift/SIL/SILUndef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/EquivalenceClasses.h"
using namespace swift;
using namespace tf;
static llvm::cl::opt<bool> TFEnsureSingleLoopExit(
"tf-ensure-single-loop-exit", llvm::cl::init(true),
llvm::cl::desc("Transform loops to have a single exit from header."));
static llvm::cl::opt<bool> TFNoUndefsInSESE(
"tf-no-undefs-in-sese", llvm::cl::init(true),
llvm::cl::desc(
"Try to eliminate undefs in when performing "
"sese canonicalization of loops (intended for debugging)."));
//===----------------------------------------------------------------------===//
// SESERegionTree Implementation
//===----------------------------------------------------------------------===//
SESERegionTree::~SESERegionTree() {}
void SESERegionTree::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
void SESERegionTree::print(llvm::raw_ostream &OS, unsigned indent) const {
switch (getKind()) {
case SingleBlock: return cast<SingleBlockSESERegion>(this)->print(OS, indent);
case Sequence: return cast<SequenceSESERegion>(this)->print(OS, indent);
case WhileLoop: return cast<WhileLoopSESERegion>(this)->print(OS, indent);
case Conditional: return cast<ConditionalSESERegion>(this)->print(OS, indent);
}
};
void SingleBlockSESERegion::
print(llvm::raw_ostream &OS, unsigned indent) const {
OS.indent(indent) << "block ";
BB->printAsOperand(OS);
}
void SequenceSESERegion::print(llvm::raw_ostream &OS, unsigned indent) const {
OS.indent(indent) << "[sequence";
for (auto &n : getNodes()) {
OS << "\n";
n->print(OS, indent+2);
}
OS << "]";
}
void WhileLoopSESERegion::print(llvm::raw_ostream &OS, unsigned indent) const {
OS.indent(indent) << "<while Preheader: ";
preheader->printAsOperand(OS);
OS << ", Header: ";
header->printAsOperand(OS);
OS << ", exit: ";
exit->printAsOperand(OS);
OS << "\n";
body->print(OS, indent+2);
OS << ">";
}
void ConditionalSESERegion::print(llvm::raw_ostream &OS, unsigned indent) const {
OS.indent(indent) << "{condition Header: ";
branchBB->printAsOperand(OS);
OS << "\n";
if (auto t = getTrue()) {
t->print(OS, indent+2);
} else {
OS.indent(indent+2) << "<<null true condition block>>";
}
OS << "\n";
if (auto f = getFalse()) {
f->print(OS, indent+2);
} else {
OS.indent(indent+2) << "<<null false condition block>>";
}
OS << "}";
}
//===----------------------------------------------------------------------===//
// CFG Canonicalization Implementation
//===----------------------------------------------------------------------===//
// Our partitioning and other transformations can leave around lots of
// unconditional branches between blocks that formerly had control edges. Go
// through and merge those to make later passes simpler.
bool tf::contractUncondBranches(SILFunction *fn, DominanceInfo* DI, SILLoopInfo *LI) {
bool changed = false;
// Iterate carefully to avoid invalidating iterators: we mutate the block list
// while we walk it.
for (auto bbi = fn->begin(), e = fn->end(); bbi != e;) {
auto *bb = &*bbi;
if (mergeBasicBlockWithSuccessor(bb, DI, LI)) {
// The block was merged with this successor. Therefore, revisit this node:
// we have new successor(s) and may need to contract them as well. Also,
// bbi may be invalidated at this point.
changed = true;
bbi = SILFunction::iterator(bb);
} else {
// Move to the next block if this was not merged.
++bbi;
}
}
return changed;
}
namespace {
class SESERegionBuilder {
DominanceInfo DI;
PostDominanceInfo PDI;
SILLoopInfo LI;
SILFunction* F;
/// processLoops fills in this mapping: it is keyed by the preheader block
/// of each loop, and points to the region produced for it.
llvm::DenseMap<SILBasicBlock*, WhileLoopSESERegion*> loopPreheaders;
public:
SESERegionBuilder(SILFunction *F) : DI(F), PDI(F), LI(F, &DI), F(F) {}
std::unique_ptr<SESERegionTree>
processAcyclicRegion(SILBasicBlock *startBB, SILBasicBlock *endBB);
/// Process all of the top-level loops in the function in post-order.
void processLoops() {
// Apply the standard SIL loop canonicalization transformations. This
// automatically gives us the following invariants: loops are guaranteed
// to have a single preheader, a single backedge block, and exit??
if (canonicalizeAllLoops(&DI, &LI)) {
// Recalculate PDI if canonicalization made any changes.
PDI.recalculate(*F);
}
// CanonicalizeAllLoops only ensures the following properties:
// - There is a unique preheader block.
// - Exit blocks have only predecessors within the loop.
// - There is a single latch block and a single back edge to the header.
// We also need to ensure the following invariants for lowering:
// - The header is the only block from which loop is exited.
// - There is a single exit block.
if (TFEnsureSingleLoopExit) {
ensureSingleExitFromLoops();
}
for (auto *loop : LI)
processLoop(loop);
}
private:
std::unique_ptr<SESERegionTree>
processAcyclicRegionExcludingEnd(SILBasicBlock *startBB,
SILBasicBlock *endBB);
void processLoop(SILLoop *loop);
void ensureSingleExitFromLoops();
};
} // end anonymous namespace
/// Transform the specified acyclic region (possibly with internal while loop
/// nodes collapsed) from startBB to endBB inclusive into properly nested SESE
/// regions and return them.
std::unique_ptr<SESERegionTree>
SESERegionBuilder::processAcyclicRegion(SILBasicBlock *startBB,
SILBasicBlock *endBB) {
// Process the bulk of the region.
auto startRegion = processAcyclicRegionExcludingEnd(startBB, endBB);
// Create the end block.
auto endRegion =
std::unique_ptr<SESERegionTree>(new SingleBlockSESERegion(endBB));
// Merge the end block into the whatever the start of the region was.
if (!startRegion)
return endRegion;
if (auto seq = dyn_cast<SequenceSESERegion>(startRegion.get())) {
seq->addNode(std::move(endRegion));
return startRegion;
}
std::unique_ptr<SESERegionTree> both[] = {
std::move(startRegion), std::move(endRegion)
};
auto result = new SequenceSESERegion(both);
return std::unique_ptr<SESERegionTree>(result);
}
/// Transform the specified acyclic region (possibly with internal while loop
/// nodes collapsed) from startBB to endBB inclusive into properly nested SESE
/// regions and return them.
///
/// This does not process endBB, and returns a null region tree if
/// startBB == endBB.
///
std::unique_ptr<SESERegionTree>
SESERegionBuilder::processAcyclicRegionExcludingEnd(SILBasicBlock *startBB,
SILBasicBlock *endBB) {
assert(PDI.dominates(endBB, startBB) &&
"endBB is required to post-dominate startBB");
SmallVector<std::unique_ptr<SESERegionTree>, 4> results;
// Iteratively work our way up the post-dominator tree (moving startBB until
// we reach the endBB), producing a sequence of loops, diamonds, and single
// block nodes as we go.
while (startBB != endBB) {
// If this ends with a loop, it will already have been processed and
// collapsed into a single node. Just use it.
auto loopIt = loopPreheaders.find(startBB);
if (loopIt != loopPreheaders.end()) {
auto whileNode = loopIt->second;
loopPreheaders.erase(loopIt);
results.push_back(std::unique_ptr<SESERegionTree>(whileNode));
startBB = whileNode->getExit();
continue;
}
// If startBB ends with an unconditional branch, then just add it to the
// sequence and keep going. The destination of the branch may have multiple
// successors in the case where the destination is endBB. The successors
// could be coming from a parent conditional region.
if (auto *branch = dyn_cast<BranchInst>(startBB->getTerminator())) {
auto startRegion = new SingleBlockSESERegion(startBB);
results.push_back(std::unique_ptr<SESERegionTree>(startRegion));
startBB = branch->getDestBB();
continue;
}
// Otherwise, we know that startBB ends with a conditional branch.
auto *condBr = cast<CondBranchInst>(startBB->getTerminator());
// Get the immediate postdominator of endBB, which defines a SESE
// subregion postdominated by startBB and postdominating endBB. Note that
// the postidom may in fact be endBB.
auto postidom = PDI[startBB]->getIDom()->getBlock();
// Analyze the successors of the branch: each of them is post dominated by
// endBB (and any of the successors may be exactly endBB).
auto trueRegion =
processAcyclicRegionExcludingEnd(condBr->getTrueBB(), postidom);
auto falseRegion =
processAcyclicRegionExcludingEnd(condBr->getFalseBB(), postidom);
// Finally, form our conditional region.
auto condRegion = new ConditionalSESERegion(startBB, std::move(trueRegion),
std::move(falseRegion));
results.push_back(std::unique_ptr<SESERegionTree>(condRegion));
startBB = postidom;
}
switch (results.size()) {
case 0: return std::unique_ptr<SESERegionTree>();
case 1: return std::move(results[0]);
default:
auto result = new SequenceSESERegion(results);
return std::unique_ptr<SESERegionTree>(result);
}
}
/// Create a TF handle of the appropriate integer type for the given bitwith
/// that is initialized to the given value.
static SILValue createTFIntegerConst(GraphFunctionDeviceInfo &deviceInfo,
SILBuilder &builder, SILLocation location,
unsigned bitwidth, intmax_t value) {
ASTContext& context = builder.getASTContext();
SILType intType =
SILType::getBuiltinIntegerType(bitwidth, builder.getASTContext());
// Literals take attributes specifying the dtype, value, and device.
GraphOperationBuilder opBuilder("Const");
opBuilder.addAttribute(
{context.getIdentifier("dtype$dtype"),
convertSwiftTypeToConstantTFDataType(intType.getASTType())});
opBuilder.addAttribute(
{context.getIdentifier("value$tensor"),
SymbolicValue::getInteger(APInt(bitwidth, value),
context.getAllocator())});
deviceInfo.handleDevicePlacement(
"Const",
/*opDevice*/ getDeviceString(DeviceType::ALL),
builder.getModule().getASTContext(), &opBuilder);
GraphOperationInst *constNode = opBuilder.build(
builder, context, location,
{convertElementTypeToTensorValueType(intType)});
assert(constNode->getNumResults() == 1);
return constNode->getResults()[0];
}
namespace {
class BasicBlockCloner : public SILClonerWithScopes<BasicBlockCloner> {
private:
/// The flag to track if this cloner was used to clone any blocks.
bool cloned;
public:
BasicBlockCloner(SILFunction &F)
: SILClonerWithScopes(F), cloned(false) {}
bool hasCloned() const { return cloned; }
/// Create a block and clone everything except the instructions.
SILBasicBlock *initBlock(SILBasicBlock *bb) {
auto bbIt = BBMap.find(bb);
if (bbIt != BBMap.end())
return bbIt->second;
cloned = true;
SILFunction &F = getBuilder().getFunction();
SILBasicBlock *newBB = F.createBasicBlock();
getBuilder().setInsertionPoint(newBB);
BBMap[bb] = newBB;
// If the basic block has arguments, clone them as well.
for (auto *arg : bb->getArguments()) {
// Create a new argument and copy it into the ValueMap so future
// references use it.
ValueMap[arg] = newBB->createPhiArgument(
arg->getType(), arg->getOwnershipKind(), arg->getDecl());
}
return newBB;
}
// Clone all the instructions and return the cloned block.
SILBasicBlock *cloneBlock(SILBasicBlock * bb) {
auto bbIt = BBMap.find(bb);
assert (bbIt != BBMap.end() && "Block is not initialied before cloning.");
SILBasicBlock *newBB = bbIt->second;
getBuilder().setInsertionPoint(newBB);
for (auto &inst : *bb) {
visit(&inst);
}
return newBB;
}
SILBasicBlock *initAndCloneBlock(SILBasicBlock * bb) {
initBlock(bb);
return cloneBlock(bb);
}
/// Handle references to basic blocks when cloning.
SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) {
// If the block was not cloned by this cloner, directly reference it.
// Otherwise, use the cloned block.
auto bbIt = BBMap.find(bb);
if (bbIt != BBMap.end())
return bbIt->second;
return bb;
}
SILValue remapValue(SILValue Value) {
auto VI = ValueMap.find(Value);
if (VI != ValueMap.end())
return VI->second;
return Value;
}
void updateValueMap(SILValue oldValue, SILValue newValue) {
auto emplaceResult = ValueMap.try_emplace(oldValue, newValue);
assert(emplaceResult.second && "Remapping value multiple times during SESE cloning.");
}
};
} // namespace
// A helper class to transform a loop to have a single exit from the header.
class SingleExitLoopTransformer {
public:
SingleExitLoopTransformer(GraphFunctionDeviceInfo *deviceInfo,
SILLoopInfo *LI, DominanceInfo *DI, SILLoop *loop,
PostDominanceInfo *PDI)
: deviceInfo(deviceInfo), DI(DI), PDI(PDI), LI(LI), loop(loop),
header(loop->getHeader()), preheader(loop->getLoopPreheader()),
latch(loop->getLoopLatch()), currentFn(header->getParent()),
oldHeaderNumArgs(header->getNumArguments()), hasUndefsAtPreheader(false) {
assert(preheader && "Canonicalization should have given us one preheader");
assert(latch && "Canonicalization should have given us one latch block");
initialize();
}
/// Transforms the loop to ensure it has a single exit from the header.
/// Returns true if the CFG was changed.
bool transform();
private:
// Helper functions
void initialize();
/// Return a map that captures information about what SILValue should be
/// used at the pre-header of the loop for every SILValue in the given
/// `values` set. If it cannot find a suitable SILValue for an entry in
/// `values`, the corresponding will be mapped to an undef.
llvm::DenseMap<SILValue, SILValue>
getPreheaderSubstMap(const SmallPtrSetImpl<SILValue> &values) const;
/// Transform the loop by moving and cloning nodes (as needed) so that the
/// nearest common post dominator of the current exit blocks becomes a single
/// exit block for the loop. Consider the following snippet:
/// while(...) {
/// if (...) {
/// ...
/// break;
/// }
/// }
/// Recall that the blocks within if do not belong to the while loop in the SIL
/// IR. The transformation implemented in this function has the effect of moving
/// the blocks back into the loop.
/// FIXME: Cloning is not implemented. Therefore, we don't get a single
/// exit block right now. Once all the required components are implemented,
/// we will get a single exit block.
void ensureSingleExitBlock();
/// Unroll the body of the loop once.
void unrollLoopBody();
/// Compute escaping values and what values to use as arguments at preheader.
llvm::DenseMap<SILValue, SILValue> computeEscapingValuesSubstMap() const;
/// Create a new header for the loop and return a pair consisting of
/// the new block and the phi argument corresponding to the exitIndex.
std::pair<SILBasicBlock *, SILValue> createNewHeader();
/// Create a new latch block and clone all the arguments from new header.
SILBasicBlock *createNewLatch(SILBasicBlock *newHeader);
/// Replace the preheader->header edge with preheader->newHeader edge
/// and update the arguments of preheader to match that of newheader.
void patchPreheader(SILBasicBlock *newHeader);
/// Patch the edges that go to the header and exit blocks to the new header or
/// latch block as appropriate. Also, return the map consisting of indices
/// assigned to the exit blocks.
llvm::DenseMap<SILBasicBlock *, intmax_t>
patchEdges(SILBasicBlock *newHeader, SILBasicBlock *latchBlock);
/// Build a demuxing if..then..else block that can be used as a single
/// exit block. This will not have any effect if there is a single exit
/// block already.
/// \param exitIndices Map from exit blocks to the assignedindex.
/// \param exitIndexArg SILValue that holds the exitIndex set in the loop.
SILBasicBlock *createNewExitBlockWithDemux(
const llvm::DenseMap<SILBasicBlock *, intmax_t> &exitIndices,
SILValue exitIndexArg);
/// Replace `termInst` with a new TermInst such that the edge to the target
/// consists of the original arguments and the given `newArgs` at the end.
/// The old termInst is removed and this returns the newly constructed one.
TermInst *appendArguments(TermInst *termInst, SILBasicBlock *target,
ArrayRef<SILValue> newArgs);
/// Return an undef value of the given type.
SILValue getUndef(SILType type) const {
return SILUndef::get(type, currentFn->getModule());
}
// Configuration for graph construction.
GraphFunctionDeviceInfo *deviceInfo;
DominanceInfo *DI;
PostDominanceInfo *PDI;
SILLoopInfo *LI;
SILLoop *loop;
SILBasicBlock *header;
SILBasicBlock *preheader;
SILBasicBlock *latch;
SILFunction *currentFn;
unsigned oldHeaderNumArgs;
/// Flag to track if we have undefs at preheader corresponding to escaping
/// values and exit args.
bool hasUndefsAtPreheader;
/// Equivalence classes induced by argument passing.
llvm::EquivalenceClasses<SILValue> equivalentValues;
/// exit blocks of the loop.
SmallPtrSet<SILBasicBlock*, 8> exitBlocks;
/// The list of edges that need to rewired to the newHeader or
// the new latchBlock as appropriate.
SmallVector<std::pair<const SILBasicBlock *, const SILBasicBlock *>, 8>
edgesToFix;
/// Identify the set of values that escape the loop. The key represents the
/// escaping value and the associated value will be used as the argument at
/// the preheader.
llvm::DenseMap<SILValue, SILValue> escapingValueSubstMap;
/// Similar to escapingValueSubstMap, but arguments of exit blocks.
llvm::DenseMap<SILValue, SILValue> exitArgSubstMap;
// Map from an arg of an exit block to the corresponding newly added header arg.
llvm::DenseMap<SILValue, SILValue> exitArgHeaderArgMap;
};
void SingleExitLoopTransformer::initialize() {
// Remember some information from the loop before it gets transformed.
// Build the equivalence classes induced by argument passing.
for (auto &bb : *currentFn) {
for (auto arg : bb.getArguments()) {
SmallVector<SILValue, 8> incomingValues;
arg->getIncomingPhiValues(incomingValues);
for (SILValue incomingValue : incomingValues) {
equivalentValues.unionSets(arg, incomingValue);
}
}
}
// Compute common exit block if needed.
if (TFNoUndefsInSESE) {
ensureSingleExitBlock();
}
// Initialize Exit Blocks
{
SmallVector<SILBasicBlock*, 8> exitBlockList;
// This API returns duplicates if there are multiple exiting edges to the
// same exit block. Dedup them here.
loop->getExitBlocks(exitBlockList);
exitBlocks.insert(exitBlockList.begin(), exitBlockList.end());
assert(!exitBlocks.empty() && "A loop has no exit blocks.");
if (TFNoUndefsInSESE) {
SmallPtrSet<SILValue, 8> exitArgs;
for (const SILBasicBlock *bb : exitBlocks) {
for (auto arg : bb->getArguments()) {
exitArgs.insert(arg);
}
}
exitArgSubstMap = getPreheaderSubstMap(exitArgs);
}
}
// All the exiting edges need to be rewired.
loop->getExitEdges(edgesToFix);
edgesToFix.emplace_back(latch, header);
// Split critical edges in edgesToFix before we do any transformations.
for (auto &edge : edgesToFix) {
SILBasicBlock *src = const_cast<SILBasicBlock *>(edge.first);
SILBasicBlock *tgt = const_cast<SILBasicBlock *>(edge.second);
SILBasicBlock *splitBlock = splitIfCriticalEdge(src, tgt, /*DT*/ nullptr, LI);
if (splitBlock != nullptr) {
// If the edge was critical then splitBlock would have been inserted
// between src and tgt as follows: src -> splitBlock -> tgt. Therefore,
// update src as the new edge to patch would be splitBlock -> tgt.
edge.first = splitBlock;
}
}
escapingValueSubstMap = computeEscapingValuesSubstMap();
}
void SingleExitLoopTransformer::ensureSingleExitBlock() {
BasicBlockCloner cloner(*currentFn);
// Identify the common post dominator
SILPrintContext printContext(llvm::dbgs());
SmallVector<SILBasicBlock*, 8> exitBlockList;
loop->getExitBlocks(exitBlockList);
auto exitBlockIter = exitBlockList.begin();
SILBasicBlock *nearestCommonPD = *exitBlockIter++;
while (exitBlockIter != exitBlockList.end()) {
nearestCommonPD =
PDI->findNearestCommonDominator(nearestCommonPD, *exitBlockIter++);
assert(nearestCommonPD);
}
LLVM_DEBUG(
llvm::dbgs() << "Common Exit Block : "
<< SILPrintContext(llvm::dbgs()).getID(nearestCommonPD)
<< "\n");
// Collect all the blocks from each exiting block up to nearest common PD.
SmallPtrSet<SILBasicBlock *, 32> blocksToBeMoved;
for (SILBasicBlock *exitBlock : exitBlockList) {
if (exitBlock == nearestCommonPD) continue;
SmallPtrSet<SILBasicBlock *, 32> worklist;
worklist.insert(exitBlock);
while (!worklist.empty()) {
SILBasicBlock *current = *worklist.begin();
blocksToBeMoved.insert(current);
worklist.erase(current);
for (unsigned edgeIdx : indices(current->getSuccessors())) {
// We have to look up the successors each time around the loop
// since if we split an edge, `current` block will have a new terminator
// implying a new successor list. The new edge will be placed at the
// same spot in the new terminator where the old edge was in the old
// terminator. Thus as long as we use indices, we will visit all edges
// appropriately and not deal with touching stale memory.
auto succs = current->getSuccessors();
auto *succ = succs[edgeIdx].getBB();
// Skip if (1) already processed or (2) reached common pd.
if (blocksToBeMoved.count(succ) > 0 || succ == nearestCommonPD) {
continue;
}
if (DI->properlyDominates(header, succ)) {
worklist.insert(succ);
continue;
}
// If `succ` is not dominated by `header`, then `succ` is reachable from
// a node outside of this loop. We might have to clone `succ` in such
// cases.
// Before cloning make sure that header -> succ is *not* backedge of a
// parent loop. This can happen when we have labeled breaks in loops. We
// cannot clone the blocks in such cases. Simply continue. This is still
// OK for our purposes because we will find an equivalent value at the
// header for any value that escapes along this edge.
if (DI->properlyDominates(succ, header)) continue;
// Clone the block and rewire the edge.
SILBasicBlock *clonedSucc = cloner.initAndCloneBlock(succ);
changeBranchTarget(current->getTerminator(), edgeIdx, clonedSucc,
/*preserveArgs*/ true);
worklist.insert(clonedSucc);
}
}
}
// Move blocks from each exitingBlock to commonExitBlock into the loop.
for (SILBasicBlock *outsideBlock : blocksToBeMoved) {
// Make sure that outsideBlock is *ONLY* reachable from a block in the loop
// or blocksToBeMoved.
assert(llvm::all_of(outsideBlock->getPredecessorBlocks(),
[this, &blocksToBeMoved](const SILBasicBlock *pred) {
return LI->getLoopFor(pred) == loop ||
blocksToBeMoved.count(pred) > 0;
}) &&
"Nodes being moved are reachable from outside loop.");
// Update loop info if this belongs to a parent loop.
SILLoop *outsideBlockLoop = LI->getLoopFor(outsideBlock);
if (outsideBlockLoop != nullptr) {
// FIXME: We don't deal with cases where the nodes being moved in
// belong to another loop yet. e.g.,
// while ... {
// if ... {
// for(...) {...}
// break;
// }
// }
// Check that `loop` is nested within `reachableLoop`.
assert(outsideBlockLoop->contains(loop) &&
"Nodes being moved belong to a non-nested loop.");
// Move the node into our loop.
outsideBlockLoop->removeBlockFromLoop(outsideBlock);
LI->changeLoopFor(outsideBlock, nullptr);
// top-level loop is already correct.
}
loop->addBasicBlockToLoop(outsideBlock, LI->getBase());
}
if (cloner.hasCloned()) {
// TODO(https://bugs.swift.org/browse/SR-8336): the transformations here are
// simple that we should be able to incrementally update the DI & PDI.
DI->recalculate(*currentFn);
PDI->recalculate(*currentFn);
}
}
llvm::DenseMap<SILValue, SILValue>
SingleExitLoopTransformer::computeEscapingValuesSubstMap() const {
llvm::SmallPtrSet<SILValue, 8> escapingValues;
for (const SILBasicBlock *bb : loop->getBlocks()) {
// Save the values that are escaping this loop in result set.
auto saveEscaping = [this, &escapingValues](SILValue value) {
for (const auto *use : value->getUses()) {
const SILInstruction *useInst = use->getUser();
if (!loop->contains(useInst->getParent())) {
escapingValues.insert(value);
break;
}
}
};
if (bb != header) {
llvm::for_each(bb->getArguments(), saveEscaping);
}
for (const SILInstruction &inst : *bb) {
llvm::for_each(inst.getResults(), saveEscaping);
}
}
return getPreheaderSubstMap(escapingValues);
}
llvm::DenseMap<SILValue, SILValue>
SingleExitLoopTransformer::getPreheaderSubstMap(
const SmallPtrSetImpl<SILValue> &values) const {
llvm::DenseMap<SILValue, SILValue> result;
for (const SILValue value : values) {
result[value] = getUndef(value->getType());
}
// Do not eliminate undefs unless requested for.
if (!TFNoUndefsInSESE) return result;
// Replace undef with an equivalent value that is available at preheader.
for (auto &kv : result) {
const SILValue &escapingValue = kv.first;
// Get the member iterator for the equivalence class of escapingValue.
auto member_begin = equivalentValues.findLeader(escapingValue);
// Iterate over *all* the members and find an equivalent value that
// dominates the terminator instruction of the preheader.
for (auto equivalentValue :
make_range(member_begin, equivalentValues.member_end())) {
if (DI->properlyDominates(equivalentValue, preheader->getTerminator())) {
// Found a definition that we could use.
kv.second = equivalentValue;
break;
}
}
}
return result;
}
/// Appends the given arguments to the given edge. Deletes the old TermInst
/// and returns a new TermInst with the appropriate number of args.
TermInst *SingleExitLoopTransformer::appendArguments(
TermInst *termInst, SILBasicBlock *target, ArrayRef<SILValue> newArgs) {
SILBuilder builder(termInst);
TermInst *newTermInst = nullptr;
auto createArgsArray = [target,
&newArgs](const OperandValueArrayRef ¤tArgs,
bool appendNewArgs) {
SmallVector<SILValue, 8> args(currentArgs.begin(), currentArgs.end());
if (appendNewArgs) {
args.append(newArgs.begin(), newArgs.end());
assert(args.size() == target->getNumArguments() &&
"Number of final arguments does not match target's arguments.");
}
return args;
};
if (auto *branch = dyn_cast<BranchInst>(termInst)) {
assert(branch->getDestBB() == target &&
"Incoming edge block and target do not match");
SmallVector<SILValue, 8> args =
createArgsArray(branch->getArgs(), /*appendNewArgs=*/true);
newTermInst =
builder.createBranch(branch->getLoc(), branch->getDestBB(), args);
} else if (auto *condBranch = dyn_cast<CondBranchInst>(termInst)) {
// At the moment we can only add arguments to br and cond_br.
assert(condBranch && "Terminator is not a branch or conditional branch.");
bool isTrueEdge = condBranch->getTrueBB() == target;
assert(((isTrueEdge && condBranch->getTrueBB() == target) ||
(!isTrueEdge && condBranch->getFalseBB() == target)) &&
"Incoming edge block and target do not match");
SmallVector<SILValue, 8> trueArgs = createArgsArray(
condBranch->getTrueArgs(), /*appendNewArgs=*/isTrueEdge);
SmallVector<SILValue, 8> falseArgs = createArgsArray(
condBranch->getFalseArgs(), /*appendNewArgs=*/!isTrueEdge);
newTermInst = builder.createCondBranch(
condBranch->getLoc(), condBranch->getCondition(),
condBranch->getTrueBB(), trueArgs, condBranch->getFalseBB(), falseArgs,
condBranch->getTrueBBCount(), condBranch->getFalseBBCount());
}
// Remove the old terminator instruction.
termInst->dropAllReferences();
termInst->eraseFromParent();
return newTermInst;
}
std::pair<SILBasicBlock *, SILValue>
SingleExitLoopTransformer::createNewHeader() {
SILBuilder builder(header);
ASTContext &context = builder.getASTContext();
SILBasicBlock *newHeader = currentFn->createBasicBlock();
loop->addBasicBlockToLoop(newHeader, LI->getBase()); // Should be done first.
// Clone arguments and change all uses to the new header's arguments.
newHeader->cloneArgumentList(header);
SmallVector<SILValue, 8> headerArgs;
for (auto i : indices(header->getArguments())) {
header->getArgument(i)->replaceAllUsesWith(newHeader->getArgument(i));
headerArgs.push_back(newHeader->getArgument(i));
}
header->dropAllArguments();
// Add phi arguments in the new header corresponding to the escaping values.
for (const auto &kv : escapingValueSubstMap) {
SILValue escapingValue = kv.first;
SILValue newValue = newHeader->createPhiArgument(
escapingValue->getType(), escapingValue.getOwnershipKind());
// Replace uses *outside* of the loop with the new value.
auto UI = escapingValue->use_begin(), E = escapingValue->use_end();
while (UI != E) {
Operand *use = *UI;
// Increment iterator before we invalidate it
// when we invoke Operand::Set below.
++UI;
if (loop->contains(use->getUser()->getParent())) {
continue;
}
use->set(newValue);
}
}
if (TFNoUndefsInSESE) {
// Add arguments in the new header corresponding to exit block arguments.
for (const auto &kv : exitArgSubstMap) {
SILValue arg = kv.first;
SILValue newValue =
newHeader->createPhiArgument(arg->getType(), arg.getOwnershipKind());
exitArgHeaderArgMap[kv.first] = newValue;
}
}
// An integer to identify the exit edge.
SILValue exitIndexArg = newHeader->createPhiArgument(
convertElementTypeToTensorValueType(
SILType::getBuiltinIntegerType(32, context)),
ValueOwnershipKind::Owned);
// A boolean corresponding to the stayInLoop flag.
newHeader->createPhiArgument(convertElementTypeToTensorValueType(
SILType::getBuiltinIntegerType(1, context)),
ValueOwnershipKind::Trivial);
return std::make_pair(newHeader, exitIndexArg);
}
SILBasicBlock *
SingleExitLoopTransformer::createNewLatch(SILBasicBlock *newHeader) {
// Create a new latch block.
SILBasicBlock *latchBlock = currentFn->createBasicBlock();
latchBlock->cloneArgumentList(newHeader);
SILBuilder builder(latchBlock);
SmallVector<SILValue, 8> latchArgs;
for (const SILArgument *latchArg : latchBlock->getArguments()) {
latchArgs.push_back(latchArg);
}
builder.createBranch(
getUserSourceLocation(header->getTerminator()->getDebugLocation()),
newHeader, latchArgs);
loop->addBasicBlockToLoop(latchBlock, LI->getBase());
return latchBlock;
}
void SingleExitLoopTransformer::patchPreheader(SILBasicBlock *newHeader) {
// Update edge
replaceBranchTarget(preheader->getTerminator(), header, newHeader,
/*preserveArgs*/ true);
SILBuilder builder(preheader->getTerminator());
SILLocation location(
getUserSourceLocation(preheader->getTerminator()->getDebugLocation()));
// Add arguments corresponding to escaping arguments.
// State from within the loop is not available in the preheader.
// Simply pass in an undef. This will never be accessed at runtime.
SmallVector<SILValue, 8> newArgs;
for (const auto &kv : escapingValueSubstMap) {
hasUndefsAtPreheader |= isa<SILUndef>(kv.second);
newArgs.push_back(kv.second);
}
if (TFNoUndefsInSESE) {
for (const auto &kv : exitArgSubstMap) {
hasUndefsAtPreheader |= isa<SILUndef>(kv.second);
newArgs.push_back(kv.second);
}
}
// `exitIndex` to identify the block to which we exit from the loop.
newArgs.push_back(createTFIntegerConst(*deviceInfo, builder, location,
/*bitwidth*/ 32,
/*exitIndex*/ 0));
// `stayInLoop` flag
newArgs.push_back(createTFIntegerConst(*deviceInfo, builder, location,
/*bitwidth*/ 1, true));
appendArguments(preheader->getTerminator(), newHeader, newArgs);
}
llvm::DenseMap<SILBasicBlock *, intmax_t>
SingleExitLoopTransformer::patchEdges(SILBasicBlock *newHeader,
SILBasicBlock *latchBlock) {
llvm::DenseMap<SILBasicBlock *, intmax_t> exitIndices;
// Identify the exit from the header (if any) and assign '0' as its index.
SILBasicBlock *headerExit = nullptr;
for (SILBasicBlock *succ : header->getSuccessorBlocks()) {
if (loop->contains(succ)) continue;
assert(headerExit == nullptr && "Loop header has more than one exit node.");
headerExit = succ;
}
// Note: headerExit can be a nullptr as header need not be an exiting block.
// e.g.,
// while (true) {
// ...
// if ... break;
// }
if (headerExit != nullptr) {
exitIndices.insert({headerExit, 0});
}
// edgesToFix also has edges to the header that are not exit edges.
// To simplify the code below, simply assign an arbitrary value (say 0)
// as the index for header. This would get simplified once we unify the
// stayInloop flag and exitIndex into one value.
exitIndices.insert({header, 0});
unsigned nextExitIndex = 1;
for (const auto &edge : edgesToFix) {
SILBasicBlock *src = const_cast<SILBasicBlock *>(edge.first);
SILBasicBlock *tgt = const_cast<SILBasicBlock *>(edge.second);
SILBasicBlock *newTgt = latchBlock;
bool stayInLoop = loop->contains(tgt);
// Track the incoming value for the exit arguments if this is an exit edge
// with arguments. This will be used to unify all the values to be passed
// to the exit nodes in the loop header.
//
llvm::DenseMap<SILValue, SILValue> exitArgIncomingValue;
if (TFNoUndefsInSESE && !stayInLoop && tgt->getNumArguments() != 0) {
auto *termInst = src->getTerminator();
auto *branch = dyn_cast<BranchInst>(termInst);
assert(branch && "Critical edges should have been split.");
for (unsigned i = 0; i < branch->getNumArgs(); ++i) {
exitArgIncomingValue[tgt->getArgument(i)] = branch->getArg(i);
}
}
replaceBranchTarget(src->getTerminator(), tgt, newTgt,
/*preserveArgs=*/stayInLoop);
// Set up additional arguments.
// If we are exiting the loop, then we should simply use newHeader args.
SmallVector<SILValue, 8> newArgs;
if (!stayInLoop) {
for (unsigned i = 0; i < oldHeaderNumArgs; ++i) {
newArgs.push_back(newHeader->getArgument(i));
}
}
SILBuilder builder(src->getTerminator());
SILLocation location(
getUserSourceLocation(src->getTerminator()->getDebugLocation()));
// Find an appropriate value to use for each escaping value.
unsigned argIndex = oldHeaderNumArgs;
for (const auto &kv : escapingValueSubstMap) {
const SILValue escapingValue = kv.first;
if (DI->properlyDominates(escapingValue, src->getTerminator())) {
newArgs.push_back(escapingValue);
} else {
newArgs.push_back(newHeader->getArgument(argIndex));
}
++argIndex;
}
if (TFNoUndefsInSESE) {
// Let p0, p1, ..pn be the arguments at new header corresponding to exit
// arguments a0, a1, a2, ..., an. For a exit edge
// br exit_i(x, y) -> exit_i(a2, a3)`,
// change the source as br new_latch(p0, p1, x, y, p4, ..., pn)
for (const auto &kv : exitArgSubstMap) {
const SILValue exitArg = kv.first;
auto iter = exitArgIncomingValue.find(exitArg);
if (iter != exitArgIncomingValue.end()) {
newArgs.push_back(iter->second);
} else {
newArgs.push_back(newHeader->getArgument(argIndex));
}
++argIndex;
}
}
// `exitIndex` to identify the block to which we exit from the loop.
// (insert a new value or get the old key value pair.)
auto emplaceResult = exitIndices.try_emplace(tgt, nextExitIndex);
if (emplaceResult.second) {
// Increment index as we inserted a new entry into the table.
++nextExitIndex;
}
auto kvPair = *emplaceResult.first;
newArgs.push_back(createTFIntegerConst(*deviceInfo, builder, location,
/*bitwidth*/ 32,
/*exitIndex*/ kvPair.second));
// `stayInLoop` flag
newArgs.push_back(createTFIntegerConst(*deviceInfo, builder, location,
/*bitwidth*/ 1, stayInLoop));
appendArguments(src->getTerminator(), newTgt, newArgs);
}
return exitIndices;
}