diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h index 02e73e247cbe0..349d5a7a08795 100644 --- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h +++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h @@ -52,6 +52,7 @@ #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/InstVisitor.h" #include "llvm/Transforms/Scalar/SCCP.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/SCCPSolver.h" @@ -69,6 +70,9 @@ using SpecMap = DenseMap>; // Just a shorter abbreviation to improve indentation. using Cost = InstructionCost; +// Map of known constants found during the specialization bonus estimation. +using ConstMap = DenseMap; + // Specialization signature, used to uniquely designate a specialization within // a function. struct SpecSig { @@ -115,6 +119,39 @@ struct Spec { : F(F), Sig(S), Score(Score) {} }; +class InstCostVisitor : public InstVisitor { + const DataLayout &DL; + BlockFrequencyInfo &BFI; + TargetTransformInfo &TTI; + SCCPSolver &Solver; + + ConstMap KnownConstants; + + ConstMap::iterator LastVisited; + +public: + InstCostVisitor(const DataLayout &DL, BlockFrequencyInfo &BFI, + TargetTransformInfo &TTI, SCCPSolver &Solver) + : DL(DL), BFI(BFI), TTI(TTI), Solver(Solver) {} + + Cost getUserBonus(Instruction *User, Value *Use, Constant *C); + +private: + friend class InstVisitor; + + Cost estimateSwitchInst(SwitchInst &I); + Cost estimateBranchInst(BranchInst &I); + + Constant *visitInstruction(Instruction &I) { return nullptr; } + Constant *visitLoadInst(LoadInst &I); + Constant *visitGetElementPtrInst(GetElementPtrInst &I); + Constant *visitSelectInst(SelectInst &I); + Constant *visitCastInst(CastInst &I); + Constant *visitCmpInst(CmpInst &I); + Constant *visitUnaryOperator(UnaryOperator &I); + Constant *visitBinaryOperator(BinaryOperator &I); +}; + class FunctionSpecializer { /// The IPSCCP Solver. @@ -151,6 +188,16 @@ class FunctionSpecializer { bool run(); + InstCostVisitor getInstCostVisitorFor(Function *F) { + auto &BFI = (GetBFI)(*F); + auto &TTI = (GetTTI)(*F); + return InstCostVisitor(M.getDataLayout(), BFI, TTI, Solver); + } + + /// Compute a bonus for replacing argument \p A with constant \p C. + Cost getSpecializationBonus(Argument *A, Constant *C, + InstCostVisitor &Visitor); + private: Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call); @@ -194,9 +241,6 @@ class FunctionSpecializer { /// Compute and return the cost of specializing function \p F. Cost getSpecializationCost(Function *F); - /// Compute a bonus for replacing argument \p A with constant \p C. - Cost getSpecializationBonus(Argument *A, Constant *C); - /// Determine if it is possible to specialise the function for constant values /// of the formal parameter \p A. bool isArgumentInteresting(Argument *A); diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index 87cc0f63e565d..a970253d9b1c8 100644 --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -48,11 +48,14 @@ #include "llvm/Transforms/IPO/FunctionSpecialization.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueLattice.h" #include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/ConstantFold.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Transforms/Scalar/SCCP.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -94,6 +97,210 @@ static cl::opt SpecializeLiteralConstant( "Enable specialization of functions that take a literal constant as an " "argument")); +// Estimates the instruction cost of all the basic blocks in \p WorkList. +// The successors of such blocks are added to the list as long as they are +// executable and they have a unique predecessor. \p WorkList represents +// the basic blocks of a specialization which become dead once we replace +// instructions that are known to be constants. The aim here is to estimate +// the combination of size and latency savings in comparison to the non +// specialized version of the function. +static Cost estimateBasicBlocks(SmallVectorImpl &WorkList, + ConstMap &KnownConstants, SCCPSolver &Solver, + BlockFrequencyInfo &BFI, + TargetTransformInfo &TTI) { + Cost Bonus = 0; + + // Accumulate the instruction cost of each basic block weighted by frequency. + while (!WorkList.empty()) { + BasicBlock *BB = WorkList.pop_back_val(); + + uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() / + BFI.getEntryFreq(); + if (!Weight) + continue; + + for (Instruction &I : *BB) { + // Disregard SSA copies. + if (auto *II = dyn_cast(&I)) + if (II->getIntrinsicID() == Intrinsic::ssa_copy) + continue; + // If it's a known constant we have already accounted for it. + if (KnownConstants.contains(&I)) + continue; + + Bonus += Weight * + TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus + << " after user " << I << "\n"); + } + + // Keep adding dead successors to the list as long as they are + // executable and they have a unique predecessor. + for (BasicBlock *SuccBB : successors(BB)) + if (Solver.isBlockExecutable(SuccBB) && + SuccBB->getUniquePredecessor() == BB) + WorkList.push_back(SuccBB); + } + return Bonus; +} + +static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) { + if (auto It = KnownConstants.find(V); It != KnownConstants.end()) + return It->second; + return nullptr; +} + +Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) { + // Cache the iterator before visiting. + LastVisited = KnownConstants.insert({Use, C}).first; + + if (auto *I = dyn_cast(User)) + return estimateSwitchInst(*I); + + if (auto *I = dyn_cast(User)) + return estimateBranchInst(*I); + + C = visit(*User); + if (!C) + return 0; + + KnownConstants.insert({User, C}); + + uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() / + BFI.getEntryFreq(); + if (!Weight) + return 0; + + Cost Bonus = Weight * + TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus + << " for user " << *User << "\n"); + + for (auto *U : User->users()) + if (auto *UI = dyn_cast(U)) + if (Solver.isBlockExecutable(UI->getParent())) + Bonus += getUserBonus(UI, User, C); + + return Bonus; +} + +Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) { + if (I.getCondition() != LastVisited->first) + return 0; + + auto *C = cast(LastVisited->second); + BasicBlock *Succ = I.findCaseValue(C)->getCaseSuccessor(); + // Initialize the worklist with the dead basic blocks. These are the + // destination labels which are different from the one corresponding + // to \p C. They should be executable and have a unique predecessor. + SmallVector WorkList; + for (const auto &Case : I.cases()) { + BasicBlock *BB = Case.getCaseSuccessor(); + if (BB == Succ || !Solver.isBlockExecutable(BB) || + BB->getUniquePredecessor() != I.getParent()) + continue; + WorkList.push_back(BB); + } + + return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI); +} + +Cost InstCostVisitor::estimateBranchInst(BranchInst &I) { + if (I.getCondition() != LastVisited->first) + return 0; + + BasicBlock *Succ = I.getSuccessor(LastVisited->second->isOneValue()); + // Initialize the worklist with the dead successor as long as + // it is executable and has a unique predecessor. + SmallVector WorkList; + if (Solver.isBlockExecutable(Succ) && + Succ->getUniquePredecessor() == I.getParent()) + WorkList.push_back(Succ); + + return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI); +} + +Constant *InstCostVisitor::visitLoadInst(LoadInst &I) { + if (isa(LastVisited->second)) + return nullptr; + return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL); +} + +Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { + SmallVector Operands; + Operands.reserve(I.getNumOperands()); + + for (unsigned Idx = 0, E = I.getNumOperands(); Idx != E; ++Idx) { + Value *V = I.getOperand(Idx); + auto *C = dyn_cast(V); + if (!C) + C = findConstantFor(V, KnownConstants); + if (!C) + return nullptr; + Operands.push_back(C); + } + + auto *Ptr = cast(Operands[0]); + auto Ops = ArrayRef(Operands.begin() + 1, Operands.end()); + return ConstantFoldGetElementPtr(I.getSourceElementType(), Ptr, + I.isInBounds(), std::nullopt, Ops); +} + +Constant *InstCostVisitor::visitSelectInst(SelectInst &I) { + if (I.getCondition() != LastVisited->first) + return nullptr; + + Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue() + : I.getTrueValue(); + auto *C = dyn_cast(V); + if (!C) + C = findConstantFor(V, KnownConstants); + return C; +} + +Constant *InstCostVisitor::visitCastInst(CastInst &I) { + return ConstantFoldCastOperand(I.getOpcode(), LastVisited->second, + I.getType(), DL); +} + +Constant *InstCostVisitor::visitCmpInst(CmpInst &I) { + bool Swap = I.getOperand(1) == LastVisited->first; + Value *V = Swap ? I.getOperand(0) : I.getOperand(1); + auto *Other = dyn_cast(V); + if (!Other) + Other = findConstantFor(V, KnownConstants); + + if (!Other) + return nullptr; + + Constant *Const = LastVisited->second; + return Swap ? + ConstantFoldCompareInstOperands(I.getPredicate(), Other, Const, DL) + : ConstantFoldCompareInstOperands(I.getPredicate(), Const, Other, DL); +} + +Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) { + return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL); +} + +Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) { + bool Swap = I.getOperand(1) == LastVisited->first; + Value *V = Swap ? I.getOperand(0) : I.getOperand(1); + auto *Other = dyn_cast(V); + if (!Other) + Other = findConstantFor(V, KnownConstants); + + if (!Other) + return nullptr; + + Constant *Const = LastVisited->second; + return dyn_cast_or_null(Swap ? + simplifyBinOp(I.getOpcode(), Other, Const, SimplifyQuery(DL)) + : simplifyBinOp(I.getOpcode(), Const, Other, SimplifyQuery(DL))); +} + Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) { Value *StoreValue = nullptr; @@ -412,10 +619,6 @@ CodeMetrics &FunctionSpecializer::analyzeFunction(Function *F) { CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); for (BasicBlock &BB : *F) Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); - - LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function " - << F->getName() << " is " << Metrics.NumInsts - << " instructions\n"); } return Metrics; } @@ -496,8 +699,9 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost, } else { // Calculate the specialisation gain. Cost Score = 0 - SpecCost; + InstCostVisitor Visitor = getInstCostVisitorFor(F); for (ArgInfo &A : S.Args) - Score += getSpecializationBonus(A.Formal, A.Actual); + Score += getSpecializationBonus(A.Formal, A.Actual, Visitor); // Discard unprofitable specialisations. if (!ForceSpecialization && Score <= 0) @@ -584,49 +788,23 @@ Cost FunctionSpecializer::getSpecializationCost(Function *F) { // Otherwise, set the specialization cost to be the cost of all the // instructions in the function. - return Metrics.NumInsts * InlineConstants::getInstrCost(); -} - -static Cost getUserBonus(User *U, TargetTransformInfo &TTI, - BlockFrequencyInfo &BFI) { - auto *I = dyn_cast_or_null(U); - // If not an instruction we do not know how to evaluate. - // Keep minimum possible cost for now so that it doesnt affect - // specialization. - if (!I) - return 0; - - uint64_t Weight = BFI.getBlockFreq(I->getParent()).getFrequency() / - BFI.getEntryFreq(); - if (!Weight) - return 0; - - Cost Bonus = Weight * - TTI.getInstructionCost(U, TargetTransformInfo::TCK_SizeAndLatency); - - // Traverse recursively if there are more uses. - // TODO: Any other instructions to be added here? - if (I->mayReadFromMemory() || I->isCast()) - for (auto *User : I->users()) - Bonus += getUserBonus(User, TTI, BFI); - - return Bonus; + return Metrics.NumInsts; } /// Compute a bonus for replacing argument \p A with constant \p C. -Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C) { - Function *F = A->getParent(); - auto &TTI = (GetTTI)(*F); - auto &BFI = (GetBFI)(*F); +Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, + InstCostVisitor &Visitor) { LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " << C->getNameOrAsOperand() << "\n"); Cost TotalCost = 0; - for (auto *U : A->users()) { - TotalCost += getUserBonus(U, TTI, BFI); - LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; - TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n"); - } + for (auto *U : A->users()) + if (auto *UI = dyn_cast(U)) + if (Solver.isBlockExecutable(UI->getParent())) + TotalCost += Visitor.getUserBonus(UI, A, C); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus " + << TotalCost << " for argument " << *A << "\n"); // The below heuristic is only concerned with exposing inlining // opportunities via indirect call promotion. If the argument is not a diff --git a/llvm/unittests/Transforms/IPO/CMakeLists.txt b/llvm/unittests/Transforms/IPO/CMakeLists.txt index 3b16d81ae3b29..4e4372179b46c 100644 --- a/llvm/unittests/Transforms/IPO/CMakeLists.txt +++ b/llvm/unittests/Transforms/IPO/CMakeLists.txt @@ -12,6 +12,7 @@ add_llvm_unittest(IPOTests LowerTypeTests.cpp WholeProgramDevirt.cpp AttributorTest.cpp + FunctionSpecializationTest.cpp ) set_property(TARGET IPOTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests") diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp new file mode 100644 index 0000000000000..16c9a505e4498 --- /dev/null +++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp @@ -0,0 +1,258 @@ +//===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/IPO/FunctionSpecialization.h" +#include "llvm/Transforms/Utils/SCCPSolver.h" +#include "gtest/gtest.h" +#include + +namespace llvm { + +class FunctionSpecializationTest : public testing::Test { +protected: + LLVMContext Ctx; + FunctionAnalysisManager FAM; + std::unique_ptr M; + std::unique_ptr Solver; + + FunctionSpecializationTest() { + FAM.registerPass([&] { return TargetLibraryAnalysis(); }); + FAM.registerPass([&] { return TargetIRAnalysis(); }); + FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); + FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); + FAM.registerPass([&] { return LoopAnalysis(); }); + FAM.registerPass([&] { return AssumptionAnalysis(); }); + FAM.registerPass([&] { return DominatorTreeAnalysis(); }); + FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); + FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + } + + Module &parseModule(const char *ModuleString) { + SMDiagnostic Err; + M = parseAssemblyString(ModuleString, Err, Ctx); + EXPECT_TRUE(M); + return *M; + } + + FunctionSpecializer getSpecializerFor(Function *F) { + auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { + return FAM.getResult(F); + }; + auto GetTTI = [this](Function &F) -> TargetTransformInfo & { + return FAM.getResult(F); + }; + auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { + return FAM.getResult(F); + }; + auto GetAC = [this](Function &F) -> AssumptionCache & { + return FAM.getResult(F); + }; + auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn { + DominatorTree &DT = FAM.getResult(F); + return { std::make_unique(F, DT, + FAM.getResult(F)), + &DT, FAM.getCachedResult(F) }; + }; + + Solver = std::make_unique(M->getDataLayout(), GetTLI, Ctx); + + Solver->addAnalysis(*F, GetAnalysis(*F)); + Solver->markBlockExecutable(&F->front()); + for (Argument &Arg : F->args()) + Solver->markOverdefined(&Arg); + Solver->solveWhileResolvedUndefsIn(*M); + + return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, + GetAC); + } + + Cost getInstCost(Instruction &I) { + auto &TTI = FAM.getResult(*I.getFunction()); + auto &BFI = FAM.getResult(*I.getFunction()); + + return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() * + TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + } +}; + +} // namespace llvm + +using namespace llvm; + +TEST_F(FunctionSpecializationTest, SwitchInst) { + const char *ModuleString = R"( + define void @foo(i32 %a, i32 %b, i32 %i) { + entry: + switch i32 %i, label %default + [ i32 1, label %case1 + i32 2, label %case2 ] + case1: + %0 = mul i32 %a, 2 + %1 = sub i32 6, 5 + br label %bb1 + case2: + %2 = and i32 %b, 3 + %3 = sdiv i32 8, 2 + br label %bb2 + bb1: + %4 = add i32 %0, %b + br label %default + bb2: + %5 = or i32 %2, %a + br label %default + default: + ret void + } + )"; + + Module &M = parseModule(ModuleString); + Function *F = M.getFunction("foo"); + FunctionSpecializer Specializer = getSpecializerFor(F); + InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); + + Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); + + auto FuncIter = F->begin(); + BasicBlock &Case1 = *++FuncIter; + BasicBlock &Case2 = *++FuncIter; + BasicBlock &BB1 = *++FuncIter; + BasicBlock &BB2 = *++FuncIter; + + Instruction &Mul = Case1.front(); + Instruction &And = Case2.front(); + Instruction &Sdiv = *++Case2.begin(); + Instruction &BrBB2 = Case2.back(); + Instruction &Add = BB1.front(); + Instruction &Or = BB2.front(); + Instruction &BrDefault = BB2.back(); + + // mul + Cost Ref = getInstCost(Mul); + Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); + EXPECT_EQ(Bonus, Ref); + + // and + or + add + Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); + Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); + EXPECT_EQ(Bonus, Ref); + + // sdiv + br + br + Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrDefault); + Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); + EXPECT_EQ(Bonus, Ref); +} + +TEST_F(FunctionSpecializationTest, BranchInst) { + const char *ModuleString = R"( + define void @foo(i32 %a, i32 %b, i1 %cond) { + entry: + br i1 %cond, label %bb0, label %bb2 + bb0: + %0 = mul i32 %a, 2 + %1 = sub i32 6, 5 + br label %bb1 + bb1: + %2 = add i32 %0, %b + %3 = sdiv i32 8, 2 + br label %bb2 + bb2: + ret void + } + )"; + + Module &M = parseModule(ModuleString); + Function *F = M.getFunction("foo"); + FunctionSpecializer Specializer = getSpecializerFor(F); + InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); + + Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); + Constant *False = ConstantInt::getFalse(M.getContext()); + + auto FuncIter = F->begin(); + BasicBlock &BB0 = *++FuncIter; + BasicBlock &BB1 = *++FuncIter; + + Instruction &Mul = BB0.front(); + Instruction &Sub = *++BB0.begin(); + Instruction &BrBB1 = BB0.back(); + Instruction &Add = BB1.front(); + Instruction &Sdiv = *++BB1.begin(); + Instruction &BrBB2 = BB1.back(); + + // mul + Cost Ref = getInstCost(Mul); + Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); + EXPECT_EQ(Bonus, Ref); + + // add + Ref = getInstCost(Add); + Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); + EXPECT_EQ(Bonus, Ref); + + // sub + br + sdiv + br + Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) + + getInstCost(BrBB2); + Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor); + EXPECT_EQ(Bonus, Ref); +} + +TEST_F(FunctionSpecializationTest, Misc) { + const char *ModuleString = R"( + @g = constant [2 x i32] zeroinitializer, align 4 + + define i32 @foo(i8 %a, i1 %cond, ptr %b) { + %cmp = icmp eq i8 %a, 10 + %ext = zext i1 %cmp to i32 + %sel = select i1 %cond, i32 %ext, i32 1 + %gep = getelementptr i32, ptr %b, i32 %sel + %ld = load i32, ptr %gep + ret i32 %ld + } + )"; + + Module &M = parseModule(ModuleString); + Function *F = M.getFunction("foo"); + FunctionSpecializer Specializer = getSpecializerFor(F); + InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); + + GlobalVariable *GV = M.getGlobalVariable("g"); + Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); + Constant *True = ConstantInt::getTrue(M.getContext()); + + auto BlockIter = F->front().begin(); + Instruction &Icmp = *BlockIter++; + Instruction &Zext = *BlockIter++; + Instruction &Select = *BlockIter++; + Instruction &Gep = *BlockIter++; + Instruction &Load = *BlockIter++; + + // icmp + zext + Cost Ref = getInstCost(Icmp) + getInstCost(Zext); + Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); + EXPECT_EQ(Bonus, Ref); + + // select + Ref = getInstCost(Select); + Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor); + EXPECT_EQ(Bonus, Ref); + + // gep + load + Ref = getInstCost(Gep) + getInstCost(Load); + Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor); + EXPECT_EQ(Bonus, Ref); +}