diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index a0876b169e0b8..a7bcbf010d1bf 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -194,10 +194,11 @@ LLVM_ABI void setExplicitlyUnknownBranchWeights(Instruction &I, /// Like setExplicitlyUnknownBranchWeights(...), but only sets unknown branch /// weights in the new instruction if the parent function of the original /// instruction has an entry count. This is to not confuse users by injecting -/// profile data into non-profiled functions. -LLVM_ABI void setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I, - Function &F, - StringRef PassName); +/// profile data into non-profiled functions. If \p F is nullptr, we will fetch +/// the function from \p I. +LLVM_ABI void +setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I, StringRef PassName, + const Function *F = nullptr); /// Analogous to setExplicitlyUnknownBranchWeights, but for functions and their /// entry counts. diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp index 88dbd176e0d3f..95edb2e8e56d8 100644 --- a/llvm/lib/IR/IRBuilder.cpp +++ b/llvm/lib/IR/IRBuilder.cpp @@ -1019,8 +1019,7 @@ Value *IRBuilderBase::CreateSelectWithUnknownProfile(Value *C, Value *True, const Twine &Name) { Value *Ret = CreateSelectFMF(C, True, False, {}, Name); if (auto *SI = dyn_cast(Ret)) { - setExplicitlyUnknownBranchWeightsIfProfiled( - *SI, *SI->getParent()->getParent(), PassName); + setExplicitlyUnknownBranchWeightsIfProfiled(*SI, PassName); } return Ret; } diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index fc2be5188f456..94dbe1f3988b8 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -274,9 +274,12 @@ void llvm::setExplicitlyUnknownBranchWeights(Instruction &I, } void llvm::setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I, - Function &F, - StringRef PassName) { - if (std::optional EC = F.getEntryCount(); + StringRef PassName, + const Function *F) { + F = F ? F : I.getFunction(); + assert(F && "Either pass a instruction attached to a Function, or explicitly " + "pass the Function that it will be attached to"); + if (std::optional EC = F->getEntryCount(); EC && EC->getCount() > 0) setExplicitlyUnknownBranchWeights(I, PassName); } diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 7a95df4b2a47c..b575d76e897d2 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -1378,8 +1378,7 @@ static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU, IRB.CreateTrunc(Call->getArgOperand(1), ByteTy), BBNext, N); // We can't know the precise weights here, as they would depend on the value // distribution of Call->getArgOperand(1). So we just mark it as "unknown". - setExplicitlyUnknownBranchWeightsIfProfiled(*SI, *Call->getFunction(), - DEBUG_TYPE); + setExplicitlyUnknownBranchWeightsIfProfiled(*SI, DEBUG_TYPE); Type *IndexTy = DL.getIndexType(Call->getType()); SmallVector Updates; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index d85e4f7590197..9bdd8cb71f7f3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -479,7 +479,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final const Twine &NameStr = "", InsertPosition InsertBefore = nullptr) { auto *Sel = SelectInst::Create(C, S1, S2, NameStr, InsertBefore, nullptr); - setExplicitlyUnknownBranchWeightsIfProfiled(*Sel, F, DEBUG_TYPE); + setExplicitlyUnknownBranchWeightsIfProfiled(*Sel, DEBUG_TYPE, &F); return Sel; } diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index bb6c879f4d47e..b96c194b8e7ae 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -329,8 +329,7 @@ static void buildPartialUnswitchConditionalBranch( HasBranchWeights ? ComputeProfFrom.getMetadata(LLVMContext::MD_prof) : nullptr); if (!HasBranchWeights) - setExplicitlyUnknownBranchWeightsIfProfiled( - *BR, *BR->getParent()->getParent(), DEBUG_TYPE); + setExplicitlyUnknownBranchWeightsIfProfiled(*BR, DEBUG_TYPE); } /// Copy a set of loop invariant values, and conditionally branch on them. diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index cbc604e87cf1a..5bfb59fad99d0 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -5212,8 +5212,7 @@ bool SimplifyCFGOpt::simplifyBranchOnICmpChain(BranchInst *BI, // We don't have any info about this condition. auto *Br = TrueWhenEqual ? Builder.CreateCondBr(ExtraCase, EdgeBB, NewBB) : Builder.CreateCondBr(ExtraCase, NewBB, EdgeBB); - setExplicitlyUnknownBranchWeightsIfProfiled(*Br, *NewBB->getParent(), - DEBUG_TYPE); + setExplicitlyUnknownBranchWeightsIfProfiled(*Br, DEBUG_TYPE); OldTI->eraseFromParent();