Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid BlockFrequency overflow problems #66280

Merged
merged 1 commit into from
Sep 14, 2023
Merged

Conversation

MatzeB
Copy link
Contributor

@MatzeB MatzeB commented Sep 13, 2023

Multiplying raw block frequency with an integer carries a high risk of overflow.

  • Introduce a new BlockFrequency::mul function returning a bool indicating overflow.
  • Mark function with __attribute__((warn_unused_result)) to avoid users accidentally ignoring the indicator.
  • Fix two instances where overflow were leading to wrong results for me.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-llvm-analysis

Changes Multiplying raw block frequency with an integer carries a high risk of overflow.
  • Introduce a new BlockFrequency::mul function returning a bool indicating overflow.
  • Mark function with __attribute__((warn_unused_result)) to avoid users accidentally ignoring the indicator.
  • Fix two instances where overflow were leading to wrong results for me.
    --
    Full diff: https://github.com/llvm/llvm-project/pull/66280.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Support/BlockFrequency.h (+8)
  • (modified) llvm/include/llvm/Support/Compiler.h (+8)
  • (modified) llvm/lib/Analysis/InlineCost.cpp (+6-5)
  • (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+5-5)
  • (modified) llvm/lib/Support/BlockFrequency.cpp (+8)
diff --git a/llvm/include/llvm/Support/BlockFrequency.h b/llvm/include/llvm/Support/BlockFrequency.h
index 6c624d7dad7d801..1711fb592485b4c 100644
--- a/llvm/include/llvm/Support/BlockFrequency.h
+++ b/llvm/include/llvm/Support/BlockFrequency.h
@@ -16,6 +16,8 @@
 #include <cassert>
 #include <cstdint>
 
+#include "llvm/Support/Compiler.h"
+
 namespace llvm {
 
 class BranchProbability;
@@ -76,6 +78,12 @@ class BlockFrequency {
     return NewFreq;
   }
 
+  /// Multiplies frequency with `Factor` and stores the result into `Result`.
+  /// Returns `true` if an overflow occured. Overflows are common and should be
+  /// checked by all callers.
+  bool mul(uint64_t Factor,
+           BlockFrequency *Result) const LLVM_WARN_UNUSED_RESULT;
+
   /// Shift block frequency to the right by count digits saturating to 1.
   BlockFrequency &operator>>=(const unsigned count) {
     // Frequency can never be 0 by design.
diff --git a/llvm/include/llvm/Support/Compiler.h b/llvm/include/llvm/Support/Compiler.h
index 12afe90f8facd47..9527e377317ac33 100644
--- a/llvm/include/llvm/Support/Compiler.h
+++ b/llvm/include/llvm/Support/Compiler.h
@@ -269,6 +269,14 @@
 #define LLVM_ATTRIBUTE_RETURNS_NOALIAS
 #endif
 
+/// Mark a function whose return value should not be ignored. Doing so without
+/// a `[[maybe_unused]]` produces a warning if supported by the compiler.
+#if __has_attribute(warn_unused_result)
+#define LLVM_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
+#else
+#define LLVM_WARN_UNUSED_RESULT
+#endif
+
 /// LLVM_FALLTHROUGH - Mark fallthrough cases in switch statements.
 #if defined(__cplusplus) && __cplusplus > 201402L && LLVM_HAS_CPP_ATTRIBUTE(fallthrough)
 #define LLVM_FALLTHROUGH [[fallthrough]]
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index a9de1dde7c7f717..d921047d6466f52 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -118,7 +118,7 @@ static cl::opt<int> ColdCallSiteRelFreq(
              "entry frequency, for a callsite to be cold in the absence of "
              "profile information."));
 
-static cl::opt<int> HotCallSiteRelFreq(
+static cl::opt<uint64_t> HotCallSiteRelFreq(
     "hot-callsite-rel-freq", cl::Hidden, cl::init(60),
     cl::desc("Minimum block frequency, expressed as a multiple of caller's "
              "entry frequency, for a callsite to be hot in the absence of "
@@ -1820,10 +1820,11 @@ InlineCostCallAnalyzer::getHotCallSiteThreshold(CallBase &Call,
   // potentially cache the computation of scaled entry frequency, but the added
   // complexity is not worth it unless this scaling shows up high in the
   // profiles.
-  auto CallSiteBB = Call.getParent();
-  auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB).getFrequency();
-  auto CallerEntryFreq = CallerBFI->getEntryFreq();
-  if (CallSiteFreq >= CallerEntryFreq * HotCallSiteRelFreq)
+  const BasicBlock *CallSiteBB = Call.getParent();
+  BlockFrequency CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB);
+  BlockFrequency CallerEntryFreq = CallerBFI->getEntryFreq();
+  BlockFrequency Limit;
+  if (!CallerEntryFreq.mul(HotCallSiteRelFreq, &Limit) && CallSiteFreq >= Limit)
     return Params.LocallyHotCallSiteThreshold;
 
   // Otherwise treat it normally.
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index f07fc4fc52bffba..e24361c1f93970d 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -198,7 +198,7 @@ static cl::opt<bool> BBSectionsGuidedSectionPrefix(
              "impacted, i.e., their prefixes will be decided by FDO/sampleFDO "
              "profiles."));
 
-static cl::opt<unsigned> FreqRatioToSkipMerge(
+static cl::opt<uint64_t> FreqRatioToSkipMerge(
     "cgp-freq-ratio-to-skip-merge", cl::Hidden, cl::init(2),
     cl::desc("Skip merging empty blocks if (frequency of empty block) / "
              "(frequency of destination block) is greater than this ratio"));
@@ -978,16 +978,16 @@ bool CodeGenPrepare::isMergingEmptyBlockProfitable(BasicBlock *BB,
   if (SameIncomingValueBBs.count(Pred))
     return true;
 
-  BlockFrequency PredFreq = BFI->getBlockFreq(Pred);
-  BlockFrequency BBFreq = BFI->getBlockFreq(BB);
+  BlockFrequency PredFreq = BFI->getBlockFreq(Pred).getFrequency();
+  BlockFrequency BBFreq = BFI->getBlockFreq(BB).getFrequency();
 
   for (auto *SameValueBB : SameIncomingValueBBs)
     if (SameValueBB->getUniquePredecessor() == Pred &&
         DestBB == findDestBlockOfMergeableEmptyBlock(SameValueBB))
       BBFreq += BFI->getBlockFreq(SameValueBB);
 
-  return PredFreq.getFrequency() <=
-         BBFreq.getFrequency() * FreqRatioToSkipMerge;
+  BlockFrequency Limit;
+  return !BBFreq.mul(FreqRatioToSkipMerge, &Limit) && PredFreq <= Limit;
 }
 
 /// Return true if we can merge BB into DestBB if there is a single
diff --git a/llvm/lib/Support/BlockFrequency.cpp b/llvm/lib/Support/BlockFrequency.cpp
index a4a1e477d9403f7..08fe3ef6061ecae 100644
--- a/llvm/lib/Support/BlockFrequency.cpp
+++ b/llvm/lib/Support/BlockFrequency.cpp
@@ -12,6 +12,7 @@
 
 #include "llvm/Support/BlockFrequency.h"
 #include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/MathExtras.h"
 
 using namespace llvm;
 
@@ -36,3 +37,10 @@ BlockFrequency BlockFrequency::operator/(BranchProbability Prob) const {
   Freq /= Prob;
   return Freq;
 }
+
+bool BlockFrequency::mul(uint64_t Factor, BlockFrequency *Result) const {
+  bool Overflow;
+  uint64_t ResultFrequency = SaturatingMultiply(Frequency, Factor, &Overflow);
+  *Result = BlockFrequency(ResultFrequency);
+  return Overflow;
+}

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-llvm-analysis

Changes Multiplying raw block frequency with an integer carries a high risk of overflow.
  • Introduce a new BlockFrequency::mul function returning a bool indicating overflow.
  • Mark function with __attribute__((warn_unused_result)) to avoid users accidentally ignoring the indicator.
  • Fix two instances where overflow were leading to wrong results for me.
    --
    Full diff: https://github.com/llvm/llvm-project/pull/66280.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Support/BlockFrequency.h (+8)
  • (modified) llvm/include/llvm/Support/Compiler.h (+8)
  • (modified) llvm/lib/Analysis/InlineCost.cpp (+6-5)
  • (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+5-5)
  • (modified) llvm/lib/Support/BlockFrequency.cpp (+8)
diff --git a/llvm/include/llvm/Support/BlockFrequency.h b/llvm/include/llvm/Support/BlockFrequency.h
index 6c624d7dad7d801..1711fb592485b4c 100644
--- a/llvm/include/llvm/Support/BlockFrequency.h
+++ b/llvm/include/llvm/Support/BlockFrequency.h
@@ -16,6 +16,8 @@
#include <cassert>
#include <cstdint>

+#include "llvm/Support/Compiler.h"
+
namespace llvm {

class BranchProbability;
@@ -76,6 +78,12 @@ class BlockFrequency {
    return NewFreq;
  }

+  /// Multiplies frequency with `Factor` and stores the result into `Result`.
+  /// Returns `true` if an overflow occured. Overflows are common and should be
+  /// checked by all callers.
+  bool mul(uint64_t Factor,
+           BlockFrequency *Result) const LLVM_WARN_UNUSED_RESULT;
+
  /// Shift block frequency to the right by count digits saturating to 1.
  BlockFrequency &operator>>=(const unsigned count) {
    // Frequency can never be 0 by design.
diff --git a/llvm/include/llvm/Support/Compiler.h b/llvm/include/llvm/Support/Compiler.h
index 12afe90f8facd47..9527e377317ac33 100644
--- a/llvm/include/llvm/Support/Compiler.h
+++ b/llvm/include/llvm/Support/Compiler.h
@@ -269,6 +269,14 @@
#define LLVM_ATTRIBUTE_RETURNS_NOALIAS
#endif

+/// Mark a function whose return value should not be ignored. Doing so without
+/// a `[[maybe_unused]]` produces a warning if supported by the compiler.
+#if __has_attribute(warn_unused_result)
+#define LLVM_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
+#else
+#define LLVM_WARN_UNUSED_RESULT
+#endif
+
/// LLVM_FALLTHROUGH - Mark fallthrough cases in switch statements.
#if defined(__cplusplus) && __cplusplus > 201402L && LLVM_HAS_CPP_ATTRIBUTE(fallthrough)
#define LLVM_FALLTHROUGH [[fallthrough]]
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index a9de1dde7c7f717..d921047d6466f52 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -118,7 +118,7 @@ static cl::opt<int> ColdCallSiteRelFreq(
             "entry frequency, for a callsite to be cold in the absence of "
             "profile information."));

-static cl::opt<int> HotCallSiteRelFreq(
+static cl::opt<uint64_t> HotCallSiteRelFreq(
    "hot-callsite-rel-freq", cl::Hidden, cl::init(60),
    cl::desc("Minimum block frequency, expressed as a multiple of caller's "
             "entry frequency, for a callsite to be hot in the absence of "
@@ -1820,10 +1820,11 @@ InlineCostCallAnalyzer::getHotCallSiteThreshold(CallBase &Call,
  // potentially cache the computation of scaled entry frequency, but the added
  // complexity is not worth it unless this scaling shows up high in the
  // profiles.
-  auto CallSiteBB = Call.getParent();
-  auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB).getFrequency();
-  auto CallerEntryFreq = CallerBFI->getEntryFreq();
-  if (CallSiteFreq >= CallerEntryFreq * HotCallSiteRelFreq)
+  const BasicBlock *CallSiteBB = Call.getParent();
+  BlockFrequency CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB);
+  BlockFrequency CallerEntryFreq = CallerBFI->getEntryFreq();
+  BlockFrequency Limit;
+  if (!CallerEntryFreq.mul(HotCallSiteRelFreq, &Limit) && CallSiteFreq >= Limit)
    return Params.LocallyHotCallSiteThreshold;

  // Otherwise treat it normally.
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index f07fc4fc52bffba..e24361c1f93970d 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -198,7 +198,7 @@ static cl::opt<bool> BBSectionsGuidedSectionPrefix(
             "impacted, i.e., their prefixes will be decided by FDO/sampleFDO "
             "profiles."));

-static cl::opt<unsigned> FreqRatioToSkipMerge(
+static cl::opt<uint64_t> FreqRatioToSkipMerge(
    "cgp-freq-ratio-to-skip-merge", cl::Hidden, cl::init(2),
    cl::desc("Skip merging empty blocks if (frequency of empty block) / "
             "(frequency of destination block) is greater than this ratio"));
@@ -978,16 +978,16 @@ bool CodeGenPrepare::isMergingEmptyBlockProfitable(BasicBlock *BB,
  if (SameIncomingValueBBs.count(Pred))
    return true;

-  BlockFrequency PredFreq = BFI->getBlockFreq(Pred);
-  BlockFrequency BBFreq = BFI->getBlockFreq(BB);
+  BlockFrequency PredFreq = BFI->getBlockFreq(Pred).getFrequency();
+  BlockFrequency BBFreq = BFI->getBlockFreq(BB).getFrequency();

  for (auto *SameValueBB : SameIncomingValueBBs)
    if (SameValueBB->getUniquePredecessor() == Pred &&
        DestBB == findDestBlockOfMergeableEmptyBlock(SameValueBB))
      BBFreq += BFI->getBlockFreq(SameValueBB);

-  return PredFreq.getFrequency() <=
-         BBFreq.getFrequency() * FreqRatioToSkipMerge;
+  BlockFrequency Limit;
+  return !BBFreq.mul(FreqRatioToSkipMerge, &Limit) && PredFreq <= Limit;
}

/// Return true if we can merge BB into DestBB if there is a single
diff --git a/llvm/lib/Support/BlockFrequency.cpp b/llvm/lib/Support/BlockFrequency.cpp
index a4a1e477d9403f7..08fe3ef6061ecae 100644
--- a/llvm/lib/Support/BlockFrequency.cpp
+++ b/llvm/lib/Support/BlockFrequency.cpp
@@ -12,6 +12,7 @@

#include "llvm/Support/BlockFrequency.h"
#include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/MathExtras.h"

using namespace llvm;

@@ -36,3 +37,10 @@ BlockFrequency BlockFrequency::operator/(BranchProbability Prob) const {
  Freq /= Prob;
  return Freq;
}
+
+bool BlockFrequency::mul(uint64_t Factor, BlockFrequency *Result) const {
+  bool Overflow;
+  uint64_t ResultFrequency = SaturatingMultiply(Frequency, Factor, &Overflow);
+  *Result = BlockFrequency(ResultFrequency);
+  return Overflow;
+}

Error: Command failed due to missing milestone.

@MatzeB
Copy link
Contributor Author

MatzeB commented Sep 13, 2023

An alternative fix for this would be to perform the arithemtic on APInt(128, ). You can see this style used in other places of the codebase. In the end I decided that this solution seems more elegant as we can keep everything abstracted in a BlockFrequency instance instead of computing with raw integers which somewhat was the reason for the trouble here in the first place. The overflow check should also be more efficient than an APInt(128 , which requires heap-allocation.

llvm/include/llvm/Support/BlockFrequency.h Outdated Show resolved Hide resolved
llvm/include/llvm/Support/BlockFrequency.h Outdated Show resolved Hide resolved
llvm/lib/Support/BlockFrequency.cpp Outdated Show resolved Hide resolved
@MatzeB MatzeB force-pushed the overflow_fix branch 2 times, most recently from 8ede6e6 to 0bd06f1 Compare September 13, 2023 20:48
Copy link
Contributor

@kazutakahirata kazutakahirata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for updating the patch! (And even bigger thank you for paying attention to the block frequencies in the first place!)

Please be sure to update the commit message also as we no longer use __attribute__((warn_unused_result)).

Copy link
Contributor

@kazutakahirata kazutakahirata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@MatzeB MatzeB mentioned this pull request Sep 13, 2023
Multiplying raw block frequency with an integer carries a high risk
of overflow.

- Add `BlockFrequency::mul` return an std::optional with the product
  or `nullopt` to indicate an overflow.
- Fix two instances where overflow was likely.
@MatzeB MatzeB merged commit b0c8c45 into llvm:main Sep 14, 2023
1 of 2 checks passed
@MatzeB MatzeB deleted the overflow_fix branch September 14, 2023 18:11
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
Multiplying raw block frequency with an integer carries a high risk
of overflow.

- Add `BlockFrequency::mul` return an std::optional with the product
  or `nullopt` to indicate an overflow.
- Fix two instances where overflow was likely.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants