diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index 66e45ecbde7df..1d906270fe046 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -575,7 +575,7 @@ struct AllSwitchPaths { AllSwitchPaths(const MainSwitch *MSwitch, OptimizationRemarkEmitter *ORE, LoopInfo *LI, Loop *L) : Switch(MSwitch->getInstr()), SwitchBlock(Switch->getParent()), ORE(ORE), - LI(LI), SwitchOuterLoop(L) {} + DefaultDest(nullptr), LI(LI), SwitchOuterLoop(L) {} std::vector &getThreadingPaths() { return TPaths; } unsigned getNumThreadingPaths() { return TPaths.size(); } @@ -587,6 +587,30 @@ struct AllSwitchPaths { unifyTPaths(); } + /// Fast helper to get the successor corresponding to a particular case value + /// for a switch statement. + BasicBlock *getNextCaseSuccessor(const APInt &NextState) { + // Precompute the value => successor mapping + if (CaseValToDest.empty()) { + for (auto Case : Switch->cases()) { + APInt CaseVal = Case.getCaseValue()->getValue(); + CaseValToDest[CaseVal] = Case.getCaseSuccessor(); + } + DefaultDest = Switch->getDefaultDest(); + } + + auto SuccIt = CaseValToDest.find(NextState); + return SuccIt == CaseValToDest.end() ? DefaultDest : SuccIt->second; + } + + void updateDefaultDest(BasicBlock *DefaultDest) { + this->DefaultDest = DefaultDest; + } + + void updateNextCase(const APInt &NextState, BasicBlock *NextCase) { + CaseValToDest[NextState] = NextCase; + } + private: // Value: an instruction that defines a switch state; // Key: the parent basic block of that instruction. @@ -818,22 +842,6 @@ struct AllSwitchPaths { TPaths = std::move(TempList); } - /// Fast helper to get the successor corresponding to a particular case value - /// for a switch statement. - BasicBlock *getNextCaseSuccessor(const APInt &NextState) { - // Precompute the value => successor mapping - if (CaseValToDest.empty()) { - for (auto Case : Switch->cases()) { - APInt CaseVal = Case.getCaseValue()->getValue(); - CaseValToDest[CaseVal] = Case.getCaseSuccessor(); - } - } - - auto SuccIt = CaseValToDest.find(NextState); - return SuccIt == CaseValToDest.end() ? Switch->getDefaultDest() - : SuccIt->second; - } - // Two states are equivalent if they have the same switch destination. // Unify the states in different threading path if the states are equivalent. void unifyTPaths() { @@ -858,6 +866,7 @@ struct AllSwitchPaths { OptimizationRemarkEmitter *ORE; std::vector TPaths; DenseMap CaseValToDest; + BasicBlock *DefaultDest; LoopInfo *LI; Loop *SwitchOuterLoop; }; @@ -1159,24 +1168,6 @@ struct TransformDFA { SSAUpdate.RewriteAllUses(&DTU->getDomTree()); } - /// Helper to get the successor corresponding to a particular case value for - /// a switch statement. - /// TODO: Unify it with SwitchPaths->getNextCaseSuccessor(SwitchInst *Switch) - /// by updating cached value => successor mapping during threading. - static BasicBlock *getNextCaseSuccessor(SwitchInst *Switch, - const APInt &NextState) { - BasicBlock *NextCase = nullptr; - for (auto Case : Switch->cases()) { - if (Case.getCaseValue()->getValue() == NextState) { - NextCase = Case.getCaseSuccessor(); - break; - } - } - if (!NextCase) - NextCase = Switch->getDefaultDest(); - return NextCase; - } - /// Clones a basic block, and adds it to the CFG. /// /// This function also includes updating phi nodes in the successors of the @@ -1231,8 +1222,7 @@ struct TransformDFA { // If BB is the last block in the path, we can simply update the one case // successor that will be reached. if (BB == SwitchPaths->getSwitchBlock()) { - SwitchInst *Switch = SwitchPaths->getSwitchInst(); - BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState); + BasicBlock *NextCase = SwitchPaths->getNextCaseSuccessor(NextState); BlocksToUpdate.push_back(NextCase); BasicBlock *ClonedSucc = getClonedBB(NextCase, NextState, DuplicateMap); if (ClonedSucc) @@ -1283,6 +1273,15 @@ struct TransformDFA { return; Instruction *PrevTerm = PrevBB->getTerminator(); + // Update cached value => destination mapping. + if (PrevTerm == SwitchPaths->getSwitchInst()) { + for (auto Case : SwitchPaths->getSwitchInst()->cases()) + if (Case.getCaseSuccessor() == OldBB) + SwitchPaths->updateNextCase(Case.getCaseValue()->getValue(), NewBB); + if (SwitchPaths->getSwitchInst()->getDefaultDest() == OldBB) + SwitchPaths->updateDefaultDest(NewBB); + } + // Replace actual successors. for (unsigned Idx = 0; Idx < PrevTerm->getNumSuccessors(); Idx++) { if (PrevTerm->getSuccessor(Idx) == OldBB) { OldBB->removePredecessor(PrevBB, /* KeepOneInputPHIs = */ true); @@ -1341,17 +1340,20 @@ struct TransformDFA { // updated yet if (!isa(LastBlock->getTerminator())) return; - SwitchInst *Switch = cast(LastBlock->getTerminator()); - BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState); + assert(BB->getTerminator() == SwitchPaths->getSwitchInst() && + "Original last block must contain the threaded switch"); + BasicBlock *NextCase = SwitchPaths->getNextCaseSuccessor(NextState); std::vector DTUpdates; SmallPtrSet SuccSet; - for (BasicBlock *Succ : successors(LastBlock)) { - if (Succ != NextCase && SuccSet.insert(Succ).second) + for (BasicBlock *Succ : successors(LastBlock)) + if (SuccSet.insert(Succ).second && Succ != NextCase) DTUpdates.push_back({DominatorTree::Delete, LastBlock, Succ}); - } - Switch->eraseFromParent(); + if (!SuccSet.count(NextCase)) + DTUpdates.push_back({DominatorTree::Insert, LastBlock, NextCase}); + + LastBlock->getTerminator()->eraseFromParent(); BranchInst::Create(NextCase, LastBlock); DTU->applyUpdates(DTUpdates); diff --git a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll index 95d3ffaa21b30..0c0b6b5184562 100644 --- a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll +++ b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll @@ -109,7 +109,7 @@ define i32 @test2(i32 %num) { ; CHECK: for.body.jt3: ; CHECK-NEXT: [[COUNT_JT3:%.*]] = phi i32 [ [[INC_JT3:%.*]], [[FOR_INC_JT3:%.*]] ] ; CHECK-NEXT: [[STATE_JT3:%.*]] = phi i32 [ [[STATE_NEXT_JT3:%.*]], [[FOR_INC_JT3]] ] -; CHECK-NEXT: br label [[FOR_INC]] +; CHECK-NEXT: br label [[FOR_INC_JT1]] ; CHECK: case1: ; CHECK-NEXT: [[COUNT6:%.*]] = phi i32 [ [[COUNT_JT1]], [[FOR_BODY_JT1:%.*]] ], [ [[COUNT]], [[FOR_BODY]] ] ; CHECK-NEXT: [[CMP_C1:%.*]] = icmp slt i32 [[COUNT6]], 50 @@ -156,8 +156,8 @@ define i32 @test2(i32 %num) { ; CHECK-NEXT: [[DOTSI_UNFOLD_PHI4_JT2:%.*]] = phi i32 [ 2, [[STATE1_1_SI_UNFOLD_TRUE:%.*]] ] ; CHECK-NEXT: br label [[FOR_INC_JT2]] ; CHECK: for.inc: -; CHECK-NEXT: [[COUNT5:%.*]] = phi i32 [ [[COUNT_JT3]], [[FOR_BODY_JT3:%.*]] ], [ undef, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ undef, [[STATE1_2_SI_UNFOLD_FALSE:%.*]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE]] ] -; CHECK-NEXT: [[STATE_NEXT]] = phi i32 [ [[STATE2_1_SI_UNFOLD_PHI]], [[STATE2_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[DOTSI_UNFOLD_PHI4]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ 1, [[FOR_BODY_JT3]] ] +; CHECK-NEXT: [[COUNT5:%.*]] = phi i32 [ undef, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ undef, [[STATE1_2_SI_UNFOLD_FALSE:%.*]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE]] ] +; CHECK-NEXT: [[STATE_NEXT]] = phi i32 [ [[STATE2_1_SI_UNFOLD_PHI]], [[STATE2_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[DOTSI_UNFOLD_PHI4]], [[STATE1_1_SI_UNFOLD_FALSE]] ] ; CHECK-NEXT: [[INC]] = add nsw i32 [[COUNT5]], 1 ; CHECK-NEXT: [[CMP_EXIT:%.*]] = icmp slt i32 [[INC]], [[NUM:%.*]] ; CHECK-NEXT: br i1 [[CMP_EXIT]], label [[FOR_BODY]], label [[FOR_END:%.*]] @@ -167,8 +167,8 @@ define i32 @test2(i32 %num) { ; CHECK-NEXT: [[CMP_EXIT_JT2:%.*]] = icmp slt i32 [[INC_JT2]], [[NUM]] ; CHECK-NEXT: br i1 [[CMP_EXIT_JT2]], label [[FOR_BODY_JT2:%.*]], label [[FOR_END]] ; CHECK: for.inc.jt1: -; CHECK-NEXT: [[COUNT7:%.*]] = phi i32 [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[COUNT]], [[FOR_BODY]] ] -; CHECK-NEXT: [[STATE_NEXT_JT1]] = phi i32 [ 1, [[FOR_BODY]] ], [ [[STATE2_1_SI_UNFOLD_PHI_JT1]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[DOTSI_UNFOLD_PHI3_JT1]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ] +; CHECK-NEXT: [[COUNT7:%.*]] = phi i32 [ [[COUNT_JT3]], [[FOR_BODY_JT3:%.*]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[COUNT]], [[FOR_BODY]] ] +; CHECK-NEXT: [[STATE_NEXT_JT1]] = phi i32 [ 1, [[FOR_BODY]] ], [ 1, [[FOR_BODY_JT3]] ], [ [[STATE2_1_SI_UNFOLD_PHI_JT1]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[DOTSI_UNFOLD_PHI3_JT1]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ] ; CHECK-NEXT: [[INC_JT1]] = add nsw i32 [[COUNT7]], 1 ; CHECK-NEXT: [[CMP_EXIT_JT1:%.*]] = icmp slt i32 [[INC_JT1]], [[NUM]] ; CHECK-NEXT: br i1 [[CMP_EXIT_JT1]], label [[FOR_BODY_JT1]], label [[FOR_END]]