Skip to content

Commit 2873d9f

Browse files
[LoopIdiomVectorize] Fix FindFirstByte successors (#156945)
This refactors fixSuccessorPhis from LoopIdiomVectorize::transformByteCompare and uses it in LoopIdiomVectorize::expandFindFirstByte to ensure that all successor Phis have incoming values from the vector basic blocks. Fixes #156588. --------- Co-authored-by: Ricardo Jesus <rjj@nvidia.com>
1 parent 8748581 commit 2873d9f

File tree

2 files changed

+341
-51
lines changed

2 files changed

+341
-51
lines changed

llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,10 @@ class LoopIdiomVectorize {
170170
bool recognizeFindFirstByte();
171171

172172
Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU,
173-
unsigned VF, Type *CharTy, BasicBlock *ExitSucc,
174-
BasicBlock *ExitFail, Value *SearchStart,
175-
Value *SearchEnd, Value *NeedleStart,
176-
Value *NeedleEnd);
173+
unsigned VF, Type *CharTy, Value *IndPhi,
174+
BasicBlock *ExitSucc, BasicBlock *ExitFail,
175+
Value *SearchStart, Value *SearchEnd,
176+
Value *NeedleStart, Value *NeedleEnd);
177177

178178
void transformFindFirstByte(PHINode *IndPhi, unsigned VF, Type *CharTy,
179179
BasicBlock *ExitSucc, BasicBlock *ExitFail,
@@ -242,6 +242,37 @@ bool LoopIdiomVectorize::run(Loop *L) {
242242
return false;
243243
}
244244

245+
static void fixSuccessorPhis(Loop *L, Value *ScalarRes, Value *VectorRes,
246+
BasicBlock *SuccBB, BasicBlock *IncBB) {
247+
for (PHINode &PN : SuccBB->phis()) {
248+
// Look through the incoming values to find ScalarRes, meaning this is a
249+
// PHI collecting the results of the transformation.
250+
bool ResPhi = false;
251+
for (Value *Op : PN.incoming_values())
252+
if (Op == ScalarRes) {
253+
ResPhi = true;
254+
break;
255+
}
256+
257+
// Any PHI that depended upon the result of the transformation needs a new
258+
// incoming value from IncBB.
259+
if (ResPhi)
260+
PN.addIncoming(VectorRes, IncBB);
261+
else {
262+
// There should be no other outside uses of other values in the
263+
// original loop. Any incoming values should either:
264+
// 1. Be for blocks outside the loop, which aren't interesting. Or ..
265+
// 2. These are from blocks in the loop with values defined outside
266+
// the loop. We should a similar incoming value from CmpBB.
267+
for (BasicBlock *BB : PN.blocks())
268+
if (L->contains(BB)) {
269+
PN.addIncoming(PN.getIncomingValueForBlock(BB), IncBB);
270+
break;
271+
}
272+
}
273+
}
274+
}
275+
245276
bool LoopIdiomVectorize::recognizeByteCompare() {
246277
// Currently the transformation only works on scalable vector types, although
247278
// there is no fundamental reason why it cannot be made to work for fixed
@@ -935,42 +966,10 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
935966
DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}});
936967
}
937968

938-
auto fixSuccessorPhis = [&](BasicBlock *SuccBB) {
939-
for (PHINode &PN : SuccBB->phis()) {
940-
// At this point we've already replaced all uses of the result from the
941-
// loop with ByteCmp. Look through the incoming values to find ByteCmp,
942-
// meaning this is a Phi collecting the results of the byte compare.
943-
bool ResPhi = false;
944-
for (Value *Op : PN.incoming_values())
945-
if (Op == ByteCmpRes) {
946-
ResPhi = true;
947-
break;
948-
}
949-
950-
// Any PHI that depended upon the result of the byte compare needs a new
951-
// incoming value from CmpBB. This is because the original loop will get
952-
// deleted.
953-
if (ResPhi)
954-
PN.addIncoming(ByteCmpRes, CmpBB);
955-
else {
956-
// There should be no other outside uses of other values in the
957-
// original loop. Any incoming values should either:
958-
// 1. Be for blocks outside the loop, which aren't interesting. Or ..
959-
// 2. These are from blocks in the loop with values defined outside
960-
// the loop. We should a similar incoming value from CmpBB.
961-
for (BasicBlock *BB : PN.blocks())
962-
if (CurLoop->contains(BB)) {
963-
PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB);
964-
break;
965-
}
966-
}
967-
}
968-
};
969-
970969
// Ensure all Phis in the successors of CmpBB have an incoming value from it.
971-
fixSuccessorPhis(EndBB);
970+
fixSuccessorPhis(CurLoop, ByteCmpRes, ByteCmpRes, EndBB, CmpBB);
972971
if (EndBB != FoundBB)
973-
fixSuccessorPhis(FoundBB);
972+
fixSuccessorPhis(CurLoop, ByteCmpRes, ByteCmpRes, FoundBB, CmpBB);
974973

975974
// The new CmpBB block isn't part of the loop, but will need to be added to
976975
// the outer loop if there is one.
@@ -1168,8 +1167,9 @@ bool LoopIdiomVectorize::recognizeFindFirstByte() {
11681167

11691168
Value *LoopIdiomVectorize::expandFindFirstByte(
11701169
IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy,
1171-
BasicBlock *ExitSucc, BasicBlock *ExitFail, Value *SearchStart,
1172-
Value *SearchEnd, Value *NeedleStart, Value *NeedleEnd) {
1170+
Value *IndPhi, BasicBlock *ExitSucc, BasicBlock *ExitFail,
1171+
Value *SearchStart, Value *SearchEnd, Value *NeedleStart,
1172+
Value *NeedleEnd) {
11731173
// Set up some types and constants that we intend to reuse.
11741174
auto *PtrTy = Builder.getPtrTy();
11751175
auto *I64Ty = Builder.getInt64Ty();
@@ -1369,6 +1369,12 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
13691369
MatchLCSSA->addIncoming(Search, BB2);
13701370
MatchPredLCSSA->addIncoming(MatchPred, BB2);
13711371

1372+
// Ensure all Phis in the successors of BB3/BB5 have an incoming value from
1373+
// them.
1374+
fixSuccessorPhis(CurLoop, IndPhi, MatchVal, ExitSucc, BB3);
1375+
if (ExitSucc != ExitFail)
1376+
fixSuccessorPhis(CurLoop, IndPhi, MatchVal, ExitFail, BB5);
1377+
13721378
if (VerifyLoops) {
13731379
OuterLoop->verifyLoop();
13741380
InnerLoop->verifyLoop();
@@ -1390,21 +1396,12 @@ void LoopIdiomVectorize::transformFindFirstByte(
13901396
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
13911397
Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
13921398

1393-
Value *MatchVal =
1394-
expandFindFirstByte(Builder, DTU, VF, CharTy, ExitSucc, ExitFail,
1395-
SearchStart, SearchEnd, NeedleStart, NeedleEnd);
1399+
expandFindFirstByte(Builder, DTU, VF, CharTy, IndPhi, ExitSucc, ExitFail,
1400+
SearchStart, SearchEnd, NeedleStart, NeedleEnd);
13961401

13971402
assert(PHBranch->isUnconditional() &&
13981403
"Expected preheader to terminate with an unconditional branch.");
13991404

1400-
// Add new incoming values with the result of the transformation to PHINodes
1401-
// of ExitSucc that use IndPhi.
1402-
for (auto *U : llvm::make_early_inc_range(IndPhi->users())) {
1403-
auto *PN = dyn_cast<PHINode>(U);
1404-
if (PN && PN->getParent() == ExitSucc)
1405-
PN->addIncoming(MatchVal, cast<Instruction>(MatchVal)->getParent());
1406-
}
1407-
14081405
if (VerifyLoops && CurLoop->getParentLoop()) {
14091406
CurLoop->getParentLoop()->verifyLoop();
14101407
if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))

0 commit comments

Comments
 (0)