Skip to content

Commit

Permalink
Revert "[FuncSpec] Add Phi nodes to the InstCostVisitor."
Browse files Browse the repository at this point in the history
This reverts commit 03f1d09
because of a crash reported on https://reviews.llvm.org/D154852
  • Loading branch information
labrinea committed Jul 26, 2023
1 parent 9957225 commit bc849e5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 143 deletions.
15 changes: 1 addition & 14 deletions llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,6 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
SCCPSolver &Solver;

ConstMap KnownConstants;
// Basic blocks known to be unreachable after constant propagation.
DenseSet<BasicBlock *> DeadBlocks;
// PHI nodes we have visited before.
DenseSet<Instruction *> VisitedPHIs;
// PHI nodes we have visited once without successfully constant folding them.
// Once the InstCostVisitor has processed all the specialization arguments,
// it should be possible to determine whether those PHIs can be folded
// (some of their incoming values may have become constant or dead).
SmallVector<Instruction *> PendingPHIs;

ConstMap::iterator LastVisited;

Expand All @@ -143,10 +134,7 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
TargetTransformInfo &TTI, SCCPSolver &Solver)
: DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {}

Cost getUserBonus(Instruction *User, Value *Use = nullptr,
Constant *C = nullptr);

Cost getBonusFromPendingPHIs();
Cost getUserBonus(Instruction *User, Value *Use, Constant *C);

private:
friend class InstVisitor<InstCostVisitor, Constant *>;
Expand All @@ -155,7 +143,6 @@ class InstCostVisitor : public InstVisitor<InstCostVisitor, Constant *> {
Cost estimateBranchInst(BranchInst &I);

Constant *visitInstruction(Instruction &I) { return nullptr; }
Constant *visitPHINode(PHINode &I);
Constant *visitFreezeInst(FreezeInst &I);
Constant *visitCallBase(CallBase &I);
Constant *visitLoadInst(LoadInst &I);
Expand Down
82 changes: 6 additions & 76 deletions llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ static cl::opt<unsigned> MaxClones(
"The maximum number of clones allowed for a single function "
"specialization"));

static cl::opt<unsigned> MaxIncomingPhiValues(
"funcspec-max-incoming-phi-values", cl::init(4), cl::Hidden, cl::desc(
"The maximum number of incoming values a PHI node can have to be "
"considered during the specialization bonus estimation"));

static cl::opt<unsigned> MinFunctionSize(
"funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc(
"Don't specialize functions that have less than this number of "
Expand All @@ -109,7 +104,6 @@ static cl::opt<bool> SpecializeLiteralConstant(
// the combination of size and latency savings in comparison to the non
// specialized version of the function.
static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
DenseSet<BasicBlock *> &DeadBlocks,
ConstMap &KnownConstants, SCCPSolver &Solver,
BlockFrequencyInfo &BFI,
TargetTransformInfo &TTI) {
Expand All @@ -124,12 +118,6 @@ static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList,
if (!Weight)
continue;

// These blocks are considered dead as far as the InstCostVisitor is
// concerned. They haven't been proven dead yet by the Solver, but
// may become if we propagate the constant specialization arguments.
if (!DeadBlocks.insert(BB).second)
continue;

for (Instruction &I : *BB) {
// Disregard SSA copies.
if (auto *II = dyn_cast<IntrinsicInst>(&I))
Expand Down Expand Up @@ -164,19 +152,9 @@ static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) {
return nullptr;
}

Cost InstCostVisitor::getBonusFromPendingPHIs() {
Cost Bonus = 0;
while (!PendingPHIs.empty()) {
Instruction *Phi = PendingPHIs.pop_back_val();
Bonus += getUserBonus(Phi);
}
return Bonus;
}

Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {
// Cache the iterator before visiting.
LastVisited = Use ? KnownConstants.insert({Use, C}).first
: KnownConstants.end();
LastVisited = KnownConstants.insert({Use, C}).first;

if (auto *I = dyn_cast<SwitchInst>(User))
return estimateSwitchInst(*I);
Expand All @@ -203,15 +181,13 @@ Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) {

for (auto *U : User->users())
if (auto *UI = dyn_cast<Instruction>(U))
if (UI != User && Solver.isBlockExecutable(UI->getParent()))
if (Solver.isBlockExecutable(UI->getParent()))
Bonus += getUserBonus(UI, User, C);

return Bonus;
}

Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (I.getCondition() != LastVisited->first)
return 0;

Expand All @@ -232,13 +208,10 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) {
WorkList.push_back(BB);
}

return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
TTI);
return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
}

Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (I.getCondition() != LastVisited->first)
return 0;

Expand All @@ -250,39 +223,10 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
Succ->getUniquePredecessor() == I.getParent())
WorkList.push_back(Succ);

return estimateBasicBlocks(WorkList, DeadBlocks, KnownConstants, Solver, BFI,
TTI);
}

Constant *InstCostVisitor::visitPHINode(PHINode &I) {
if (I.getNumIncomingValues() > MaxIncomingPhiValues)
return nullptr;

bool Inserted = VisitedPHIs.insert(&I).second;
Constant *Const = nullptr;

for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) {
Value *V = I.getIncomingValue(Idx);
if (auto *Inst = dyn_cast<Instruction>(V))
if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx)))
continue;
Constant *C = findConstantFor(V, KnownConstants);
if (!C) {
if (Inserted)
PendingPHIs.push_back(&I);
return nullptr;
}
if (!Const)
Const = C;
else if (C != Const)
return nullptr;
}
return Const;
return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI);
}

Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second))
return LastVisited->second;
return nullptr;
Expand All @@ -309,8 +253,6 @@ Constant *InstCostVisitor::visitCallBase(CallBase &I) {
}

Constant *InstCostVisitor::visitLoadInst(LoadInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (isa<ConstantPointerNull>(LastVisited->second))
return nullptr;
return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL);
Expand All @@ -333,8 +275,6 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
}

Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

if (I.getCondition() != LastVisited->first)
return nullptr;

Expand All @@ -350,8 +290,6 @@ Constant *InstCostVisitor::visitCastInst(CastInst &I) {
}

Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
Constant *Other = findConstantFor(V, KnownConstants);
Expand All @@ -365,14 +303,10 @@ Constant *InstCostVisitor::visitCmpInst(CmpInst &I) {
}

Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL);
}

Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");

bool Swap = I.getOperand(1) == LastVisited->first;
Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
Constant *Other = findConstantFor(V, KnownConstants);
Expand Down Expand Up @@ -779,17 +713,13 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost,
AllSpecs[Index].CallSites.push_back(&CS);
} else {
// Calculate the specialisation gain.
Cost Score = 0;
Cost Score = 0 - SpecCost;
InstCostVisitor Visitor = getInstCostVisitorFor(F);
for (ArgInfo &A : S.Args)
Score += getSpecializationBonus(A.Formal, A.Actual, Visitor);
Score += Visitor.getBonusFromPendingPHIs();

LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization score = "
<< Score << "\n");

// Discard unprofitable specialisations.
if (!ForceSpecialization && Score <= SpecCost)
if (!ForceSpecialization && Score <= 0)
continue;

// Create a new specialisation entry.
Expand Down
53 changes: 0 additions & 53 deletions llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,56 +287,3 @@ TEST_F(FunctionSpecializationTest, Misc) {
Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
EXPECT_TRUE(Bonus == 0);
}

TEST_F(FunctionSpecializationTest, PhiNode) {
const char *ModuleString = R"(
define void @foo(i32 %a, i32 %b, i32 %i) {
entry:
br label %loop
loop:
switch i32 %i, label %default
[ i32 1, label %case1
i32 2, label %case2 ]
case1:
%0 = add i32 %a, 1
br label %bb
case2:
%1 = sub i32 %b, 1
br label %bb
bb:
%2 = phi i32 [ %0, %case1 ], [ %1, %case2 ], [ %2, %bb ]
%3 = icmp eq i32 %2, 2
br i1 %3, label %bb, label %loop
default:
ret void
}
)";

Module &M = parseModule(ModuleString);
Function *F = M.getFunction("foo");
FunctionSpecializer Specializer = getSpecializerFor(F);
InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);

Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);

auto FuncIter = F->begin();
for (int I = 0; I < 4; ++I)
++FuncIter;

BasicBlock &BB = *FuncIter;

Instruction &Phi = BB.front();
Instruction &Icmp = *++BB.begin();

Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor) +
Specializer.getSpecializationBonus(F->getArg(1), One, Visitor) +
Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
EXPECT_TRUE(Bonus > 0);

// phi + icmp
Cost Ref = getInstCost(Phi) + getInstCost(Icmp);
Bonus = Visitor.getBonusFromPendingPHIs();
EXPECT_EQ(Bonus, Ref);
EXPECT_TRUE(Bonus > 0);
}

0 comments on commit bc849e5

Please sign in to comment.