diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp index 311ba19c727ad8..7f1d27abcdc124 100644 --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -102,7 +102,7 @@ static const uint32_t LBH_UNLIKELY_WEIGHT = 62; /// /// This is the probability for a branch being taken to a block that terminates /// (eventually) in unreachable. These are predicted as unlikely as possible. -/// All reachable probability will equally share the remaining part. +/// All reachable probability will proportionally share the remaining part. static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1); /// Weight for a branch taken going into a cold block. @@ -349,35 +349,69 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { // Examine the metadata against unreachable heuristic. // If the unreachable heuristic is more strong then we use it for this edge. - if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) { - auto UnreachableProb = UR_TAKEN_PROB; - for (auto I : UnreachableIdxs) - if (UnreachableProb < BP[I]) { - BP[I] = UnreachableProb; - } + if (UnreachableIdxs.size() == 0 || ReachableIdxs.size() == 0) { + setEdgeProbability(BB, BP); + return true; + } + + auto UnreachableProb = UR_TAKEN_PROB; + for (auto I : UnreachableIdxs) + if (UnreachableProb < BP[I]) { + BP[I] = UnreachableProb; + } - // Because of possible rounding errors and the above fix up for - // the unreachable heuristic the sum of probabilities of all edges may be - // less than 1.0. Distribute the remaining probability (calculated as - // 1.0 - (sum of BP[i])) evenly among all the reachable edges. - auto ToDistribute = BranchProbability::getOne(); - for (auto &P : BP) - ToDistribute -= P; - - // If we modified the probability of some edges then we must distribute - // the difference between reachable blocks. - // TODO: This spreads ToDistribute evenly upon the reachable edges. A better - // distribution would be proportional. So the relation between weights of - // the reachable edges would be kept unchanged. That is for any reachable - // edges i and j: - // newBP[i] / newBP[j] == oldBP[i] / oldBP[j] - // newBP[i] / oldBP[i] == newBP[j] / oldBP[j] == - // == Denominator / (Denominator - ToDistribute) - // newBP[i] = oldBP[i] * Denominator / (Denominator - ToDistribute) - BranchProbability PerEdge = ToDistribute / ReachableIdxs.size(); - if (PerEdge > BranchProbability::getZero()) + // Sum of all edge probabilities must be 1.0. If we modified the probability + // of some edges then we must distribute the introduced difference over the + // reachable blocks. + // + // Proportional distribution: the relation between probabilities of the + // reachable edges is kept unchanged. That is for any reachable edges i and j: + // newBP[i] / newBP[j] == oldBP[i] / oldBP[j] => + // newBP[i] / oldBP[i] == newBP[j] / oldBP[j] == K + // Where K is independent of i,j. + // newBP[i] == oldBP[i] * K + // We need to find K. + // Make sum of all reachables of the left and right parts: + // sum_of_reachable(newBP) == K * sum_of_reachable(oldBP) + // Sum of newBP must be equal to 1.0: + // sum_of_reachable(newBP) + sum_of_unreachable(newBP) == 1.0 => + // sum_of_reachable(newBP) = 1.0 - sum_of_unreachable(newBP) + // Where sum_of_unreachable(newBP) is what has been just changed. + // Finally: + // K == sum_of_reachable(newBP) / sum_of_reachable(oldBP) => + // K == (1.0 - sum_of_unreachable(newBP)) / sum_of_reachable(oldBP) + BranchProbability NewUnreachableSum = BranchProbability::getZero(); + for (auto I : UnreachableIdxs) + NewUnreachableSum += BP[I]; + + BranchProbability NewReachableSum = + BranchProbability::getOne() - NewUnreachableSum; + + BranchProbability OldReachableSum = BranchProbability::getZero(); + for (auto I : ReachableIdxs) + OldReachableSum += BP[I]; + + if (OldReachableSum != NewReachableSum) { // Anything to dsitribute? + if (OldReachableSum.isZero()) { + // If all oldBP[i] are zeroes then the proportional distribution results + // in all zero probabilities and the error stays big. In this case we + // evenly spread NewReachableSum over the reachable edges. + BranchProbability PerEdge = NewReachableSum / ReachableIdxs.size(); for (auto I : ReachableIdxs) - BP[I] += PerEdge; + BP[I] = PerEdge; + } else { + for (auto I : ReachableIdxs) { + // We use uint64_t to avoid double rounding error of the following + // calculation: BP[i] = BP[i] * NewReachableSum / OldReachableSum + // The formula is taken from the private constructor + // BranchProbability(uint32_t Numerator, uint32_t Denominator) + uint64_t Mul = static_cast(NewReachableSum.getNumerator()) * + BP[I].getNumerator(); + uint32_t Div = static_cast( + divideNearest(Mul, OldReachableSum.getNumerator())); + BP[I] = BranchProbability::getRaw(Div); + } + } } setEdgeProbability(BB, BP); diff --git a/llvm/test/Analysis/BranchProbabilityInfo/basic.ll b/llvm/test/Analysis/BranchProbabilityInfo/basic.ll index 73720b0b411f0c..debec866d7159a 100644 --- a/llvm/test/Analysis/BranchProbabilityInfo/basic.ll +++ b/llvm/test/Analysis/BranchProbabilityInfo/basic.ll @@ -469,11 +469,12 @@ entry: i32 2, label %case_c i32 3, label %case_d i32 4, label %case_e ], !prof !8 +; Reachable probabilities keep their relation: 4/64/4/4 = 5.26% / 84.21% / 5.26% / 5.26%. ; CHECK: edge entry -> case_a probability is 0x00000001 / 0x80000000 = 0.00% -; CHECK: edge entry -> case_b probability is 0x07ffffff / 0x80000000 = 6.25% -; CHECK: edge entry -> case_c probability is 0x67ffffff / 0x80000000 = 81.25% [HOT edge] -; CHECK: edge entry -> case_d probability is 0x07ffffff / 0x80000000 = 6.25% -; CHECK: edge entry -> case_e probability is 0x07ffffff / 0x80000000 = 6.25% +; CHECK: edge entry -> case_b probability is 0x06bca1af / 0x80000000 = 5.26% +; CHECK: edge entry -> case_c probability is 0x6bca1af3 / 0x80000000 = 84.21% [HOT edge] +; CHECK: edge entry -> case_d probability is 0x06bca1af / 0x80000000 = 5.26% +; CHECK: edge entry -> case_e probability is 0x06bca1af / 0x80000000 = 5.26% case_a: unreachable @@ -511,11 +512,13 @@ entry: i32 2, label %case_c i32 3, label %case_d i32 4, label %case_e ], !prof !9 +; Reachable probabilities keep their relation: 64/4/4 = 88.89% / 5.56% / 5.56%. ; CHECK: edge entry -> case_a probability is 0x00000001 / 0x80000000 = 0.00% ; CHECK: edge entry -> case_b probability is 0x00000001 / 0x80000000 = 0.00% -; CHECK: edge entry -> case_c probability is 0x6aaaaaaa / 0x80000000 = 83.33% [HOT edge] -; CHECK: edge entry -> case_d probability is 0x0aaaaaaa / 0x80000000 = 8.33% -; CHECK: edge entry -> case_e probability is 0x0aaaaaaa / 0x80000000 = 8.33% +; CHECK: edge entry -> case_c probability is 0x71c71c71 / 0x80000000 = 88.89% [HOT edge] +; CHECK: edge entry -> case_d probability is 0x071c71c7 / 0x80000000 = 5.56% +; CHECK: edge entry -> case_e probability is 0x071c71c7 / 0x80000000 = 5.56% + case_a: unreachable @@ -551,11 +554,12 @@ entry: i32 2, label %case_c i32 3, label %case_d i32 4, label %case_e ], !prof !10 +; Reachable probabilities keep their relation: 64/4/4 = 88.89% / 5.56% / 5.56%. ; CHECK: edge entry -> case_a probability is 0x00000000 / 0x80000000 = 0.00% ; CHECK: edge entry -> case_b probability is 0x00000001 / 0x80000000 = 0.00% -; CHECK: edge entry -> case_c probability is 0x6e08fb82 / 0x80000000 = 85.96% [HOT edge] -; CHECK: edge entry -> case_d probability is 0x08fb823e / 0x80000000 = 7.02% -; CHECK: edge entry -> case_e probability is 0x08fb823e / 0x80000000 = 7.02% +; CHECK: edge entry -> case_c probability is 0x71c71c71 / 0x80000000 = 88.89% [HOT edge] +; CHECK: edge entry -> case_d probability is 0x071c71c7 / 0x80000000 = 5.56% +; CHECK: edge entry -> case_e probability is 0x071c71c7 / 0x80000000 = 5.56% case_a: unreachable