diff --git a/llvm/lib/Transforms/Utils/CloneFunction.cpp b/llvm/lib/Transforms/Utils/CloneFunction.cpp index 1d0244c54a282..260dc3a12df43 100644 --- a/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -811,23 +811,29 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, if (NumPreds != PN->getNumIncomingValues()) { assert(NumPreds < PN->getNumIncomingValues()); // Count how many times each predecessor comes to this block. - std::map PredCount; + DenseMap PredCount; for (BasicBlock *Pred : predecessors(NewBB)) - --PredCount[Pred]; - - // Figure out how many entries to remove from each PHI. - for (BasicBlock *Pred : PN->blocks()) ++PredCount[Pred]; - // At this point, the excess predecessor entries are positive in the - // map. Loop over all of the PHIs and remove excess predecessor - // entries. BasicBlock::iterator I = NewBB->begin(); + DenseMap SeenPredCount; + SeenPredCount.reserve(PredCount.size()); for (; (PN = dyn_cast(I)); ++I) { - for (const auto &[Pred, Count] : PredCount) { - for ([[maybe_unused]] unsigned _ : llvm::seq(Count)) - PN->removeIncomingValue(Pred, false); - } + SeenPredCount.clear(); + PN->removeIncomingValueIf( + [&](unsigned Idx) { + BasicBlock *IncomingBlock = PN->getIncomingBlock(Idx); + auto It = PredCount.find(IncomingBlock); + if (It == PredCount.end()) + return true; + unsigned &SeenCount = SeenPredCount[IncomingBlock]; + if (SeenCount < It->second) { + SeenCount++; + return false; + } + return true; + }, + false); } }