Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add setBranchWeigths convenience function. NFC #72446

Merged
merged 1 commit into from
Nov 16, 2023

Conversation

MatzeB
Copy link
Contributor

@MatzeB MatzeB commented Nov 15, 2023

Add setBranchWeights convenience function to ProfDataUtils.h and use
it where appropriate.

@MatzeB MatzeB requested a review from ilovepi November 15, 2023 22:34
Copy link

github-actions bot commented Nov 15, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Add `setBranchWeights` convenience function to ProfDataUtils.h and use
it where appropriate.
@MatzeB
Copy link
Contributor Author

MatzeB commented Nov 15, 2023

Adding a global setBranchWeights convenience function as requested in review of https://reviews.llvm.org/D159322

@MatzeB MatzeB added the PGO Profile Guided Optimizations label Nov 15, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 15, 2023

@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-pgo

Author: Matthias Braun (MatzeB)

Changes

Add setBranchWeights convenience function to ProfDataUtils.h and use
it where appropriate.


Full diff: https://github.com/llvm/llvm-project/pull/72446.diff

11 Files Affected:

  • (modified) llvm/include/llvm/IR/ProfDataUtils.h (+4)
  • (modified) llvm/lib/IR/ProfDataUtils.cpp (+7)
  • (modified) llvm/lib/Transforms/IPO/SampleProfile.cpp (+7-7)
  • (modified) llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp (+1-2)
  • (modified) llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp (+2-4)
  • (modified) llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp (+3-5)
  • (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+7-11)
  • (modified) llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp (+2-4)
  • (modified) llvm/lib/Transforms/Utils/Local.cpp (+1-3)
  • (modified) llvm/lib/Transforms/Utils/LoopPeel.cpp (+4-13)
  • (modified) llvm/lib/Transforms/Utils/LoopRotationUtils.cpp (+10-9)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b61199372de0de8..255fa2ff1c79065 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
 /// metadata was found.
 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
 
+/// Create a new `branch_weights` metadata node and add or overwrite
+/// a `prof` metadata reference to instruction `I`.
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
+
 } // namespace llvm
 #endif
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 77b3c1cb95d686c..29536b0b090cd76 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
@@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
 }
 
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+  MDBuilder MDB(I.getContext());
+  MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+  I.setMetadata(LLVMContext::MD_prof, BranchWeights);
+}
+
 } // namespace llvm
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 063f7b42022ff83..6c6f0a0eca72a7a 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -56,6 +56,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/PseudoProbe.h"
 #include "llvm/IR/ValueSymbolTable.h"
 #include "llvm/ProfileData/InstrProf.h"
@@ -1710,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           else if (OverwriteExistingWeights)
             I.setMetadata(LLVMContext::MD_prof, nullptr);
         } else if (!isa<IntrinsicInst>(&I)) {
-          I.setMetadata(LLVMContext::MD_prof,
-                        MDB.createBranchWeights(
-                            {static_cast<uint32_t>(BlockWeights[BB])}));
+          setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
         }
       }
     } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1720,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
       // clear it for cold code.
       for (auto &I : *BB) {
         if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
-          if (cast<CallBase>(I).isIndirectCall())
+          if (cast<CallBase>(I).isIndirectCall()) {
             I.setMetadata(LLVMContext::MD_prof, nullptr);
-          else
-            I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
+          } else {
+            setBranchWeights(I, {uint32_t(0)});
+          }
         }
       }
     }
@@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
     if (MaxWeight > 0 &&
         (!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
       LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
-      TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+      setBranchWeights(*TI, Weights);
       ORE->emit([&]() {
         return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
                << "most popular destination for conditional branches at "
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 97cf583510f9395..0a3d8d6000cf47d 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
       static_cast<uint32_t>(CHRBranchBias.scale(1000)),
       static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
   };
-  MDBuilder MDB(F.getContext());
-  MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*MergedBR, Weights);
   CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
             << "\n");
 }
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index 5c9799235017a8a..7344fea17517191 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -26,6 +26,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Value.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/Support/Casting.h"
@@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
       promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
 
   if (AttachProfToDirectCall) {
-    MDBuilder MDB(NewInst.getContext());
-    NewInst.setMetadata(
-        LLVMContext::MD_prof,
-        MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
+    setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
   }
 
   using namespace ore;
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index aea0f2f7cae7786..4a5a0b25bebbaf1 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
     // If A is uncovered, set weight=1.
     // This setup will allow BFI to give nonzero profile counts to only covered
     // blocks.
-    SmallVector<unsigned, 4> Weights;
+    SmallVector<uint32_t, 4> Weights;
     for (auto *Succ : successors(&BB))
       Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
     if (Weights.size() >= 2)
-      BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
-                                      MDB.createBranchWeights(Weights));
+      llvm::setBranchWeights(*BB.getTerminator(), Weights);
   }
 
   unsigned NumCorruptCoverage = 0;
@@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
 
 void llvm::setProfMetadata(Module *M, Instruction *TI,
                            ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
-  MDBuilder MDB(M->getContext());
   assert(MaxCount > 0 && "Bad max count");
   uint64_t Scale = calculateCountScale(MaxCount);
   SmallVector<unsigned, 4> Weights;
@@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
 
   misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
 
-  TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*TI, Weights);
   if (EmitBranchProbability) {
     std::string BrCondStr = getBranchCondString(TI);
     if (BrCondStr.empty())
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 7a8128c5b6c0901..2d899f100f8154d 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     if (BP >= BranchProbability(50, 100))
       continue;
 
-    SmallVector<uint32_t, 2> Weights;
+    uint32_t Weights[2];
     if (PredBr->getSuccessor(0) == PredOutEdge.second) {
-      Weights.push_back(BP.getNumerator());
-      Weights.push_back(BP.getCompl().getNumerator());
+      Weights[0] = BP.getNumerator();
+      Weights[1] = BP.getCompl().getNumerator();
     } else {
-      Weights.push_back(BP.getCompl().getNumerator());
-      Weights.push_back(BP.getNumerator());
+      Weights[0] = BP.getCompl().getNumerator();
+      Weights[1] = BP.getNumerator();
     }
-    PredBr->setMetadata(LLVMContext::MD_prof,
-                        MDBuilder(PredBr->getParent()->getContext())
-                            .createBranchWeights(Weights));
+    setBranchWeights(*PredBr, Weights);
   }
 }
 
@@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
       Weights.push_back(Prob.getNumerator());
 
     auto TI = BB->getTerminator();
-    TI->setMetadata(
-        LLVMContext::MD_prof,
-        MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
+    setBranchWeights(*TI, Weights);
   }
 }
 
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index ac87ee736c0d169..6f87e4d91d2c794 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Transforms/Utils/MisExpect.h"
 
@@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
   misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
 
   SI.setCondition(ArgValue);
-
-  SI.setMetadata(LLVMContext::MD_prof,
-                 MDBuilder(CI->getContext()).createBranchWeights(Weights));
-
+  setBranchWeights(SI, Weights);
   return true;
 }
 
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index aacf66bfe38eb91..c8c8d02cdcbb2b3 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
           // Remove weight for this case.
           std::swap(Weights[Idx + 1], Weights.back());
           Weights.pop_back();
-          SI->setMetadata(LLVMContext::MD_prof,
-                          MDBuilder(BB->getContext()).
-                          createBranchWeights(Weights));
+          setBranchWeights(*SI, Weights);
         }
         // Remove this entry.
         BasicBlock *ParentBB = SI->getParent();
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 2881444206b0bb5..7566f70661baf48 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -631,9 +631,7 @@ struct WeightInfo {
 /// To avoid dealing with division rounding we can just multiple both part
 /// of weights to E and use weight as (F - I * E, E).
 static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
+  setBranchWeights(*Term, Info.Weights);
   for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
     if (SubWeight != 0)
       // Don't set the probability of taking the edge from latch to loop header
@@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
   }
 }
 
-/// Update the weights of original exiting block after peeling off all
-/// iterations.
-static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
-}
-
 /// Clones the body of the loop L, putting it between \p InsertTop and \p
 /// InsertBot.
 /// \param IterNumber The serial number of the iteration currently being
@@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
     PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
   }
 
-  for (const auto &[Term, Info] : Weights)
-    fixupBranchWeights(Term, Info);
+  for (const auto &[Term, Info] : Weights) {
+    setBranchWeights(*Term, Info.Weights);
+  }
 
   // Update Metadata for count of peeled off iterations.
   unsigned AlreadyPeeled = 0;
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 012aa5dbb9ca004..ae155ac082d8111 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     LoopBackWeight = 0;
   }
 
-  MDBuilder MDB(LoopBI.getContext());
-  MDNode *LoopWeightMD =
-      MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
-                              SuccsSwapped ? ExitWeight1 : LoopBackWeight);
-  LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
+  const uint32_t LoopBIWeights[] = {
+      SuccsSwapped ? LoopBackWeight : ExitWeight1,
+      SuccsSwapped ? ExitWeight1 : LoopBackWeight,
+  };
+  setBranchWeights(LoopBI, LoopBIWeights);
   if (HasConditionalPreHeader) {
-    MDNode *PreHeaderWeightMD =
-        MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
-                                SuccsSwapped ? ExitWeight0 : EnterWeight);
-    PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
+    const uint32_t PreHeaderBIWeights[] = {
+        SuccsSwapped ? EnterWeight : ExitWeight0,
+        SuccsSwapped ? ExitWeight0 : EnterWeight,
+    };
+    setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
   }
 }
 

@MatzeB MatzeB marked this pull request as ready for review November 15, 2023 22:46
Copy link
Member

@mtrofin mtrofin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should the commit title say [nfc]?

@MatzeB MatzeB changed the title Add setBranchWeigths convenience function Add setBranchWeigths convenience function; NFC Nov 15, 2023
@MatzeB MatzeB changed the title Add setBranchWeigths convenience function; NFC Add setBranchWeigths convenience function. NFC Nov 15, 2023
Copy link
Contributor

@ilovepi ilovepi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I was surprised that Instruction didn't implement this already, but I think this is a good change. Thanks for the nice improvement.

@MatzeB MatzeB merged commit cb4627d into llvm:main Nov 16, 2023
7 checks passed
sr-tream pushed a commit to sr-tream/llvm-project that referenced this pull request Nov 20, 2023
Add `setBranchWeights` convenience function to ProfDataUtils.h and use
it where appropriate.
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
Add `setBranchWeights` convenience function to ProfDataUtils.h and use
it where appropriate.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:ir llvm:transforms PGO Profile Guided Optimizations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants