Skip to content

Commit

Permalink
Reland [FuncSpec] Add Phi nodes to the InstCostVisitor.
Browse files Browse the repository at this point in the history
This patch allows constant folding of PHIs when estimating the user
bonus. Phi nodes are a special case since some of their inputs may
remain unresolved until all the specialization arguments have been
processed by the InstCostVisitor. Therefore, we keep a list of dead
basic blocks and then lazily visit the Phi nodes once the user bonus
has been computed for all the specialization arguments.

Differential Revision: https://reviews.llvm.org/D154852
  • Loading branch information
labrinea committed Jul 31, 2023
1 parent 3f75d38 commit 893d3a6
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 8 deletions.
19 changes: 18 additions & 1 deletion llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
SCCPSolver &Solver;

ConstMap KnownConstants;
// Basic blocks known to be unreachable after constant propagation.
DenseSet<BasicBlock *> DeadBlocks;
// PHI nodes we have visited before.
DenseSet<Instruction *> VisitedPHIs;
// PHI nodes we have visited once without successfully constant folding them.
// Once the InstCostVisitor has processed all the specialization arguments,
// it should be possible to determine whether those PHIs can be folded
// (some of their incoming values may have become constant or dead).
SmallVector<Instruction *> PendingPHIs;

ConstMap::iterator LastVisited;

Expand All @@ -131,7 +140,14 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
TargetTransformInfo &TTI, SCCPSolver &Solver)
: DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}

Cost getUserBonus(Instruction *User, Value *Use, Constant *C);
bool isBlockExecutable(BasicBlock *BB) {
return Solver.isBlockExecutable(BB) && !DeadBlocks.contains(BB);
}

Cost getUserBonus(Instruction *User, Value *Use = nullptr,
Constant *C = nullptr);

Cost getBonusFromPendingPHIs();

private:
friend class InstVisitor<InstCostVisitor, Constant *>;
Expand All @@ -140,6 +156,7 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
Cost estimateBranchInst(BranchInst &I);

Constant *visitInstruction(Instruction &I) { return nullptr; }
Constant *visitPHINode(PHINode &I);
Constant *visitFreezeInst(FreezeInst &I);
Constant *visitCallBase(CallBase &I);
Constant *visitLoadInst(LoadInst &I);
Expand Down
90 changes: 83 additions & 7 deletions llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ static cl::opt<unsigned> MaxClones(
"The maximum number of clones allowed for a single function "
"specialization"));

static cl::opt<unsigned> MaxIncomingPhiValues(
"funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc(
"The maximum number of incoming values a PHI node can have to be "
"considered during the specialization bonus estimation"));

static cl::opt<unsigned> MinFunctionSize(
"funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc(
"Don't specialize functions that have less than this number of "
Expand All @@ -104,6 +109,7 @@ static cl::opt<bool> SpecializeLiteralConstant(
// the combination of size and latency savings in comparison to the non
// specialized version of the function.
static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
DenseSet<BasicBlock *> &DeadBlocks,
ConstMap &KnownConstants, SCCPSolver &Solver,
BlockFrequencyInfo &BFI,
TargetTransformInfo &TTI) {
Expand All @@ -118,6 +124,12 @@ static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
if (!Weight)
continue;

// These blocks are considered dead as far as the InstCostVisitor
// is concerned. They haven't been proven dead yet by the Solver,
// but may become if we propagate the specialization arguments.
if (!DeadBlocks.insert(BB).second)
continue;

for (Instruction &I : *BB) {
// Disregard SSA copies.
if (auto *II = dyn_cast<IntrinsicInst>(&I))
Expand Down Expand Up @@ -152,9 +164,25 @@ static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
return nullptr;
}

Cost InstCostVisitor::getBonusFromPendingPHIs() {
Cost Bonus = 0;
while (!PendingPHIs.empty()) {
Instruction *Phi = PendingPHIs.pop_back_val();
// The pending PHIs could have been proven dead by now.
if (isBlockExecutable(Phi->getParent()))
Bonus += getUserBonus(Phi);
}
return Bonus;
}

Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
// We have already propagated a constant for this user.
if (KnownConstants.contains(User))
return 0;

// Cache the iterator before visiting.
LastVisited = KnownConstants.insert({Use, C}).first;
LastVisited = Use ? KnownConstants.insert({Use, C}).first
: KnownConstants.end();

if (auto *I = dyn_cast<SwitchInst>(User))
return estimateSwitchInst(*I);
Expand All @@ -181,13 +209,15 @@ Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {

for (auto *U : User->users())
if (auto *UI = dyn_cast<Instruction>(U))
if (Solver.isBlockExecutable(UI->getParent()))
if (UI != User && isBlockExecutable(UI->getParent()))
Bonus += getUserBonus(UI, User, C);

return Bonus;
}

Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (I.getCondition() != LastVisited->first)
return 0;

Expand All @@ -208,10 +238,13 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
WorkList.push_back(BB);
}

return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
TTI);
}

Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (I.getCondition() != LastVisited->first)
return 0;

Expand All @@ -223,10 +256,39 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
Succ->getUniquePredecessor() == I.getParent())
WorkList.push_back(Succ);

return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
TTI);
}

Constant *InstCostVisitor::visitPHINode(PHINode &I) {
if (I.getNumIncomingValues() > MaxIncomingPhiValues)
return nullptr;

bool Inserted = VisitedPHIs.insert(&I).second;
Constant *Const = nullptr;

for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) {
Value *V = I.getIncomingValue(Idx);
if (auto *Inst = dyn_cast<Instruction>(V))
if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx)))
continue;
Constant *C = findConstantFor(V, KnownConstants);
if (!C) {
if (Inserted)
PendingPHIs.push_back(&I);
return nullptr;
}
if (!Const)
Const = C;
else if (C != Const)
return nullptr;
}
return Const;
}

Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second))
return LastVisited->second;
return nullptr;
Expand All @@ -253,6 +315,8 @@ Constant *InstCostVisitor::visitCallBase(CallBase &I) {
}

Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (isa<ConstantPointerNull>(LastVisited->second))
return nullptr;
return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
Expand All @@ -275,6 +339,8 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
}

Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (I.getCondition() != LastVisited->first)
return nullptr;

Expand All @@ -290,6 +356,8 @@ Constant *InstCostVisitor::visitCastInst(CastInst &I) {
}

Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
Constant *Other = findConstantFor(V, KnownConstants);
Expand All @@ -303,10 +371,14 @@ Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
}

Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
}

Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
Constant *Other = findConstantFor(V, KnownConstants);
Expand Down Expand Up @@ -713,13 +785,17 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
AllSpecs[Index].CallSites.push_back(&CS);
} else {
// Calculate the specialisation gain.
Cost Score = 0 - SpecCost;
Cost Score = 0;
InstCostVisitor Visitor = getInstCostVisitorFor(F);
for (ArgInfo &A : S.Args)
Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
Score += Visitor.getBonusFromPendingPHIs();

LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = "
<< Score << "\n");

// Discard unprofitable specialisations.
if (!ForceSpecialization && Score <= 0)
if (!ForceSpecialization && Score <= SpecCost)
continue;

// Create a new specialisation entry.
Expand Down Expand Up @@ -798,7 +874,7 @@ Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C,
Cost TotalCost = 0;
for (auto *U : A->users())
if (auto *UI = dyn_cast<Instruction>(U))
if (Solver.isBlockExecutable(UI->getParent()))
if (Visitor.isBlockExecutable(UI->getParent()))
TotalCost += Visitor.getUserBonus(UI, A, C);

LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus "
Expand Down
66 changes: 66 additions & 0 deletions llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,69 @@ TEST_F(FunctionSpecializationTest, Misc) {
Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
EXPECT_TRUE(Bonus == 0);
}

TEST_F(FunctionSpecializationTest, PhiNode) {
const char *ModuleString = R"(
define void @foo(i32 %a, i32 %b, i32 %i) {
entry:
br label %loop
loop:
%0 = phi i32 [ %a, %entry ], [ %3, %bb ]
switch i32 %i, label %default
[ i32 1, label %case1
i32 2, label %case2 ]
case1:
%1 = add i32 %0, 1
br label %bb
case2:
%2 = phi i32 [ %a, %entry ], [ %0, %loop ]
br label %bb
bb:
%3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
%4 = icmp eq i32 %3, 1
br i1 %4, label %bb, label %loop
default:
ret void
}
)";

Module &M = parseModule(ModuleString);
Function *F = M.getFunction("foo");
FunctionSpecializer Specializer = getSpecializerFor(F);
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);

Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);

auto FuncIter = F->begin();
BasicBlock &Loop = *++FuncIter;
BasicBlock &Case1 = *++FuncIter;
BasicBlock &Case2 = *++FuncIter;
BasicBlock &BB = *++FuncIter;

Instruction &PhiLoop = Loop.front();
Instruction &Add = Case1.front();
Instruction &PhiCase2 = Case2.front();
Instruction &BrBB = Case2.back();
Instruction &PhiBB = BB.front();
Instruction &Icmp = *++BB.begin();

Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
EXPECT_EQ(Bonus, 0);

Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
EXPECT_EQ(Bonus, 0);

// phi + br
Cost Ref = getInstCost(PhiCase2) + getInstCost(BrBB);
Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);

// phi + phi + add + icmp
Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) +
getInstCost(Icmp);
Bonus = Visitor.getBonusFromPendingPHIs();
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
}

0 comments on commit 893d3a6

Please sign in to comment.