diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index a7bcbf010d1bf..f1c2f38c74afd 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -18,6 +18,8 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/Compiler.h" +#include +#include namespace llvm { struct MDProfLabels { @@ -216,9 +218,13 @@ LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T); /// branch weights B1 and B2, respectively. In both B1 and B2, the first /// position (index 0) is for the 'true' branch, and the second position (index /// 1) is for the 'false' branch. +template && std::is_arithmetic_v && + sizeof(T1) <= sizeof(uint64_t) && sizeof(T2) <= sizeof(uint64_t)>> inline SmallVector -getDisjunctionWeights(const SmallVector &B1, - const SmallVector &B2) { +getDisjunctionWeights(const SmallVector &B1, + const SmallVector &B2) { // For the first conditional branch, the probability the "true" case is taken // is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is // p(not b1) = B1[1] / (B1[0] + B1[1]). @@ -235,8 +241,8 @@ getDisjunctionWeights(const SmallVector &B1, // the product of sums, the subtracted one cancels out). assert(B1.size() == 2); assert(B2.size() == 2); - auto FalseWeight = B1[1] * B2[1]; - auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0]; + uint64_t FalseWeight = B1[1] * B2[1]; + uint64_t TrueWeight = B1[0] * (B2[0] + B2[1]) + B1[1] * B2[0]; return {TrueWeight, FalseWeight}; } } // namespace llvm