diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index e66a1a2ccb318..1caed93b1b668 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -96,6 +96,11 @@ static cl::opt cl::desc("View the CFG before DFA Jump Threading"), cl::Hidden, cl::init(false)); +static cl::opt EarlyExitHeuristic( + "dfa-early-exit-heuristic", + cl::desc("Exit early if an unpredictable value come from the same loop"), + cl::Hidden, cl::init(true)); + static cl::opt MaxPathLength( "dfa-max-path-length", cl::desc("Max number of blocks searched to find a threading path"), @@ -405,7 +410,7 @@ struct MainSwitch { /// /// Also, collect select instructions to unfold. bool isCandidate(const SwitchInst *SI) { - std::deque Q; + std::deque> Q; SmallSet SeenValues; SelectInsts.clear(); @@ -415,25 +420,28 @@ struct MainSwitch { return false; // The switch must be in a loop. - if (!LI->getLoopFor(SI->getParent())) + const Loop *L = LI->getLoopFor(SI->getParent()); + if (!L) return false; - addToQueue(SICond, Q, SeenValues); + addToQueue(SICond, nullptr, Q, SeenValues); while (!Q.empty()) { - Value *Current = Q.front(); + Value *Current = Q.front().first; + BasicBlock *CurrentIncomingBB = Q.front().second; Q.pop_front(); if (auto *Phi = dyn_cast(Current)) { - for (Value *Incoming : Phi->incoming_values()) { - addToQueue(Incoming, Q, SeenValues); + for (BasicBlock *IncomingBB : Phi->blocks()) { + Value *Incoming = Phi->getIncomingValueForBlock(IncomingBB); + addToQueue(Incoming, IncomingBB, Q, SeenValues); } LLVM_DEBUG(dbgs() << "\tphi: " << *Phi << "\n"); } else if (SelectInst *SelI = dyn_cast(Current)) { if (!isValidSelectInst(SelI)) return false; - addToQueue(SelI->getTrueValue(), Q, SeenValues); - addToQueue(SelI->getFalseValue(), Q, SeenValues); + addToQueue(SelI->getTrueValue(), CurrentIncomingBB, Q, SeenValues); + addToQueue(SelI->getFalseValue(), CurrentIncomingBB, Q, SeenValues); LLVM_DEBUG(dbgs() << "\tselect: " << *SelI << "\n"); if (auto *SelIUse = dyn_cast(SelI->user_back())) SelectInsts.push_back(SelectInstToUnfold(SelI, SelIUse)); @@ -446,6 +454,18 @@ struct MainSwitch { // initial switch values that can be ignored (they will hit the // unthreaded switch) but this assumption will get checked later after // paths have been enumerated (in function getStateDefMap). + + // If the unpredictable value comes from the same inner loop it is + // likely that it will also be on the enumerated paths, causing us to + // exit after we have enumerated all the paths. This heuristic save + // compile time because a search for all the paths can become expensive. + if (EarlyExitHeuristic && + L->contains(LI->getLoopFor(CurrentIncomingBB))) { + LLVM_DEBUG(dbgs() + << "\tExiting early due to unpredictability heuristic.\n"); + return false; + } + continue; } } @@ -453,11 +473,12 @@ struct MainSwitch { return true; } - void addToQueue(Value *Val, std::deque &Q, + void addToQueue(Value *Val, BasicBlock *BB, + std::deque> &Q, SmallSet &SeenValues) { if (SeenValues.contains(Val)) return; - Q.push_back(Val); + Q.push_back({Val, BB}); SeenValues.insert(Val); } diff --git a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll index df725b9a7fa47..696bd55d2dfdd 100644 --- a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll +++ b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt -S -passes=dfa-jump-threading %s | FileCheck %s +; RUN: opt -S -passes=dfa-jump-threading -dfa-early-exit-heuristic=false %s | FileCheck %s ; These tests check if selects are unfolded properly for jump threading ; opportunities. There are three different patterns to consider: diff --git a/llvm/test/Transforms/DFAJumpThreading/unpredictable-heuristic.ll b/llvm/test/Transforms/DFAJumpThreading/unpredictable-heuristic.ll new file mode 100644 index 0000000000000..9743f0acc8165 --- /dev/null +++ b/llvm/test/Transforms/DFAJumpThreading/unpredictable-heuristic.ll @@ -0,0 +1,124 @@ +; REQUIRES: asserts +; RUN: opt -S -passes=dfa-jump-threading %s -debug-only=dfa-jump-threading 2>&1 | FileCheck %s + +; CHECK-COUNT-3: Exiting early due to unpredictability heuristic. + +@.str.1 = private unnamed_addr constant [3 x i8] c"10\00", align 1 +@.str.2 = private unnamed_addr constant [3 x i8] c"30\00", align 1 +@.str.3 = private unnamed_addr constant [3 x i8] c"20\00", align 1 +@.str.4 = private unnamed_addr constant [3 x i8] c"40\00", align 1 + +define void @test1(i32 noundef %num, i32 noundef %num2) { +entry: + br label %while.body + +while.body: ; preds = %entry, %sw.epilog + %num.addr.0 = phi i32 [ %num, %entry ], [ %num.addr.1, %sw.epilog ] + switch i32 %num.addr.0, label %sw.default [ + i32 10, label %sw.bb + i32 30, label %sw.bb1 + i32 20, label %sw.bb2 + i32 40, label %sw.bb3 + ] + +sw.bb: ; preds = %while.body + %call.i = tail call i32 @bar(ptr noundef nonnull @.str.1) + br label %sw.epilog + +sw.bb1: ; preds = %while.body + %call.i4 = tail call i32 @bar(ptr noundef nonnull @.str.2) + br label %sw.epilog + +sw.bb2: ; preds = %while.body + %call.i5 = tail call i32 @bar(ptr noundef nonnull @.str.3) + br label %sw.epilog + +sw.bb3: ; preds = %while.body + %call.i6 = tail call i32 @bar(ptr noundef nonnull @.str.4) + %call = tail call noundef i32 @foo() + %add = add nsw i32 %call, %num2 + br label %sw.epilog + +sw.default: ; preds = %while.body + ret void + +sw.epilog: ; preds = %sw.bb3, %sw.bb2, %sw.bb1, %sw.bb + %num.addr.1 = phi i32 [ %add, %sw.bb3 ], [ 40, %sw.bb2 ], [ 20, %sw.bb1 ], [ 30, %sw.bb ] + br label %while.body +} + + +define void @test2(i32 noundef %num, i32 noundef %num2) { +entry: + br label %while.body + +while.body: ; preds = %entry, %sw.epilog + %num.addr.0 = phi i32 [ %num, %entry ], [ %num.addr.1, %sw.epilog ] + switch i32 %num.addr.0, label %sw.default [ + i32 10, label %sw.epilog + i32 30, label %sw.bb1 + i32 20, label %sw.bb2 + i32 40, label %sw.bb3 + ] + +sw.bb1: ; preds = %while.body + br label %sw.epilog + +sw.bb2: ; preds = %while.body + br label %sw.epilog + +sw.bb3: ; preds = %while.body + br label %sw.epilog + +sw.default: ; preds = %while.body + ret void + +sw.epilog: ; preds = %while.body, %sw.bb3, %sw.bb2, %sw.bb1 + %.str.4.sink = phi ptr [ @.str.4, %sw.bb3 ], [ @.str.3, %sw.bb2 ], [ @.str.2, %sw.bb1 ], [ @.str.1, %while.body ] + %num.addr.1 = phi i32 [ %num2, %sw.bb3 ], [ 40, %sw.bb2 ], [ 20, %sw.bb1 ], [ 30, %while.body ] + %call.i6 = tail call i32 @bar(ptr noundef nonnull %.str.4.sink) + br label %while.body +} + + +define void @test3(i32 noundef %num, i32 noundef %num2) { +entry: + %add = add nsw i32 %num2, 40 + br label %while.body + +while.body: ; preds = %entry, %sw.epilog + %num.addr.0 = phi i32 [ %num, %entry ], [ %num.addr.1, %sw.epilog ] + switch i32 %num.addr.0, label %sw.default [ + i32 10, label %sw.bb + i32 30, label %sw.bb1 + i32 20, label %sw.bb2 + i32 40, label %sw.bb3 + ] + +sw.bb: ; preds = %while.body + %call.i = tail call i32 @bar(ptr noundef nonnull @.str.1) + br label %sw.epilog + +sw.bb1: ; preds = %while.body + %call.i5 = tail call i32 @bar(ptr noundef nonnull @.str.2) + br label %sw.epilog + +sw.bb2: ; preds = %while.body + %call.i6 = tail call i32 @bar(ptr noundef nonnull @.str.3) + br label %sw.epilog + +sw.bb3: ; preds = %while.body + %call.i7 = tail call i32 @bar(ptr noundef nonnull @.str.4) + br label %sw.epilog + +sw.default: ; preds = %while.body + ret void + +sw.epilog: ; preds = %sw.bb3, %sw.bb2, %sw.bb1, %sw.bb + %num.addr.1 = phi i32 [ %add, %sw.bb3 ], [ 40, %sw.bb2 ], [ 20, %sw.bb1 ], [ 30, %sw.bb ] + br label %while.body +} + + +declare noundef i32 @foo() +declare noundef i32 @bar(ptr nocapture noundef readonly)