Skip to content

Commit

Permalink
[FuncSpec] Add Phi nodes to the InstCostVisitor.
Browse files Browse the repository at this point in the history
This patch allows constant folding of PHIs when estimating the user
bonus. Phi nodes are a special case since some of their inputs may
remain unresolved until all the specialization arguments have been
processed by the InstCostVisitor. Therefore, we keep a list of dead
basic blocks and then lazily visit the Phi nodes once the user bonus
has been computed for all the specialization arguments.

Differential Revision: https://reviews.llvm.org/D154852
  • Loading branch information
labrinea committed Jul 25, 2023
1 parent c1b8297 commit 03f1d09
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 7 deletions.
15 changes: 14 additions & 1 deletion llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
SCCPSolver &Solver;

ConstMap KnownConstants;
// Basic blocks known to be unreachable after constant propagation.
DenseSet<BasicBlock *> DeadBlocks;
// PHI nodes we have visited before.
DenseSet<Instruction *> VisitedPHIs;
// PHI nodes we have visited once without successfully constant folding them.
// Once the InstCostVisitor has processed all the specialization arguments,
// it should be possible to determine whether those PHIs can be folded
// (some of their incoming values may have become constant or dead).
SmallVector<Instruction *> PendingPHIs;

ConstMap::iterator LastVisited;

Expand All @@ -134,7 +143,10 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
TargetTransformInfo &TTI, SCCPSolver &Solver)
: DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}

Cost getUserBonus(Instruction *User, Value *Use, Constant *C);
Cost getUserBonus(Instruction *User, Value *Use = nullptr,
Constant *C = nullptr);

Cost getBonusFromPendingPHIs();

private:
friend class InstVisitor<InstCostVisitor, Constant *>;
Expand All @@ -143,6 +155,7 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
Cost estimateBranchInst(BranchInst &I);

Constant *visitInstruction(Instruction &I) { return nullptr; }
Constant *visitPHINode(PHINode &I);
Constant *visitFreezeInst(FreezeInst &I);
Constant *visitCallBase(CallBase &I);
Constant *visitLoadInst(LoadInst &I);
Expand Down
82 changes: 76 additions & 6 deletions llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ static cl::opt<unsigned> MaxClones(
"The maximum number of clones allowed for a single function "
"specialization"));

static cl::opt<unsigned> MaxIncomingPhiValues(
"funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc(
"The maximum number of incoming values a PHI node can have to be "
"considered during the specialization bonus estimation"));

static cl::opt<unsigned> MinFunctionSize(
"funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc(
"Don't specialize functions that have less than this number of "
Expand All @@ -104,6 +109,7 @@ static cl::opt<bool> SpecializeLiteralConstant(
// the combination of size and latency savings in comparison to the non
// specialized version of the function.
static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
DenseSet<BasicBlock *> &DeadBlocks,
ConstMap &KnownConstants, SCCPSolver &Solver,
BlockFrequencyInfo &BFI,
TargetTransformInfo &TTI) {
Expand All @@ -118,6 +124,12 @@ static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
if (!Weight)
continue;

// These blocks are considered dead as far as the InstCostVisitor is
// concerned. They haven't been proven dead yet by the Solver, but
// may become if we propagate the constant specialization arguments.
if (!DeadBlocks.insert(BB).second)
continue;

for (Instruction &I : *BB) {
// Disregard SSA copies.
if (auto *II = dyn_cast<IntrinsicInst>(&I))
Expand Down Expand Up @@ -152,9 +164,19 @@ static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
return nullptr;
}

Cost InstCostVisitor::getBonusFromPendingPHIs() {
Cost Bonus = 0;
while (!PendingPHIs.empty()) {
Instruction *Phi = PendingPHIs.pop_back_val();
Bonus += getUserBonus(Phi);
}
return Bonus;
}

Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
// Cache the iterator before visiting.
LastVisited = KnownConstants.insert({Use, C}).first;
LastVisited = Use ? KnownConstants.insert({Use, C}).first
: KnownConstants.end();

if (auto *I = dyn_cast<SwitchInst>(User))
return estimateSwitchInst(*I);
Expand All @@ -181,13 +203,15 @@ Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {

for (auto *U : User->users())
if (auto *UI = dyn_cast<Instruction>(U))
if (Solver.isBlockExecutable(UI->getParent()))
if (UI != User && Solver.isBlockExecutable(UI->getParent()))
Bonus += getUserBonus(UI, User, C);

return Bonus;
}

Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (I.getCondition() != LastVisited->first)
return 0;

Expand All @@ -208,10 +232,13 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
WorkList.push_back(BB);
}

return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
TTI);
}

Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (I.getCondition() != LastVisited->first)
return 0;

Expand All @@ -223,10 +250,39 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
Succ->getUniquePredecessor() == I.getParent())
WorkList.push_back(Succ);

return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
TTI);
}

Constant *InstCostVisitor::visitPHINode(PHINode &I) {
if (I.getNumIncomingValues() > MaxIncomingPhiValues)
return nullptr;

bool Inserted = VisitedPHIs.insert(&I).second;
Constant *Const = nullptr;

for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) {
Value *V = I.getIncomingValue(Idx);
if (auto *Inst = dyn_cast<Instruction>(V))
if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx)))
continue;
Constant *C = findConstantFor(V, KnownConstants);
if (!C) {
if (Inserted)
PendingPHIs.push_back(&I);
return nullptr;
}
if (!Const)
Const = C;
else if (C != Const)
return nullptr;
}
return Const;
}

Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second))
return LastVisited->second;
return nullptr;
Expand All @@ -253,6 +309,8 @@ Constant *InstCostVisitor::visitCallBase(CallBase &I) {
}

Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (isa<ConstantPointerNull>(LastVisited->second))
return nullptr;
return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
Expand All @@ -275,6 +333,8 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
}

Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (I.getCondition() != LastVisited->first)
return nullptr;

Expand All @@ -290,6 +350,8 @@ Constant *InstCostVisitor::visitCastInst(CastInst &I) {
}

Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
Constant *Other = findConstantFor(V, KnownConstants);
Expand All @@ -303,10 +365,14 @@ Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
}

Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
}

Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
Constant *Other = findConstantFor(V, KnownConstants);
Expand Down Expand Up @@ -713,13 +779,17 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
AllSpecs[Index].CallSites.push_back(&CS);
} else {
// Calculate the specialisation gain.
Cost Score = 0 - SpecCost;
Cost Score = 0;
InstCostVisitor Visitor = getInstCostVisitorFor(F);
for (ArgInfo &A : S.Args)
Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
Score += Visitor.getBonusFromPendingPHIs();

LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = "
<< Score << "\n");

// Discard unprofitable specialisations.
if (!ForceSpecialization && Score <= 0)
if (!ForceSpecialization && Score <= SpecCost)
continue;

// Create a new specialisation entry.
Expand Down
53 changes: 53 additions & 0 deletions llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,56 @@ TEST_F(FunctionSpecializationTest, Misc) {
Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
EXPECT_TRUE(Bonus == 0);
}

TEST_F(FunctionSpecializationTest, PhiNode) {
const char *ModuleString = R"(
define void @foo(i32 %a, i32 %b, i32 %i) {
entry:
br label %loop
loop:
switch i32 %i, label %default
[ i32 1, label %case1
i32 2, label %case2 ]
case1:
%0 = add i32 %a, 1
br label %bb
case2:
%1 = sub i32 %b, 1
br label %bb
bb:
%2 = phi i32 [ %0, %case1 ], [ %1, %case2 ], [ %2, %bb ]
%3 = icmp eq i32 %2, 2
br i1 %3, label %bb, label %loop
default:
ret void
}
)";

Module &M = parseModule(ModuleString);
Function *F = M.getFunction("foo");
FunctionSpecializer Specializer = getSpecializerFor(F);
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);

Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);

auto FuncIter = F->begin();
for (int I = 0; I < 4; ++I)
++FuncIter;

BasicBlock &BB = *FuncIter;

Instruction &Phi = BB.front();
Instruction &Icmp = *++BB.begin();

Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor) +
Specializer.getSpecializationBonus(F->getArg(1), One, Visitor) +
Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
EXPECT_TRUE(Bonus > 0);

// phi + icmp
Cost Ref = getInstCost(Phi) + getInstCost(Icmp);
Bonus = Visitor.getBonusFromPendingPHIs();
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
}

0 comments on commit 03f1d09

Please sign in to comment.