Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoopVectorize: Set branch_weight for conditional branches #72450

Merged
merged 2 commits into from
Nov 16, 2023

Conversation

MatzeB
Copy link
Contributor

@MatzeB MatzeB commented Nov 15, 2023

Consistently add branch_weights metadata in any condition branch created by LoopVectorize.cpp:

  • Will only add metadata if the original loop-latch branch had metadata assigned.
  • Most checks should rarely trigger so I am using a 127:1 ratio.
  • For the middle block we assume an equal distribution of modulo results.

Add `setBranchWeights` convenience function to ProfDataUtils.h and use
it where appropriate.
Consistently add `branch_weights` metadata in any condition branch
created from `LoopVectorize.cpp`:
- Will only add metadata if the original loop-latch branch had metadata
  assigned.
- Most checks should rarely trigger so I am using a 127:1 ratio.
- For the middle block we assume an equal distribution of modulo
  results.
@MatzeB
Copy link
Contributor Author

MatzeB commented Nov 15, 2023

Note that this depends on PR #72446 and contains these changes as well for now.

@MatzeB
Copy link
Contributor Author

MatzeB commented Nov 15, 2023

This was already reviewed and accepted as https://reviews.llvm.org/D159322 under the condition that setBranchWeights is factored out which I did now.

@MatzeB MatzeB marked this pull request as ready for review November 15, 2023 23:03
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 15, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Matthias Braun (MatzeB)

Changes

Consistently add branch_weights metadata in any condition branch
created from LoopVectorize.cpp:

  • Will only add metadata if the original loop-latch branch had metadata
    assigned.
  • Most checks should rarely trigger so I am using a 127:1 ratio.
  • For the middle block we assume an equal distribution of modulo
    results.

Patch is 44.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72450.diff

14 Files Affected:

  • (modified) llvm/include/llvm/IR/ProfDataUtils.h (+4)
  • (modified) llvm/lib/IR/ProfDataUtils.cpp (+7)
  • (modified) llvm/lib/Transforms/IPO/SampleProfile.cpp (+7-7)
  • (modified) llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp (+1-2)
  • (modified) llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp (+2-4)
  • (modified) llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp (+3-5)
  • (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+7-11)
  • (modified) llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp (+2-4)
  • (modified) llvm/lib/Transforms/Utils/Local.cpp (+1-3)
  • (modified) llvm/lib/Transforms/Utils/LoopPeel.cpp (+4-13)
  • (modified) llvm/lib/Transforms/Utils/LoopRotationUtils.cpp (+10-9)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+60-19)
  • (added) llvm/test/Transforms/LoopVectorize/branch-weights.ll (+81)
  • (modified) llvm/test/Transforms/LoopVectorize/first-order-recurrence.ll (+30-30)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b61199372de0de8..255fa2ff1c79065 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
 /// metadata was found.
 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
 
+/// Create a new `branch_weights` metadata node and add or overwrite
+/// a `prof` metadata reference to instruction `I`.
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
+
 } // namespace llvm
 #endif
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 77b3c1cb95d686c..29536b0b090cd76 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
@@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
 }
 
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+  MDBuilder MDB(I.getContext());
+  MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+  I.setMetadata(LLVMContext::MD_prof, BranchWeights);
+}
+
 } // namespace llvm
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 063f7b42022ff83..6c6f0a0eca72a7a 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -56,6 +56,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/PseudoProbe.h"
 #include "llvm/IR/ValueSymbolTable.h"
 #include "llvm/ProfileData/InstrProf.h"
@@ -1710,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           else if (OverwriteExistingWeights)
             I.setMetadata(LLVMContext::MD_prof, nullptr);
         } else if (!isa<IntrinsicInst>(&I)) {
-          I.setMetadata(LLVMContext::MD_prof,
-                        MDB.createBranchWeights(
-                            {static_cast<uint32_t>(BlockWeights[BB])}));
+          setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
         }
       }
     } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1720,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
       // clear it for cold code.
       for (auto &I : *BB) {
         if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
-          if (cast<CallBase>(I).isIndirectCall())
+          if (cast<CallBase>(I).isIndirectCall()) {
             I.setMetadata(LLVMContext::MD_prof, nullptr);
-          else
-            I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
+          } else {
+            setBranchWeights(I, {uint32_t(0)});
+          }
         }
       }
     }
@@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
     if (MaxWeight > 0 &&
         (!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
       LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
-      TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+      setBranchWeights(*TI, Weights);
       ORE->emit([&]() {
         return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
                << "most popular destination for conditional branches at "
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 97cf583510f9395..0a3d8d6000cf47d 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
       static_cast<uint32_t>(CHRBranchBias.scale(1000)),
       static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
   };
-  MDBuilder MDB(F.getContext());
-  MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*MergedBR, Weights);
   CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
             << "\n");
 }
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index 5c9799235017a8a..7344fea17517191 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -26,6 +26,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Value.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/Support/Casting.h"
@@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
       promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
 
   if (AttachProfToDirectCall) {
-    MDBuilder MDB(NewInst.getContext());
-    NewInst.setMetadata(
-        LLVMContext::MD_prof,
-        MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
+    setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
   }
 
   using namespace ore;
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index aea0f2f7cae7786..4a5a0b25bebbaf1 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
     // If A is uncovered, set weight=1.
     // This setup will allow BFI to give nonzero profile counts to only covered
     // blocks.
-    SmallVector<unsigned, 4> Weights;
+    SmallVector<uint32_t, 4> Weights;
     for (auto *Succ : successors(&BB))
       Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
     if (Weights.size() >= 2)
-      BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
-                                      MDB.createBranchWeights(Weights));
+      llvm::setBranchWeights(*BB.getTerminator(), Weights);
   }
 
   unsigned NumCorruptCoverage = 0;
@@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
 
 void llvm::setProfMetadata(Module *M, Instruction *TI,
                            ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
-  MDBuilder MDB(M->getContext());
   assert(MaxCount > 0 && "Bad max count");
   uint64_t Scale = calculateCountScale(MaxCount);
   SmallVector<unsigned, 4> Weights;
@@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
 
   misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
 
-  TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*TI, Weights);
   if (EmitBranchProbability) {
     std::string BrCondStr = getBranchCondString(TI);
     if (BrCondStr.empty())
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 7a8128c5b6c0901..2d899f100f8154d 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     if (BP >= BranchProbability(50, 100))
       continue;
 
-    SmallVector<uint32_t, 2> Weights;
+    uint32_t Weights[2];
     if (PredBr->getSuccessor(0) == PredOutEdge.second) {
-      Weights.push_back(BP.getNumerator());
-      Weights.push_back(BP.getCompl().getNumerator());
+      Weights[0] = BP.getNumerator();
+      Weights[1] = BP.getCompl().getNumerator();
     } else {
-      Weights.push_back(BP.getCompl().getNumerator());
-      Weights.push_back(BP.getNumerator());
+      Weights[0] = BP.getCompl().getNumerator();
+      Weights[1] = BP.getNumerator();
     }
-    PredBr->setMetadata(LLVMContext::MD_prof,
-                        MDBuilder(PredBr->getParent()->getContext())
-                            .createBranchWeights(Weights));
+    setBranchWeights(*PredBr, Weights);
   }
 }
 
@@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
       Weights.push_back(Prob.getNumerator());
 
     auto TI = BB->getTerminator();
-    TI->setMetadata(
-        LLVMContext::MD_prof,
-        MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
+    setBranchWeights(*TI, Weights);
   }
 }
 
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index ac87ee736c0d169..6f87e4d91d2c794 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Transforms/Utils/MisExpect.h"
 
@@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
   misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
 
   SI.setCondition(ArgValue);
-
-  SI.setMetadata(LLVMContext::MD_prof,
-                 MDBuilder(CI->getContext()).createBranchWeights(Weights));
-
+  setBranchWeights(SI, Weights);
   return true;
 }
 
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index aacf66bfe38eb91..c8c8d02cdcbb2b3 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
           // Remove weight for this case.
           std::swap(Weights[Idx + 1], Weights.back());
           Weights.pop_back();
-          SI->setMetadata(LLVMContext::MD_prof,
-                          MDBuilder(BB->getContext()).
-                          createBranchWeights(Weights));
+          setBranchWeights(*SI, Weights);
         }
         // Remove this entry.
         BasicBlock *ParentBB = SI->getParent();
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 2881444206b0bb5..7566f70661baf48 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -631,9 +631,7 @@ struct WeightInfo {
 /// To avoid dealing with division rounding we can just multiple both part
 /// of weights to E and use weight as (F - I * E, E).
 static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
+  setBranchWeights(*Term, Info.Weights);
   for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
     if (SubWeight != 0)
       // Don't set the probability of taking the edge from latch to loop header
@@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
   }
 }
 
-/// Update the weights of original exiting block after peeling off all
-/// iterations.
-static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
-}
-
 /// Clones the body of the loop L, putting it between \p InsertTop and \p
 /// InsertBot.
 /// \param IterNumber The serial number of the iteration currently being
@@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
     PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
   }
 
-  for (const auto &[Term, Info] : Weights)
-    fixupBranchWeights(Term, Info);
+  for (const auto &[Term, Info] : Weights) {
+    setBranchWeights(*Term, Info.Weights);
+  }
 
   // Update Metadata for count of peeled off iterations.
   unsigned AlreadyPeeled = 0;
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 012aa5dbb9ca004..ae155ac082d8111 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     LoopBackWeight = 0;
   }
 
-  MDBuilder MDB(LoopBI.getContext());
-  MDNode *LoopWeightMD =
-      MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
-                              SuccsSwapped ? ExitWeight1 : LoopBackWeight);
-  LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
+  const uint32_t LoopBIWeights[] = {
+      SuccsSwapped ? LoopBackWeight : ExitWeight1,
+      SuccsSwapped ? ExitWeight1 : LoopBackWeight,
+  };
+  setBranchWeights(LoopBI, LoopBIWeights);
   if (HasConditionalPreHeader) {
-    MDNode *PreHeaderWeightMD =
-        MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
-                                SuccsSwapped ? ExitWeight0 : EnterWeight);
-    PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
+    const uint32_t PreHeaderBIWeights[] = {
+        SuccsSwapped ? EnterWeight : ExitWeight0,
+        SuccsSwapped ? ExitWeight0 : EnterWeight,
+    };
+    setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
   }
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e9d0315d114f65c..0a10f0d6471769a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -112,10 +112,12 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -396,6 +398,19 @@ static cl::opt<bool> UseWiderVFIfCallVariantsPresent(
     cl::Hidden,
     cl::desc("Try wider VFs if they enable the use of vector variants"));
 
+// Likelyhood of bypassing the vectorized loop because assumptions about SCEV
+// variables not overflowing do not hold. See `emitSCEVChecks`.
+static constexpr uint32_t SCEVCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because pointers overlap. See
+// `emitMemRuntimeChecks`.
+static constexpr uint32_t MemCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because there are zero trips left
+// after prolog. See `emitIterationCountCheck`.
+static constexpr uint32_t MinItersBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because of zero trips necessary.
+// See `emitMinimumVectorEpilogueIterCountCheck`.
+static constexpr uint32_t EpilogueMinItersBypassWeights[] = {1, 127};
+
 /// A helper function that returns true if the given type is irregular. The
 /// type is irregular if its allocated size doesn't equal the store size of an
 /// element of the corresponding vector type.
@@ -1962,12 +1977,14 @@ class GeneratedRTChecks {
   SCEVExpander MemCheckExp;
 
   bool CostTooHigh = false;
+  const bool AddBranchWeights;
 
 public:
   GeneratedRTChecks(ScalarEvolution &SE, DominatorTree *DT, LoopInfo *LI,
-                    TargetTransformInfo *TTI, const DataLayout &DL)
+                    TargetTransformInfo *TTI, const DataLayout &DL,
+                    bool AddBranchWeights)
       : DT(DT), LI(LI), TTI(TTI), SCEVExp(SE, DL, "scev.check"),
-        MemCheckExp(SE, DL, "scev.check") {}
+        MemCheckExp(SE, DL, "scev.check"), AddBranchWeights(AddBranchWeights) {}
 
   /// Generate runtime checks in SCEVCheckBlock and MemCheckBlock, so we can
   /// accurately estimate the cost of the runtime checks. The blocks are
@@ -2160,8 +2177,10 @@ class GeneratedRTChecks {
     DT->addNewBlock(SCEVCheckBlock, Pred);
     DT->changeImmediateDominator(LoopVectorPreHeader, SCEVCheckBlock);
 
-    ReplaceInstWithInst(SCEVCheckBlock->getTerminator(),
-                        BranchInst::Create(Bypass, LoopVectorPreHeader, Cond));
+    BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, Cond);
+    if (AddBranchWeights)
+      setBranchWeights(BI, SCEVCheckBypassWeights);
+    ReplaceInstWithInst(SCEVCheckBlock->getTerminator(), &BI);
     return SCEVCheckBlock;
   }
 
@@ -2185,9 +2204,12 @@ class GeneratedRTChecks {
     if (auto *PL = LI->getLoopFor(LoopVectorPreHeader))
       PL->addBasicBlockToLoop(MemCheckBlock, *LI);
 
-    ReplaceInstWithInst(
-        MemCheckBlock->getTerminator(),
-        BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond));
+    BranchInst &BI =
+        *BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond);
+    if (AddBranchWeights) {
+      setBranchWeights(BI, MemCheckBypassWeights);
+    }
+    ReplaceInstWithInst(MemCheckBlock->getTerminator(), &BI);
     MemCheckBlock->getTerminator()->setDebugLoc(
         Pred->getTerminator()->getDebugLoc());
 
@@ -2900,9 +2922,11 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
     // dominator of the exit blocks.
     DT->changeImmediateDominator(LoopExitBlock, TCCheckBlock);
 
-  ReplaceInstWithInst(
-      TCCheckBlock->getTerminator(),
-      BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+  BranchInst &BI =
+      *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+  if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+    setBranchWeights(BI, MinItersBypassWeights);
+  ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
   LoopBypassBlocks.push_back(TCCheckBlock);
 }
 
@@ -3133,7 +3157,16 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() {
     IRBuilder<> B(LoopMiddleBlock->getTerminator());
     B.SetCurrentDebugLocation(ScalarLatchTerm->getDebugLoc());
     Value *CmpN = B.CreateICmpEQ(Count, VectorTripCount, "cmp.n");
-    cast<BranchInst>(LoopMiddleBlock->getTerminator())->setCondition(CmpN);
+    BranchInst &BI = *cast<BranchInst>(LoopMiddleBlock->getTerminator());
+    BI.setCondition(CmpN);
+    if (hasBranchWeightMD(*ScalarLatchTerm)) {
+      // Assume that `Count % VectorTripCount` is equally distributed.
+      unsigned TripCount = UF * VF.getKnownMinValue();
+      assert(TripCount > 0 && "trip count should not be zero");
+      MDBuilder MDB(ScalarLatchTerm->getContext());
+      MDNode *BranchWeights = MDB.createBranchWeights(1, TripCount - 1);
+      BI.setMetadata(LLVMContext::MD_prof, BranchWeights);
+    }
   }
 
 #ifdef EXPENSIVE_CHECKS
@@ -7896,9 +7929,11 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass,
     EPI.TripCount = Count;
   }
 
-  ReplaceInstWithInst(
-      TCCheckBlock->getTerminator(),
-      BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+  BranchInst &BI =
+      *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+  if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+    setBranchWeights(BI, MinItersBypassWeights);
+  ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
 
   return TCCheckBlock;
 }
@@ -8042,9 +8077,11 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck(
                                          EPI.EpilogueVF, EPI.EpilogueUF),
                          "min.epilog.iters.check");
 
-  ReplaceInstWithInst(
-      Insert->getTerminator(),
-      BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+  BranchInst &BI =
+      *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+  if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+    setBranchWeights(BI, EpilogueMinItersBypassWeights);
+  ReplaceInstWithInst(Insert->getTerminator(), &BI);
 
   LoopBypassBlocks.push_back(Insert);
   return Insert;
@@ -9731,8 +9768,10 @@ static bool processLoopInVPlanNativePath(
   VPlan &BestPlan = LVP.getBestPlanFor(VF.Width);
 
   {
+    bool AddBranchWeights =
+        hasBranchWeightMD(*L->ge...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 15, 2023

@llvm/pr-subscribers-llvm-ir

Author: Matthias Braun (MatzeB)

Changes

Consistently add branch_weights metadata in any condition branch
created from LoopVectorize.cpp:

  • Will only add metadata if the original loop-latch branch had metadata
    assigned.
  • Most checks should rarely trigger so I am using a 127:1 ratio.
  • For the middle block we assume an equal distribution of modulo
    results.

Patch is 44.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72450.diff

14 Files Affected:

  • (modified) llvm/include/llvm/IR/ProfDataUtils.h (+4)
  • (modified) llvm/lib/IR/ProfDataUtils.cpp (+7)
  • (modified) llvm/lib/Transforms/IPO/SampleProfile.cpp (+7-7)
  • (modified) llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp (+1-2)
  • (modified) llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp (+2-4)
  • (modified) llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp (+3-5)
  • (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+7-11)
  • (modified) llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp (+2-4)
  • (modified) llvm/lib/Transforms/Utils/Local.cpp (+1-3)
  • (modified) llvm/lib/Transforms/Utils/LoopPeel.cpp (+4-13)
  • (modified) llvm/lib/Transforms/Utils/LoopRotationUtils.cpp (+10-9)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+60-19)
  • (added) llvm/test/Transforms/LoopVectorize/branch-weights.ll (+81)
  • (modified) llvm/test/Transforms/LoopVectorize/first-order-recurrence.ll (+30-30)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b61199372de0de8..255fa2ff1c79065 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
 /// metadata was found.
 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
 
+/// Create a new `branch_weights` metadata node and add or overwrite
+/// a `prof` metadata reference to instruction `I`.
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
+
 } // namespace llvm
 #endif
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 77b3c1cb95d686c..29536b0b090cd76 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
@@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
 }
 
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+  MDBuilder MDB(I.getContext());
+  MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+  I.setMetadata(LLVMContext::MD_prof, BranchWeights);
+}
+
 } // namespace llvm
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 063f7b42022ff83..6c6f0a0eca72a7a 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -56,6 +56,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/PseudoProbe.h"
 #include "llvm/IR/ValueSymbolTable.h"
 #include "llvm/ProfileData/InstrProf.h"
@@ -1710,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           else if (OverwriteExistingWeights)
             I.setMetadata(LLVMContext::MD_prof, nullptr);
         } else if (!isa<IntrinsicInst>(&I)) {
-          I.setMetadata(LLVMContext::MD_prof,
-                        MDB.createBranchWeights(
-                            {static_cast<uint32_t>(BlockWeights[BB])}));
+          setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
         }
       }
     } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1720,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
       // clear it for cold code.
       for (auto &I : *BB) {
         if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
-          if (cast<CallBase>(I).isIndirectCall())
+          if (cast<CallBase>(I).isIndirectCall()) {
             I.setMetadata(LLVMContext::MD_prof, nullptr);
-          else
-            I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
+          } else {
+            setBranchWeights(I, {uint32_t(0)});
+          }
         }
       }
     }
@@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
     if (MaxWeight > 0 &&
         (!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
       LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
-      TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+      setBranchWeights(*TI, Weights);
       ORE->emit([&]() {
         return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
                << "most popular destination for conditional branches at "
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 97cf583510f9395..0a3d8d6000cf47d 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
       static_cast<uint32_t>(CHRBranchBias.scale(1000)),
       static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
   };
-  MDBuilder MDB(F.getContext());
-  MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*MergedBR, Weights);
   CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
             << "\n");
 }
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index 5c9799235017a8a..7344fea17517191 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -26,6 +26,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Value.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/Support/Casting.h"
@@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
       promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
 
   if (AttachProfToDirectCall) {
-    MDBuilder MDB(NewInst.getContext());
-    NewInst.setMetadata(
-        LLVMContext::MD_prof,
-        MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
+    setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
   }
 
   using namespace ore;
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index aea0f2f7cae7786..4a5a0b25bebbaf1 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
     // If A is uncovered, set weight=1.
     // This setup will allow BFI to give nonzero profile counts to only covered
     // blocks.
-    SmallVector<unsigned, 4> Weights;
+    SmallVector<uint32_t, 4> Weights;
     for (auto *Succ : successors(&BB))
       Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
     if (Weights.size() >= 2)
-      BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
-                                      MDB.createBranchWeights(Weights));
+      llvm::setBranchWeights(*BB.getTerminator(), Weights);
   }
 
   unsigned NumCorruptCoverage = 0;
@@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
 
 void llvm::setProfMetadata(Module *M, Instruction *TI,
                            ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
-  MDBuilder MDB(M->getContext());
   assert(MaxCount > 0 && "Bad max count");
   uint64_t Scale = calculateCountScale(MaxCount);
   SmallVector<unsigned, 4> Weights;
@@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
 
   misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
 
-  TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*TI, Weights);
   if (EmitBranchProbability) {
     std::string BrCondStr = getBranchCondString(TI);
     if (BrCondStr.empty())
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 7a8128c5b6c0901..2d899f100f8154d 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     if (BP >= BranchProbability(50, 100))
       continue;
 
-    SmallVector<uint32_t, 2> Weights;
+    uint32_t Weights[2];
     if (PredBr->getSuccessor(0) == PredOutEdge.second) {
-      Weights.push_back(BP.getNumerator());
-      Weights.push_back(BP.getCompl().getNumerator());
+      Weights[0] = BP.getNumerator();
+      Weights[1] = BP.getCompl().getNumerator();
     } else {
-      Weights.push_back(BP.getCompl().getNumerator());
-      Weights.push_back(BP.getNumerator());
+      Weights[0] = BP.getCompl().getNumerator();
+      Weights[1] = BP.getNumerator();
     }
-    PredBr->setMetadata(LLVMContext::MD_prof,
-                        MDBuilder(PredBr->getParent()->getContext())
-                            .createBranchWeights(Weights));
+    setBranchWeights(*PredBr, Weights);
   }
 }
 
@@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
       Weights.push_back(Prob.getNumerator());
 
     auto TI = BB->getTerminator();
-    TI->setMetadata(
-        LLVMContext::MD_prof,
-        MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
+    setBranchWeights(*TI, Weights);
   }
 }
 
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index ac87ee736c0d169..6f87e4d91d2c794 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Transforms/Utils/MisExpect.h"
 
@@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
   misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
 
   SI.setCondition(ArgValue);
-
-  SI.setMetadata(LLVMContext::MD_prof,
-                 MDBuilder(CI->getContext()).createBranchWeights(Weights));
-
+  setBranchWeights(SI, Weights);
   return true;
 }
 
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index aacf66bfe38eb91..c8c8d02cdcbb2b3 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
           // Remove weight for this case.
           std::swap(Weights[Idx + 1], Weights.back());
           Weights.pop_back();
-          SI->setMetadata(LLVMContext::MD_prof,
-                          MDBuilder(BB->getContext()).
-                          createBranchWeights(Weights));
+          setBranchWeights(*SI, Weights);
         }
         // Remove this entry.
         BasicBlock *ParentBB = SI->getParent();
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 2881444206b0bb5..7566f70661baf48 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -631,9 +631,7 @@ struct WeightInfo {
 /// To avoid dealing with division rounding we can just multiple both part
 /// of weights to E and use weight as (F - I * E, E).
 static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
+  setBranchWeights(*Term, Info.Weights);
   for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
     if (SubWeight != 0)
       // Don't set the probability of taking the edge from latch to loop header
@@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
   }
 }
 
-/// Update the weights of original exiting block after peeling off all
-/// iterations.
-static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
-}
-
 /// Clones the body of the loop L, putting it between \p InsertTop and \p
 /// InsertBot.
 /// \param IterNumber The serial number of the iteration currently being
@@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
     PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
   }
 
-  for (const auto &[Term, Info] : Weights)
-    fixupBranchWeights(Term, Info);
+  for (const auto &[Term, Info] : Weights) {
+    setBranchWeights(*Term, Info.Weights);
+  }
 
   // Update Metadata for count of peeled off iterations.
   unsigned AlreadyPeeled = 0;
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 012aa5dbb9ca004..ae155ac082d8111 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     LoopBackWeight = 0;
   }
 
-  MDBuilder MDB(LoopBI.getContext());
-  MDNode *LoopWeightMD =
-      MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
-                              SuccsSwapped ? ExitWeight1 : LoopBackWeight);
-  LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
+  const uint32_t LoopBIWeights[] = {
+      SuccsSwapped ? LoopBackWeight : ExitWeight1,
+      SuccsSwapped ? ExitWeight1 : LoopBackWeight,
+  };
+  setBranchWeights(LoopBI, LoopBIWeights);
   if (HasConditionalPreHeader) {
-    MDNode *PreHeaderWeightMD =
-        MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
-                                SuccsSwapped ? ExitWeight0 : EnterWeight);
-    PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
+    const uint32_t PreHeaderBIWeights[] = {
+        SuccsSwapped ? EnterWeight : ExitWeight0,
+        SuccsSwapped ? ExitWeight0 : EnterWeight,
+    };
+    setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
   }
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e9d0315d114f65c..0a10f0d6471769a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -112,10 +112,12 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -396,6 +398,19 @@ static cl::opt<bool> UseWiderVFIfCallVariantsPresent(
     cl::Hidden,
     cl::desc("Try wider VFs if they enable the use of vector variants"));
 
+// Likelyhood of bypassing the vectorized loop because assumptions about SCEV
+// variables not overflowing do not hold. See `emitSCEVChecks`.
+static constexpr uint32_t SCEVCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because pointers overlap. See
+// `emitMemRuntimeChecks`.
+static constexpr uint32_t MemCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because there are zero trips left
+// after prolog. See `emitIterationCountCheck`.
+static constexpr uint32_t MinItersBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because of zero trips necessary.
+// See `emitMinimumVectorEpilogueIterCountCheck`.
+static constexpr uint32_t EpilogueMinItersBypassWeights[] = {1, 127};
+
 /// A helper function that returns true if the given type is irregular. The
 /// type is irregular if its allocated size doesn't equal the store size of an
 /// element of the corresponding vector type.
@@ -1962,12 +1977,14 @@ class GeneratedRTChecks {
   SCEVExpander MemCheckExp;
 
   bool CostTooHigh = false;
+  const bool AddBranchWeights;
 
 public:
   GeneratedRTChecks(ScalarEvolution &SE, DominatorTree *DT, LoopInfo *LI,
-                    TargetTransformInfo *TTI, const DataLayout &DL)
+                    TargetTransformInfo *TTI, const DataLayout &DL,
+                    bool AddBranchWeights)
       : DT(DT), LI(LI), TTI(TTI), SCEVExp(SE, DL, "scev.check"),
-        MemCheckExp(SE, DL, "scev.check") {}
+        MemCheckExp(SE, DL, "scev.check"), AddBranchWeights(AddBranchWeights) {}
 
   /// Generate runtime checks in SCEVCheckBlock and MemCheckBlock, so we can
   /// accurately estimate the cost of the runtime checks. The blocks are
@@ -2160,8 +2177,10 @@ class GeneratedRTChecks {
     DT->addNewBlock(SCEVCheckBlock, Pred);
     DT->changeImmediateDominator(LoopVectorPreHeader, SCEVCheckBlock);
 
-    ReplaceInstWithInst(SCEVCheckBlock->getTerminator(),
-                        BranchInst::Create(Bypass, LoopVectorPreHeader, Cond));
+    BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, Cond);
+    if (AddBranchWeights)
+      setBranchWeights(BI, SCEVCheckBypassWeights);
+    ReplaceInstWithInst(SCEVCheckBlock->getTerminator(), &BI);
     return SCEVCheckBlock;
   }
 
@@ -2185,9 +2204,12 @@ class GeneratedRTChecks {
     if (auto *PL = LI->getLoopFor(LoopVectorPreHeader))
       PL->addBasicBlockToLoop(MemCheckBlock, *LI);
 
-    ReplaceInstWithInst(
-        MemCheckBlock->getTerminator(),
-        BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond));
+    BranchInst &BI =
+        *BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond);
+    if (AddBranchWeights) {
+      setBranchWeights(BI, MemCheckBypassWeights);
+    }
+    ReplaceInstWithInst(MemCheckBlock->getTerminator(), &BI);
     MemCheckBlock->getTerminator()->setDebugLoc(
         Pred->getTerminator()->getDebugLoc());
 
@@ -2900,9 +2922,11 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
     // dominator of the exit blocks.
     DT->changeImmediateDominator(LoopExitBlock, TCCheckBlock);
 
-  ReplaceInstWithInst(
-      TCCheckBlock->getTerminator(),
-      BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+  BranchInst &BI =
+      *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+  if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+    setBranchWeights(BI, MinItersBypassWeights);
+  ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
   LoopBypassBlocks.push_back(TCCheckBlock);
 }
 
@@ -3133,7 +3157,16 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() {
     IRBuilder<> B(LoopMiddleBlock->getTerminator());
     B.SetCurrentDebugLocation(ScalarLatchTerm->getDebugLoc());
     Value *CmpN = B.CreateICmpEQ(Count, VectorTripCount, "cmp.n");
-    cast<BranchInst>(LoopMiddleBlock->getTerminator())->setCondition(CmpN);
+    BranchInst &BI = *cast<BranchInst>(LoopMiddleBlock->getTerminator());
+    BI.setCondition(CmpN);
+    if (hasBranchWeightMD(*ScalarLatchTerm)) {
+      // Assume that `Count % VectorTripCount` is equally distributed.
+      unsigned TripCount = UF * VF.getKnownMinValue();
+      assert(TripCount > 0 && "trip count should not be zero");
+      MDBuilder MDB(ScalarLatchTerm->getContext());
+      MDNode *BranchWeights = MDB.createBranchWeights(1, TripCount - 1);
+      BI.setMetadata(LLVMContext::MD_prof, BranchWeights);
+    }
   }
 
 #ifdef EXPENSIVE_CHECKS
@@ -7896,9 +7929,11 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass,
     EPI.TripCount = Count;
   }
 
-  ReplaceInstWithInst(
-      TCCheckBlock->getTerminator(),
-      BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+  BranchInst &BI =
+      *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+  if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+    setBranchWeights(BI, MinItersBypassWeights);
+  ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
 
   return TCCheckBlock;
 }
@@ -8042,9 +8077,11 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck(
                                          EPI.EpilogueVF, EPI.EpilogueUF),
                          "min.epilog.iters.check");
 
-  ReplaceInstWithInst(
-      Insert->getTerminator(),
-      BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+  BranchInst &BI =
+      *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+  if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+    setBranchWeights(BI, EpilogueMinItersBypassWeights);
+  ReplaceInstWithInst(Insert->getTerminator(), &BI);
 
   LoopBypassBlocks.push_back(Insert);
   return Insert;
@@ -9731,8 +9768,10 @@ static bool processLoopInVPlanNativePath(
   VPlan &BestPlan = LVP.getBestPlanFor(VF.Width);
 
   {
+    bool AddBranchWeights =
+        hasBranchWeightMD(*L->ge...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 15, 2023

@llvm/pr-subscribers-pgo

Author: Matthias Braun (MatzeB)

Changes

Consistently add branch_weights metadata in any condition branch
created from LoopVectorize.cpp:

  • Will only add metadata if the original loop-latch branch had metadata
    assigned.
  • Most checks should rarely trigger so I am using a 127:1 ratio.
  • For the middle block we assume an equal distribution of modulo
    results.

Patch is 44.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72450.diff

14 Files Affected:

  • (modified) llvm/include/llvm/IR/ProfDataUtils.h (+4)
  • (modified) llvm/lib/IR/ProfDataUtils.cpp (+7)
  • (modified) llvm/lib/Transforms/IPO/SampleProfile.cpp (+7-7)
  • (modified) llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp (+1-2)
  • (modified) llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp (+2-4)
  • (modified) llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp (+3-5)
  • (modified) llvm/lib/Transforms/Scalar/JumpThreading.cpp (+7-11)
  • (modified) llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp (+2-4)
  • (modified) llvm/lib/Transforms/Utils/Local.cpp (+1-3)
  • (modified) llvm/lib/Transforms/Utils/LoopPeel.cpp (+4-13)
  • (modified) llvm/lib/Transforms/Utils/LoopRotationUtils.cpp (+10-9)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+60-19)
  • (added) llvm/test/Transforms/LoopVectorize/branch-weights.ll (+81)
  • (modified) llvm/test/Transforms/LoopVectorize/first-order-recurrence.ll (+30-30)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b61199372de0de8..255fa2ff1c79065 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
 /// metadata was found.
 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
 
+/// Create a new `branch_weights` metadata node and add or overwrite
+/// a `prof` metadata reference to instruction `I`.
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
+
 } // namespace llvm
 #endif
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index 77b3c1cb95d686c..29536b0b090cd76 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/CommandLine.h"
@@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
 }
 
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+  MDBuilder MDB(I.getContext());
+  MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+  I.setMetadata(LLVMContext::MD_prof, BranchWeights);
+}
+
 } // namespace llvm
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 063f7b42022ff83..6c6f0a0eca72a7a 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -56,6 +56,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/PseudoProbe.h"
 #include "llvm/IR/ValueSymbolTable.h"
 #include "llvm/ProfileData/InstrProf.h"
@@ -1710,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           else if (OverwriteExistingWeights)
             I.setMetadata(LLVMContext::MD_prof, nullptr);
         } else if (!isa<IntrinsicInst>(&I)) {
-          I.setMetadata(LLVMContext::MD_prof,
-                        MDB.createBranchWeights(
-                            {static_cast<uint32_t>(BlockWeights[BB])}));
+          setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
         }
       }
     } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1720,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
       // clear it for cold code.
       for (auto &I : *BB) {
         if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
-          if (cast<CallBase>(I).isIndirectCall())
+          if (cast<CallBase>(I).isIndirectCall()) {
             I.setMetadata(LLVMContext::MD_prof, nullptr);
-          else
-            I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
+          } else {
+            setBranchWeights(I, {uint32_t(0)});
+          }
         }
       }
     }
@@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
     if (MaxWeight > 0 &&
         (!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
       LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
-      TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+      setBranchWeights(*TI, Weights);
       ORE->emit([&]() {
         return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
                << "most popular destination for conditional branches at "
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 97cf583510f9395..0a3d8d6000cf47d 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
       static_cast<uint32_t>(CHRBranchBias.scale(1000)),
       static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
   };
-  MDBuilder MDB(F.getContext());
-  MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*MergedBR, Weights);
   CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
             << "\n");
 }
diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index 5c9799235017a8a..7344fea17517191 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -26,6 +26,7 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Value.h"
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/Support/Casting.h"
@@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
       promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
 
   if (AttachProfToDirectCall) {
-    MDBuilder MDB(NewInst.getContext());
-    NewInst.setMetadata(
-        LLVMContext::MD_prof,
-        MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
+    setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
   }
 
   using namespace ore;
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index aea0f2f7cae7786..4a5a0b25bebbaf1 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
     // If A is uncovered, set weight=1.
     // This setup will allow BFI to give nonzero profile counts to only covered
     // blocks.
-    SmallVector<unsigned, 4> Weights;
+    SmallVector<uint32_t, 4> Weights;
     for (auto *Succ : successors(&BB))
       Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
     if (Weights.size() >= 2)
-      BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
-                                      MDB.createBranchWeights(Weights));
+      llvm::setBranchWeights(*BB.getTerminator(), Weights);
   }
 
   unsigned NumCorruptCoverage = 0;
@@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
 
 void llvm::setProfMetadata(Module *M, Instruction *TI,
                            ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
-  MDBuilder MDB(M->getContext());
   assert(MaxCount > 0 && "Bad max count");
   uint64_t Scale = calculateCountScale(MaxCount);
   SmallVector<unsigned, 4> Weights;
@@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
 
   misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
 
-  TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
+  setBranchWeights(*TI, Weights);
   if (EmitBranchProbability) {
     std::string BrCondStr = getBranchCondString(TI);
     if (BrCondStr.empty())
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 7a8128c5b6c0901..2d899f100f8154d 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
     if (BP >= BranchProbability(50, 100))
       continue;
 
-    SmallVector<uint32_t, 2> Weights;
+    uint32_t Weights[2];
     if (PredBr->getSuccessor(0) == PredOutEdge.second) {
-      Weights.push_back(BP.getNumerator());
-      Weights.push_back(BP.getCompl().getNumerator());
+      Weights[0] = BP.getNumerator();
+      Weights[1] = BP.getCompl().getNumerator();
     } else {
-      Weights.push_back(BP.getCompl().getNumerator());
-      Weights.push_back(BP.getNumerator());
+      Weights[0] = BP.getCompl().getNumerator();
+      Weights[1] = BP.getNumerator();
     }
-    PredBr->setMetadata(LLVMContext::MD_prof,
-                        MDBuilder(PredBr->getParent()->getContext())
-                            .createBranchWeights(Weights));
+    setBranchWeights(*PredBr, Weights);
   }
 }
 
@@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
       Weights.push_back(Prob.getNumerator());
 
     auto TI = BB->getTerminator();
-    TI->setMetadata(
-        LLVMContext::MD_prof,
-        MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
+    setBranchWeights(*TI, Weights);
   }
 }
 
diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
index ac87ee736c0d169..6f87e4d91d2c794 100644
--- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Transforms/Utils/MisExpect.h"
 
@@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
   misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
 
   SI.setCondition(ArgValue);
-
-  SI.setMetadata(LLVMContext::MD_prof,
-                 MDBuilder(CI->getContext()).createBranchWeights(Weights));
-
+  setBranchWeights(SI, Weights);
   return true;
 }
 
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index aacf66bfe38eb91..c8c8d02cdcbb2b3 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
           // Remove weight for this case.
           std::swap(Weights[Idx + 1], Weights.back());
           Weights.pop_back();
-          SI->setMetadata(LLVMContext::MD_prof,
-                          MDBuilder(BB->getContext()).
-                          createBranchWeights(Weights));
+          setBranchWeights(*SI, Weights);
         }
         // Remove this entry.
         BasicBlock *ParentBB = SI->getParent();
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 2881444206b0bb5..7566f70661baf48 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -631,9 +631,7 @@ struct WeightInfo {
 /// To avoid dealing with division rounding we can just multiple both part
 /// of weights to E and use weight as (F - I * E, E).
 static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
+  setBranchWeights(*Term, Info.Weights);
   for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
     if (SubWeight != 0)
       // Don't set the probability of taking the edge from latch to loop header
@@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
   }
 }
 
-/// Update the weights of original exiting block after peeling off all
-/// iterations.
-static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
-  MDBuilder MDB(Term->getContext());
-  Term->setMetadata(LLVMContext::MD_prof,
-                    MDB.createBranchWeights(Info.Weights));
-}
-
 /// Clones the body of the loop L, putting it between \p InsertTop and \p
 /// InsertBot.
 /// \param IterNumber The serial number of the iteration currently being
@@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
     PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
   }
 
-  for (const auto &[Term, Info] : Weights)
-    fixupBranchWeights(Term, Info);
+  for (const auto &[Term, Info] : Weights) {
+    setBranchWeights(*Term, Info.Weights);
+  }
 
   // Update Metadata for count of peeled off iterations.
   unsigned AlreadyPeeled = 0;
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 012aa5dbb9ca004..ae155ac082d8111 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     LoopBackWeight = 0;
   }
 
-  MDBuilder MDB(LoopBI.getContext());
-  MDNode *LoopWeightMD =
-      MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
-                              SuccsSwapped ? ExitWeight1 : LoopBackWeight);
-  LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
+  const uint32_t LoopBIWeights[] = {
+      SuccsSwapped ? LoopBackWeight : ExitWeight1,
+      SuccsSwapped ? ExitWeight1 : LoopBackWeight,
+  };
+  setBranchWeights(LoopBI, LoopBIWeights);
   if (HasConditionalPreHeader) {
-    MDNode *PreHeaderWeightMD =
-        MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
-                                SuccsSwapped ? ExitWeight0 : EnterWeight);
-    PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
+    const uint32_t PreHeaderBIWeights[] = {
+        SuccsSwapped ? EnterWeight : ExitWeight0,
+        SuccsSwapped ? ExitWeight0 : EnterWeight,
+    };
+    setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
   }
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index e9d0315d114f65c..0a10f0d6471769a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -112,10 +112,12 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Use.h"
 #include "llvm/IR/User.h"
@@ -396,6 +398,19 @@ static cl::opt<bool> UseWiderVFIfCallVariantsPresent(
     cl::Hidden,
     cl::desc("Try wider VFs if they enable the use of vector variants"));
 
+// Likelyhood of bypassing the vectorized loop because assumptions about SCEV
+// variables not overflowing do not hold. See `emitSCEVChecks`.
+static constexpr uint32_t SCEVCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because pointers overlap. See
+// `emitMemRuntimeChecks`.
+static constexpr uint32_t MemCheckBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because there are zero trips left
+// after prolog. See `emitIterationCountCheck`.
+static constexpr uint32_t MinItersBypassWeights[] = {1, 127};
+// Likelyhood of bypassing the vectorized loop because of zero trips necessary.
+// See `emitMinimumVectorEpilogueIterCountCheck`.
+static constexpr uint32_t EpilogueMinItersBypassWeights[] = {1, 127};
+
 /// A helper function that returns true if the given type is irregular. The
 /// type is irregular if its allocated size doesn't equal the store size of an
 /// element of the corresponding vector type.
@@ -1962,12 +1977,14 @@ class GeneratedRTChecks {
   SCEVExpander MemCheckExp;
 
   bool CostTooHigh = false;
+  const bool AddBranchWeights;
 
 public:
   GeneratedRTChecks(ScalarEvolution &SE, DominatorTree *DT, LoopInfo *LI,
-                    TargetTransformInfo *TTI, const DataLayout &DL)
+                    TargetTransformInfo *TTI, const DataLayout &DL,
+                    bool AddBranchWeights)
       : DT(DT), LI(LI), TTI(TTI), SCEVExp(SE, DL, "scev.check"),
-        MemCheckExp(SE, DL, "scev.check") {}
+        MemCheckExp(SE, DL, "scev.check"), AddBranchWeights(AddBranchWeights) {}
 
   /// Generate runtime checks in SCEVCheckBlock and MemCheckBlock, so we can
   /// accurately estimate the cost of the runtime checks. The blocks are
@@ -2160,8 +2177,10 @@ class GeneratedRTChecks {
     DT->addNewBlock(SCEVCheckBlock, Pred);
     DT->changeImmediateDominator(LoopVectorPreHeader, SCEVCheckBlock);
 
-    ReplaceInstWithInst(SCEVCheckBlock->getTerminator(),
-                        BranchInst::Create(Bypass, LoopVectorPreHeader, Cond));
+    BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, Cond);
+    if (AddBranchWeights)
+      setBranchWeights(BI, SCEVCheckBypassWeights);
+    ReplaceInstWithInst(SCEVCheckBlock->getTerminator(), &BI);
     return SCEVCheckBlock;
   }
 
@@ -2185,9 +2204,12 @@ class GeneratedRTChecks {
     if (auto *PL = LI->getLoopFor(LoopVectorPreHeader))
       PL->addBasicBlockToLoop(MemCheckBlock, *LI);
 
-    ReplaceInstWithInst(
-        MemCheckBlock->getTerminator(),
-        BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond));
+    BranchInst &BI =
+        *BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond);
+    if (AddBranchWeights) {
+      setBranchWeights(BI, MemCheckBypassWeights);
+    }
+    ReplaceInstWithInst(MemCheckBlock->getTerminator(), &BI);
     MemCheckBlock->getTerminator()->setDebugLoc(
         Pred->getTerminator()->getDebugLoc());
 
@@ -2900,9 +2922,11 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
     // dominator of the exit blocks.
     DT->changeImmediateDominator(LoopExitBlock, TCCheckBlock);
 
-  ReplaceInstWithInst(
-      TCCheckBlock->getTerminator(),
-      BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+  BranchInst &BI =
+      *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+  if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+    setBranchWeights(BI, MinItersBypassWeights);
+  ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
   LoopBypassBlocks.push_back(TCCheckBlock);
 }
 
@@ -3133,7 +3157,16 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() {
     IRBuilder<> B(LoopMiddleBlock->getTerminator());
     B.SetCurrentDebugLocation(ScalarLatchTerm->getDebugLoc());
     Value *CmpN = B.CreateICmpEQ(Count, VectorTripCount, "cmp.n");
-    cast<BranchInst>(LoopMiddleBlock->getTerminator())->setCondition(CmpN);
+    BranchInst &BI = *cast<BranchInst>(LoopMiddleBlock->getTerminator());
+    BI.setCondition(CmpN);
+    if (hasBranchWeightMD(*ScalarLatchTerm)) {
+      // Assume that `Count % VectorTripCount` is equally distributed.
+      unsigned TripCount = UF * VF.getKnownMinValue();
+      assert(TripCount > 0 && "trip count should not be zero");
+      MDBuilder MDB(ScalarLatchTerm->getContext());
+      MDNode *BranchWeights = MDB.createBranchWeights(1, TripCount - 1);
+      BI.setMetadata(LLVMContext::MD_prof, BranchWeights);
+    }
   }
 
 #ifdef EXPENSIVE_CHECKS
@@ -7896,9 +7929,11 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass,
     EPI.TripCount = Count;
   }
 
-  ReplaceInstWithInst(
-      TCCheckBlock->getTerminator(),
-      BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+  BranchInst &BI =
+      *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+  if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+    setBranchWeights(BI, MinItersBypassWeights);
+  ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI);
 
   return TCCheckBlock;
 }
@@ -8042,9 +8077,11 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck(
                                          EPI.EpilogueVF, EPI.EpilogueUF),
                          "min.epilog.iters.check");
 
-  ReplaceInstWithInst(
-      Insert->getTerminator(),
-      BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters));
+  BranchInst &BI =
+      *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters);
+  if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
+    setBranchWeights(BI, EpilogueMinItersBypassWeights);
+  ReplaceInstWithInst(Insert->getTerminator(), &BI);
 
   LoopBypassBlocks.push_back(Insert);
   return Insert;
@@ -9731,8 +9768,10 @@ static bool processLoopInVPlanNativePath(
   VPlan &BestPlan = LVP.getBestPlanFor(VF.Width);
 
   {
+    bool AddBranchWeights =
+        hasBranchWeightMD(*L->ge...
[truncated]

Copy link
Member

@WenleiHe WenleiHe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm with a suggestion.. thanks for the refactoring (and it seems less invasive than I thought :))

llvm/lib/Transforms/IPO/SampleProfile.cpp Show resolved Hide resolved
@MatzeB MatzeB merged commit a9cc6fc into llvm:main Nov 16, 2023
8 checks passed
Copy link
Contributor

@david-xl david-xl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any performance data with this change (as there are heuristic based defaults used)?

static constexpr uint32_t MinItersBypassWeights[] = {1, 127};
// Likelyhood of bypassing the vectorized loop because of zero trips necessary.
// See `emitMinimumVectorEpilogueIterCountCheck`.
static constexpr uint32_t EpilogueMinItersBypassWeights[] = {1, 127};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are the default values selected? 127 seems too large as default e.g.for the epilogue minBy pass. The average trip count of the remainder loop is 0.5mainUFmainVF, how likely it is larger than EVF*EUF?

Copy link
Contributor Author

@MatzeB MatzeB Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I may have confused this particular check with the very similar MinItersBypassWeights which tests whether the total amount of iterations is smaller than the loop count, which is something we cannot easily compute. But I think you are right in that this one is indeed the remainder of iterations after the vectorized body ran. Computing this based on the tripcount similar to what I am already doing for InnerLoopVectorizer::completeLoopSkeleton() makes sense. Will send a follow-up.

The number "127" itself has no real justification except that we know it should be a "big" number and that not annotating anything leaves us with a bad 1:1 default branch_weights assumed by BranchProbabilityInfo. Whether the "truth" is closer to 64 or 256 or some other big number I don't know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also previous discussion in https://reviews.llvm.org/D159322

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Submitted the follow-up as PR #72589

@MatzeB
Copy link
Contributor Author

MatzeB commented Nov 16, 2023

There is no measurable performance change from that on our end. I am merely hunting down conditional branches without branch_weight annotations which I think we can agree are never desirable in a function that has otherwise profile annotations.

@david-xl
Copy link
Contributor

There is no measurable performance change from that on our end. I am merely hunting down conditional branches without branch_weight annotations which I think we can agree are never desirable in a function that has otherwise profile annotations.

Agree in general, but there might be undesirable outcome in some cases. There are 3 types of missing annotations: 1) real profile gets dropped; 2) real profile is unknown (due to lack of flow sensitivity) for a clone, but can be reasonablily guessed with scaling; and 3) new branches that require static heuristic to guess. It is the 3) category that can be problematic. In the code, I added a question on the default values.

@MatzeB
Copy link
Contributor Author

MatzeB commented Nov 16, 2023

I agree that this is 3) in your categories and it's a guessing game where we likely can be way off depending on program and program inputs. However I also do not see any immediately better answers for this right now.

Though for some context: I am currently comparing estimated execution counts with actual counts and this change turns llvm-test-suite/SingleSource/Benchmarks/Shootout/ary3.c (which admittedly is a really small microbenchmark) from misjudging the counts by 2x to misjudging by 2% ...

MatzeB added a commit to MatzeB/llvm-project that referenced this pull request Nov 17, 2023
This is a follow-up to PR llvm#72450 correcting the branch_weights used
for the test whether the vectorized epilogue loop should be skipped.
sr-tream pushed a commit to sr-tream/llvm-project that referenced this pull request Nov 20, 2023
Consistently add `branch_weights` metadata in any condition branch
created by `LoopVectorize.cpp`:
- Will only add metadata if the original loop-latch branch had metadata
assigned.
- Most checks should rarely trigger so I am using a 127:1 ratio.
- For the middle block we assume an equal distribution of modulo
results.
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
Consistently add `branch_weights` metadata in any condition branch
created by `LoopVectorize.cpp`:
- Will only add metadata if the original loop-latch branch had metadata
assigned.
- Most checks should rarely trigger so I am using a 127:1 ratio.
- For the middle block we assume an equal distribution of modulo
results.
MatzeB added a commit that referenced this pull request Nov 20, 2023
…#72589)

This is a follow-up to PR #72450 correcting the branch_weights used
for the test whether the vectorized epilogue loop should be skipped.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants