diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h index b20212a21b326..5efb3cb240ff8 100644 --- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h +++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h @@ -123,6 +123,15 @@ class InstCostVisitor : public InstVisitor { SCCPSolver &Solver; ConstMap KnownConstants; + // Basic blocks known to be unreachable after constant propagation. + DenseSet DeadBlocks; + // PHI nodes we have visited before. + DenseSet VisitedPHIs; + // PHI nodes we have visited once without successfully constant folding them. + // Once the InstCostVisitor has processed all the specialization arguments, + // it should be possible to determine whether those PHIs can be folded + // (some of their incoming values may have become constant or dead). + SmallVector PendingPHIs; ConstMap::iterator LastVisited; @@ -131,7 +140,14 @@ class InstCostVisitor : public InstVisitor { TargetTransformInfo &TTI, SCCPSolver &Solver) : DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {} - Cost getUserBonus(Instruction *User, Value *Use, Constant *C); + bool isBlockExecutable(BasicBlock *BB) { + return Solver.isBlockExecutable(BB) && !DeadBlocks.contains(BB); + } + + Cost getUserBonus(Instruction *User, Value *Use = nullptr, + Constant *C = nullptr); + + Cost getBonusFromPendingPHIs(); private: friend class InstVisitor; @@ -140,6 +156,7 @@ class InstCostVisitor : public InstVisitor { Cost estimateBranchInst(BranchInst &I); Constant *visitInstruction(Instruction &I) { return nullptr; } + Constant *visitPHINode(PHINode &I); Constant *visitFreezeInst(FreezeInst &I); Constant *visitCallBase(CallBase &I); Constant *visitLoadInst(LoadInst &I); diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index ac5dbc7cfb2a5..d917342a7d290 100644 --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -78,6 +78,11 @@ static cl::opt MaxClones( "The maximum number of clones allowed for a single function " "specialization")); +static cl::opt MaxIncomingPhiValues( + "funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc( + "The maximum number of incoming values a PHI node can have to be " + "considered during the specialization bonus estimation")); + static cl::opt MinFunctionSize( "funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc( "Don't specialize functions that have less than this number of " @@ -104,6 +109,7 @@ static cl::opt SpecializeLiteralConstant( // the combination of size and latency savings in comparison to the non // specialized version of the function. static Cost estimateBasicBlocks(SmallVectorImpl &WorkList, + DenseSet &DeadBlocks, ConstMap &KnownConstants, SCCPSolver &Solver, BlockFrequencyInfo &BFI, TargetTransformInfo &TTI) { @@ -118,6 +124,12 @@ static Cost estimateBasicBlocks(SmallVectorImpl &WorkList, if (!Weight) continue; + // These blocks are considered dead as far as the InstCostVisitor + // is concerned. They haven't been proven dead yet by the Solver, + // but may become if we propagate the specialization arguments. + if (!DeadBlocks.insert(BB).second) + continue; + for (Instruction &I : *BB) { // Disregard SSA copies. if (auto *II = dyn_cast(&I)) @@ -152,9 +164,25 @@ static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) { return nullptr; } +Cost InstCostVisitor::getBonusFromPendingPHIs() { + Cost Bonus = 0; + while (!PendingPHIs.empty()) { + Instruction *Phi = PendingPHIs.pop_back_val(); + // The pending PHIs could have been proven dead by now. + if (isBlockExecutable(Phi->getParent())) + Bonus += getUserBonus(Phi); + } + return Bonus; +} + Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) { + // We have already propagated a constant for this user. + if (KnownConstants.contains(User)) + return 0; + // Cache the iterator before visiting. - LastVisited = KnownConstants.insert({Use, C}).first; + LastVisited = Use ? KnownConstants.insert({Use, C}).first + : KnownConstants.end(); if (auto *I = dyn_cast(User)) return estimateSwitchInst(*I); @@ -181,13 +209,15 @@ Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) { for (auto *U : User->users()) if (auto *UI = dyn_cast(U)) - if (Solver.isBlockExecutable(UI->getParent())) + if (UI != User && isBlockExecutable(UI->getParent())) Bonus += getUserBonus(UI, User, C); return Bonus; } Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (I.getCondition() != LastVisited->first) return 0; @@ -208,10 +238,13 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) { WorkList.push_back(BB); } - return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI); + return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI, + TTI); } Cost InstCostVisitor::estimateBranchInst(BranchInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (I.getCondition() != LastVisited->first) return 0; @@ -223,10 +256,39 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) { Succ->getUniquePredecessor() == I.getParent()) WorkList.push_back(Succ); - return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI); + return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI, + TTI); +} + +Constant *InstCostVisitor::visitPHINode(PHINode &I) { + if (I.getNumIncomingValues() > MaxIncomingPhiValues) + return nullptr; + + bool Inserted = VisitedPHIs.insert(&I).second; + Constant *Const = nullptr; + + for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) { + Value *V = I.getIncomingValue(Idx); + if (auto *Inst = dyn_cast(V)) + if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx))) + continue; + Constant *C = findConstantFor(V, KnownConstants); + if (!C) { + if (Inserted) + PendingPHIs.push_back(&I); + return nullptr; + } + if (!Const) + Const = C; + else if (C != Const) + return nullptr; + } + return Const; } Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second)) return LastVisited->second; return nullptr; @@ -253,6 +315,8 @@ Constant *InstCostVisitor::visitCallBase(CallBase &I) { } Constant *InstCostVisitor::visitLoadInst(LoadInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (isa(LastVisited->second)) return nullptr; return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL); @@ -275,6 +339,8 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { } Constant *InstCostVisitor::visitSelectInst(SelectInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (I.getCondition() != LastVisited->first) return nullptr; @@ -290,6 +356,8 @@ Constant *InstCostVisitor::visitCastInst(CastInst &I) { } Constant *InstCostVisitor::visitCmpInst(CmpInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + bool Swap = I.getOperand(1) == LastVisited->first; Value *V = Swap ? I.getOperand(0) : I.getOperand(1); Constant *Other = findConstantFor(V, KnownConstants); @@ -303,10 +371,14 @@ Constant *InstCostVisitor::visitCmpInst(CmpInst &I) { } Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL); } Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + bool Swap = I.getOperand(1) == LastVisited->first; Value *V = Swap ? I.getOperand(0) : I.getOperand(1); Constant *Other = findConstantFor(V, KnownConstants); @@ -713,13 +785,17 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost, AllSpecs[Index].CallSites.push_back(&CS); } else { // Calculate the specialisation gain. - Cost Score = 0 - SpecCost; + Cost Score = 0; InstCostVisitor Visitor = getInstCostVisitorFor(F); for (ArgInfo &A : S.Args) Score += getSpecializationBonus(A.Formal, A.Actual, Visitor); + Score += Visitor.getBonusFromPendingPHIs(); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = " + << Score << "\n"); // Discard unprofitable specialisations. - if (!ForceSpecialization && Score <= 0) + if (!ForceSpecialization && Score <= SpecCost) continue; // Create a new specialisation entry. @@ -798,7 +874,7 @@ Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, Cost TotalCost = 0; for (auto *U : A->users()) if (auto *UI = dyn_cast(U)) - if (Solver.isBlockExecutable(UI->getParent())) + if (Visitor.isBlockExecutable(UI->getParent())) TotalCost += Visitor.getUserBonus(UI, A, C); LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus " diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp index 81da6d8f6ed5c..6018263cad658 100644 --- a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp +++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp @@ -302,3 +302,69 @@ TEST_F(FunctionSpecializationTest, Misc) { Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor); EXPECT_TRUE(Bonus == 0); } + +TEST_F(FunctionSpecializationTest, PhiNode) { + const char *ModuleString = R"( + define void @foo(i32 %a, i32 %b, i32 %i) { + entry: + br label %loop + loop: + %0 = phi i32 [ %a, %entry ], [ %3, %bb ] + switch i32 %i, label %default + [ i32 1, label %case1 + i32 2, label %case2 ] + case1: + %1 = add i32 %0, 1 + br label %bb + case2: + %2 = phi i32 [ %a, %entry ], [ %0, %loop ] + br label %bb + bb: + %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ] + %4 = icmp eq i32 %3, 1 + br i1 %4, label %bb, label %loop + default: + ret void + } + )"; + + Module &M = parseModule(ModuleString); + Function *F = M.getFunction("foo"); + FunctionSpecializer Specializer = getSpecializerFor(F); + InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); + + Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); + + auto FuncIter = F->begin(); + BasicBlock &Loop = *++FuncIter; + BasicBlock &Case1 = *++FuncIter; + BasicBlock &Case2 = *++FuncIter; + BasicBlock &BB = *++FuncIter; + + Instruction &PhiLoop = Loop.front(); + Instruction &Add = Case1.front(); + Instruction &PhiCase2 = Case2.front(); + Instruction &BrBB = Case2.back(); + Instruction &PhiBB = BB.front(); + Instruction &Icmp = *++BB.begin(); + + Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); + EXPECT_EQ(Bonus, 0); + + Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); + EXPECT_EQ(Bonus, 0); + + // phi + br + Cost Ref = getInstCost(PhiCase2) + getInstCost(BrBB); + Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); + EXPECT_EQ(Bonus, Ref); + EXPECT_TRUE(Bonus > 0); + + // phi + phi + add + icmp + Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) + + getInstCost(Icmp); + Bonus = Visitor.getBonusFromPendingPHIs(); + EXPECT_EQ(Bonus, Ref); + EXPECT_TRUE(Bonus > 0); +} +