diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h index 42bcadfc7dcdb6..afa4f06d3e7951 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h @@ -17,6 +17,8 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; @@ -104,6 +106,16 @@ class TruncInstCombine { /// to be reduced. Type *getBestTruncatedType(); + KnownBits computeKnownBits(const Value *V) const { + return llvm::computeKnownBits(V, DL, /*Depth=*/0, /*AC=*/nullptr, + /*CtxI=*/nullptr, &DT); + } + + unsigned ComputeNumSignBits(const Value *V) const { + return llvm::ComputeNumSignBits(V, DL, /*Depth=*/0, /*AC=*/nullptr, + /*CtxI=*/nullptr, &DT); + } + /// Given a \p V value and a \p SclTy scalar type return the generated reduced /// value of \p V based on the type \p SclTy. /// diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp index 5d66533f04e0fb..25ca5885f83157 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -288,19 +288,19 @@ Type *TruncInstCombine::getBestTruncatedType() { for (auto &Itr : InstInfoMap) { Instruction *I = Itr.first; if (I->isShift()) { - KnownBits KnownRHS = computeKnownBits(I->getOperand(1), DL); + KnownBits KnownRHS = computeKnownBits(I->getOperand(1)); unsigned MinBitWidth = KnownRHS.getMaxValue() .uadd_sat(APInt(OrigBitWidth, 1)) .getLimitedValue(OrigBitWidth); if (MinBitWidth == OrigBitWidth) return nullptr; if (I->getOpcode() == Instruction::LShr) { - KnownBits KnownLHS = computeKnownBits(I->getOperand(0), DL); + KnownBits KnownLHS = computeKnownBits(I->getOperand(0)); MinBitWidth = std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits()); } if (I->getOpcode() == Instruction::AShr) { - unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0), DL); + unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0)); MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1); } if (MinBitWidth >= OrigBitWidth)