diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h index 5906b1aaf2ceb6..f4e571e864935d 100644 --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -1176,36 +1176,37 @@ class VACopyInst : public IntrinsicInst { Value *getSrc() const { return const_cast(getArgOperand(1)); } }; -/// This represents the llvm.instrprof_increment intrinsic. -class InstrProfIncrementInst : public IntrinsicInst { +/// A base class for all instrprof intrinsics. +class InstrProfInstBase : public IntrinsicInst { public: - static bool classof(const IntrinsicInst *I) { - return I->getIntrinsicID() == Intrinsic::instrprof_increment; - } - static bool classof(const Value *V) { - return isa(V) && classof(cast(V)); - } - + // The name of the instrumented function. GlobalVariable *getName() const { return cast( const_cast(getArgOperand(0))->stripPointerCasts()); } - + // The hash of the CFG for the instrumented function. ConstantInt *getHash() const { return cast(const_cast(getArgOperand(1))); } + // The number of counters for the instrumented function. + ConstantInt *getNumCounters() const; + // The index of the counter that this instruction acts on. + ConstantInt *getIndex() const; +}; - ConstantInt *getNumCounters() const { - return cast(const_cast(getArgOperand(2))); +/// This represents the llvm.instrprof.increment intrinsic. +class InstrProfIncrementInst : public InstrProfInstBase { +public: + static bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::instrprof_increment; } - - ConstantInt *getIndex() const { - return cast(const_cast(getArgOperand(3))); + static bool classof(const Value *V) { + return isa(V) && classof(cast(V)); } - Value *getStep() const; }; +/// This represents the llvm.instrprof.increment.step intrinsic. class InstrProfIncrementInstStep : public InstrProfIncrementInst { public: static bool classof(const IntrinsicInst *I) { @@ -1216,8 +1217,8 @@ class InstrProfIncrementInstStep : public InstrProfIncrementInst { } }; -/// This represents the llvm.instrprof_value_profile intrinsic. -class InstrProfValueProfileInst : public IntrinsicInst { +/// This represents the llvm.instrprof.value.profile intrinsic. +class InstrProfValueProfileInst : public InstrProfInstBase { public: static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::instrprof_value_profile; @@ -1226,15 +1227,6 @@ class InstrProfValueProfileInst : public IntrinsicInst { return isa(V) && classof(cast(V)); } - GlobalVariable *getName() const { - return cast( - const_cast(getArgOperand(0))->stripPointerCasts()); - } - - ConstantInt *getHash() const { - return cast(const_cast(getArgOperand(1))); - } - Value *getTargetValue() const { return cast(const_cast(getArgOperand(2))); } diff --git a/llvm/include/llvm/Transforms/Instrumentation/InstrProfiling.h b/llvm/include/llvm/Transforms/Instrumentation/InstrProfiling.h index 94b156f3b137b9..64523d7d073c02 100644 --- a/llvm/include/llvm/Transforms/Instrumentation/InstrProfiling.h +++ b/llvm/include/llvm/Transforms/Instrumentation/InstrProfiling.h @@ -100,7 +100,7 @@ class InstrProfiling : public PassInfoMixin { /// /// If the counter array doesn't yet exist, the profile data variables /// referring to them will also be created. - GlobalVariable *getOrCreateRegionCounters(InstrProfIncrementInst *Inc); + GlobalVariable *getOrCreateRegionCounters(InstrProfInstBase *Inc); /// Emit the section with compressed function names. void emitNameData(); diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp index e7555bf8bc4276..adea7abb75cf47 100644 --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -178,6 +178,18 @@ int llvm::Intrinsic::lookupLLVMIntrinsicByName(ArrayRef NameTable, return -1; } +ConstantInt *InstrProfInstBase::getNumCounters() const { + if (InstrProfValueProfileInst::classof(this)) + llvm_unreachable("InstrProfValueProfileInst does not have counters!"); + return cast(const_cast(getArgOperand(2))); +} + +ConstantInt *InstrProfInstBase::getIndex() const { + if (InstrProfValueProfileInst::classof(this)) + llvm_unreachable("Please use InstrProfValueProfileInst::getIndex()"); + return cast(const_cast(getArgOperand(3))); +} + Value *InstrProfIncrementInst::getStep() const { if (InstrProfIncrementInstStep::classof(this)) { return const_cast(getArgOperand(4)); diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index e9c4a56a90c2e7..ab179b03dd29f2 100644 --- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -445,24 +445,19 @@ llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options, return new InstrProfilingLegacyPass(Options, IsCS); } -static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) { - InstrProfIncrementInst *Inc = dyn_cast(Instr); - if (Inc) - return Inc; - return dyn_cast(Instr); -} - bool InstrProfiling::lowerIntrinsics(Function *F) { bool MadeChange = false; PromotionCandidates.clear(); for (BasicBlock &BB : *F) { for (Instruction &Instr : llvm::make_early_inc_range(BB)) { - InstrProfIncrementInst *Inc = castToIncrementInst(&Instr); - if (Inc) { - lowerIncrement(Inc); + if (auto *IPIS = dyn_cast(&Instr)) { + lowerIncrement(IPIS); MadeChange = true; - } else if (auto *Ind = dyn_cast(&Instr)) { - lowerValueProfileInst(Ind); + } else if (auto *IPI = dyn_cast(&Instr)) { + lowerIncrement(IPI); + MadeChange = true; + } else if (auto *IPVP = dyn_cast(&Instr)) { + lowerValueProfileInst(IPVP); MadeChange = true; } } @@ -539,19 +534,14 @@ static bool needsRuntimeHookUnconditionally(const Triple &TT) { /// Check if the module contains uses of any profiling intrinsics. static bool containsProfilingIntrinsics(Module &M) { - if (auto *F = M.getFunction( - Intrinsic::getName(llvm::Intrinsic::instrprof_increment))) - if (!F->use_empty()) - return true; - if (auto *F = M.getFunction( - Intrinsic::getName(llvm::Intrinsic::instrprof_increment_step))) - if (!F->use_empty()) - return true; - if (auto *F = M.getFunction( - Intrinsic::getName(llvm::Intrinsic::instrprof_value_profile))) - if (!F->use_empty()) - return true; - return false; + auto containsIntrinsic = [&](int ID) { + if (auto *F = M.getFunction(Intrinsic::getName(ID))) + return !F->use_empty(); + return false; + }; + return containsIntrinsic(llvm::Intrinsic::instrprof_increment) || + containsIntrinsic(llvm::Intrinsic::instrprof_increment_step) || + containsIntrinsic(llvm::Intrinsic::instrprof_value_profile); } bool InstrProfiling::run( @@ -770,7 +760,7 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) { } /// Get the name of a profiling variable for a particular function. -static std::string getVarName(InstrProfIncrementInst *Inc, StringRef Prefix, +static std::string getVarName(InstrProfInstBase *Inc, StringRef Prefix, bool &Renamed) { StringRef NamePrefix = getInstrProfNameVarPrefix(); StringRef Name = Inc->getName()->getName().substr(NamePrefix.size()); @@ -859,7 +849,7 @@ static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) { } GlobalVariable * -InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { +InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { GlobalVariable *NamePtr = Inc->getName(); auto &PD = ProfileDataMap[NamePtr]; if (PD.RegionCounters)