Skip to content

Commit

Permalink
Use fixed-point representation for BranchProbability.
Browse files Browse the repository at this point in the history
BranchProbability now is represented by its numerator and denominator in uint32_t type. This patch changes this representation into a fixed point that is represented by the numerator in uint32_t type and a constant denominator 1<<31. This is quite similar to the representation of BlockMass in BlockFrequencyInfoImpl.h. There are several pros and cons of this change:

Pros:

1. It uses only a half space of the current one.
2. Some operations are much faster like plus, subtraction, comparison, and scaling by an integer.

Cons:

1. Constructing a probability using arbitrary numerator and denominator needs additional calculations.
2. It is a little less precise than before as we use a fixed denominator. For example, 1 - 1/3 may not be exactly identical to 1 / 3 (this will lead to many BranchProbability unit test failures). This should not matter when we only use it for branch probability. If we use it like a rational value for some precise calculations we may need another construct like ValueRatio.

One important reason for this change is that we propose to store branch probabilities instead of edge weights in MachineBasicBlock. We also want clients to use probability instead of weight when adding successors to a MBB. The current BranchProbability has more space which may be a concern.

Differential revision: http://reviews.llvm.org/D12603

llvm-svn: 248633
  • Loading branch information
Cong Hou committed Sep 25, 2015
1 parent a0bc859 commit 15ea016
Show file tree
Hide file tree
Showing 18 changed files with 370 additions and 275 deletions.
9 changes: 5 additions & 4 deletions llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h
Expand Up @@ -1190,10 +1190,11 @@ raw_ostream &BlockFrequencyInfoImpl<BT>::print(raw_ostream &OS) const {
if (!F)
return OS;
OS << "block-frequency-info: " << F->getName() << "\n";
for (const BlockT &BB : *F)
OS << " - " << bfi_detail::getBlockName(&BB)
<< ": float = " << getFloatingBlockFreq(&BB)
<< ", int = " << getBlockFreq(&BB).getFrequency() << "\n";
for (const BlockT &BB : *F) {
OS << " - " << bfi_detail::getBlockName(&BB) << ": float = ";
getFloatingBlockFreq(&BB).print(OS, 5)
<< ", int = " << getBlockFreq(&BB).getFrequency() << "\n";
}

// Add an extra newline for readability.
OS << "\n";
Expand Down
78 changes: 58 additions & 20 deletions llvm/include/llvm/Support/BranchProbability.h
Expand Up @@ -21,31 +21,43 @@ namespace llvm {

class raw_ostream;

// This class represents Branch Probability as a non-negative fraction.
// This class represents Branch Probability as a non-negative fraction that is
// no greater than 1. It uses a fixed-point-like implementation, in which the
// denominator is always a constant value (here we use 1<<31 for maximum
// precision).
class BranchProbability {
// Numerator
uint32_t N;

// Denominator
uint32_t D;
// Denominator, which is a constant value.
static const uint32_t D = 1u << 31;

// Construct a BranchProbability with only numerator assuming the denominator
// is 1<<31. For internal use only.
explicit BranchProbability(uint32_t n) : N(n) {}

public:
BranchProbability(uint32_t Numerator, uint32_t Denominator)
: N(Numerator), D(Denominator) {
assert(D > 0 && "Denominator cannot be 0!");
assert(N <= D && "Probability cannot be bigger than 1!");
}
BranchProbability() : N(0) {}
BranchProbability(uint32_t Numerator, uint32_t Denominator);

bool isZero() const { return N == 0; }

static BranchProbability getZero() { return BranchProbability(0, 1); }
static BranchProbability getOne() { return BranchProbability(1, 1); }
static BranchProbability getZero() { return BranchProbability(0); }
static BranchProbability getOne() { return BranchProbability(D); }
// Create a BranchProbability object with the given numerator and 1<<31
// as denominator.
static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }

// Normalize given probabilties so that the sum of them becomes approximate
// one.
template <class ProbabilityList>
static void normalizeProbabilities(ProbabilityList &Probs);

uint32_t getNumerator() const { return N; }
uint32_t getDenominator() const { return D; }
static uint32_t getDenominator() { return D; }

// Return (1 - Probability).
BranchProbability getCompl() const {
return BranchProbability(D - N, D);
}
BranchProbability getCompl() const { return BranchProbability(D - N); }

raw_ostream &print(raw_ostream &OS) const;

Expand All @@ -67,15 +79,31 @@ class BranchProbability {
/// \return \c Num divided by \c this.
uint64_t scaleByInverse(uint64_t Num) const;

bool operator==(BranchProbability RHS) const {
return (uint64_t)N * RHS.D == (uint64_t)D * RHS.N;
BranchProbability &operator+=(BranchProbability RHS);
BranchProbability &operator-=(BranchProbability RHS);
BranchProbability &operator*=(BranchProbability RHS) {
N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
return *this;
}

BranchProbability operator+(BranchProbability RHS) const {
BranchProbability Prob(*this);
return Prob += RHS;
}
bool operator!=(BranchProbability RHS) const {
return !(*this == RHS);

BranchProbability operator-(BranchProbability RHS) const {
BranchProbability Prob(*this);
return Prob -= RHS;
}
bool operator<(BranchProbability RHS) const {
return (uint64_t)N * RHS.D < (uint64_t)D * RHS.N;

BranchProbability operator*(BranchProbability RHS) const {
BranchProbability Prob(*this);
return Prob *= RHS;
}

bool operator==(BranchProbability RHS) const { return N == RHS.N; }
bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
bool operator<(BranchProbability RHS) const { return N < RHS.N; }
bool operator>(BranchProbability RHS) const { return RHS < *this; }
bool operator<=(BranchProbability RHS) const { return !(RHS < *this); }
bool operator>=(BranchProbability RHS) const { return !(*this < RHS); }
Expand All @@ -85,6 +113,16 @@ inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
return Prob.print(OS);
}

template <class ProbabilityList>
void BranchProbability::normalizeProbabilities(ProbabilityList &Probs) {
uint64_t Sum = 0;
for (auto Prob : Probs)
Sum += Prob.N;
assert(Sum > 0);
for (auto &Prob : Probs)
Prob.N = (Prob.N * uint64_t(D) + Sum / 2) / Sum;
}

}

#endif
56 changes: 51 additions & 5 deletions llvm/lib/Support/BranchProbability.cpp
Expand Up @@ -15,17 +15,63 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>

using namespace llvm;

raw_ostream &BranchProbability::print(raw_ostream &OS) const {
return OS << N << " / " << D << " = "
<< format("%g%%", ((double)N / D) * 100.0);
auto GetHexDigit = [](int Val) -> char {
assert(Val < 16);
if (Val < 10)
return '0' + Val;
return 'a' + Val - 10;
};
OS << "0x";
for (int Digits = 0; Digits < 8; ++Digits)
OS << GetHexDigit(N >> (28 - Digits * 4) & 0xf);
OS << " / 0x";
for (int Digits = 0; Digits < 8; ++Digits)
OS << GetHexDigit(D >> (28 - Digits * 4) & 0xf);
OS << " = " << format("%.2f%%", ((double)N / D) * 100.0);
return OS;
}

void BranchProbability::dump() const { print(dbgs()) << '\n'; }

static uint64_t scale(uint64_t Num, uint32_t N, uint32_t D) {
BranchProbability::BranchProbability(uint32_t Numerator, uint32_t Denominator) {
assert(Denominator > 0 && "Denominator cannot be 0!");
assert(Numerator <= Denominator && "Probability cannot be bigger than 1!");
if (Denominator == D)
N = Numerator;
else {
uint64_t Prob64 =
(Numerator * static_cast<uint64_t>(D) + Denominator / 2) / Denominator;
N = static_cast<uint32_t>(Prob64);
}
}

BranchProbability &BranchProbability::operator+=(BranchProbability RHS) {
assert(N <= D - RHS.N &&
"The sum of branch probabilities should not exceed one!");
N += RHS.N;
return *this;
}

BranchProbability &BranchProbability::operator-=(BranchProbability RHS) {
assert(N >= RHS.N &&
"Can only subtract a smaller probability from a larger one!");
N -= RHS.N;
return *this;
}

// If ConstD is not zero, then replace D by ConstD so that division and modulo
// operations by D can be optimized, in case this function is not inlined by the
// compiler.
template <uint32_t ConstD>
inline uint64_t scale(uint64_t Num, uint32_t N, uint32_t D) {
if (ConstD > 0)
D = ConstD;

assert(D && "divide by 0");

// Fast path for multiplying by 1.0.
Expand Down Expand Up @@ -65,9 +111,9 @@ static uint64_t scale(uint64_t Num, uint32_t N, uint32_t D) {
}

uint64_t BranchProbability::scale(uint64_t Num) const {
return ::scale(Num, N, D);
return ::scale<D>(Num, N, D);
}

uint64_t BranchProbability::scaleByInverse(uint64_t Num) const {
return ::scale(Num, D, N);
return ::scale<0>(Num, D, N);
}
6 changes: 3 additions & 3 deletions llvm/test/Analysis/BlockFrequencyInfo/basic.ll
Expand Up @@ -104,13 +104,13 @@ for.cond1.preheader:
%x.024 = phi i32 [ 0, %entry ], [ %inc12, %for.inc11 ]
br label %for.cond4.preheader

; CHECK-NEXT: for.cond4.preheader: float = 16008001.0,
; CHECK-NEXT: for.cond4.preheader: float = 16007984.8,
for.cond4.preheader:
%y.023 = phi i32 [ 0, %for.cond1.preheader ], [ %inc9, %for.inc8 ]
%add = add i32 %y.023, %x.024
br label %for.body6

; CHECK-NEXT: for.body6: float = 64048012001.0,
; CHECK-NEXT: for.body6: float = 64047914563.9,
for.body6:
%z.022 = phi i32 [ 0, %for.cond4.preheader ], [ %inc, %for.body6 ]
%add7 = add i32 %add, %z.022
Expand All @@ -119,7 +119,7 @@ for.body6:
%cmp5 = icmp ugt i32 %inc, %a
br i1 %cmp5, label %for.inc8, label %for.body6, !prof !2

; CHECK-NEXT: for.inc8: float = 16008001.0,
; CHECK-NEXT: for.inc8: float = 16007984.8,
for.inc8:
%inc9 = add i32 %y.023, 1
%cmp2 = icmp ugt i32 %inc9, %a
Expand Down
Expand Up @@ -93,7 +93,7 @@ for.cond4: ; preds = %for.inc, %for.body3
%cmp5 = icmp slt i32 %2, 100
br i1 %cmp5, label %for.body6, label %for.end, !prof !3

; CHECK: - for.body6: float = 500000.5, int = 4000003
; CHECK: - for.body6: float = 500000.5, int = 4000004
for.body6: ; preds = %for.cond4
call void @bar()
br label %for.inc
Expand Down Expand Up @@ -143,7 +143,7 @@ for.cond16: ; preds = %for.inc19, %for.bod
%cmp17 = icmp slt i32 %8, 10000
br i1 %cmp17, label %for.body18, label %for.end21, !prof !4

; CHECK: - for.body18: float = 500000.5, int = 4000003
; CHECK: - for.body18: float = 499999.9, int = 3999998
for.body18: ; preds = %for.cond16
call void @bar()
br label %for.inc19
Expand Down Expand Up @@ -175,7 +175,7 @@ for.cond26: ; preds = %for.inc29, %for.end
%cmp27 = icmp slt i32 %12, 1000000
br i1 %cmp27, label %for.body28, label %for.end31, !prof !5

; CHECK: - for.body28: float = 500000.5, int = 4000003
; CHECK: - for.body28: float = 499995.2, int = 3999961
for.body28: ; preds = %for.cond26
call void @bar()
br label %for.inc29
Expand Down

0 comments on commit 15ea016

Please sign in to comment.