Skip to content

Commit

Permalink
[ValueTracking] Add DemandedElts support to computeKnownBits/ComputeN…
Browse files Browse the repository at this point in the history
…umSignBits (PR36319)

This patch adds initial support for a DemandedElts mask to the internal computeKnownBits/ComputeNumSignBits methods, matching the SelectionDAG and GlobalISel equivalents.

So far only a couple of instructions have been setup to handle the DemandedElts, the remainder still using the existing 'all elements' default. The plan is to extend support as we have test coverage.

Differential Revision: https://reviews.llvm.org/D73435
  • Loading branch information
RKSimon committed Feb 1, 2020
1 parent 1acf129 commit 105e5c9
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 59 deletions.
230 changes: 184 additions & 46 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,43 @@ static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
return nullptr;
}

static void computeKnownBits(const Value *V, KnownBits &Known,
unsigned Depth, const Query &Q);
static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
const APInt &DemandedElts,
APInt &DemandedLHS, APInt &DemandedRHS) {
int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements();
int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements();
DemandedLHS = DemandedRHS = APInt::getNullValue(NumElts);

for (int i = 0; i != NumMaskElts; ++i) {
if (!DemandedElts[i])
continue;
int M = Shuf->getMaskValue(i);
assert(M < (NumElts * 2) && "Invalid shuffle mask constant");

// For undef elements, we don't know anything about the common state of
// the shuffle result.
if (M == -1)
return false;
if (M < NumElts)
DemandedLHS.setBit(M % NumElts);
else
DemandedRHS.setBit(M % NumElts);
}

return true;
}

static void computeKnownBits(const Value *V, const APInt &DemandedElts,
KnownBits &Known, unsigned Depth, const Query &Q);

static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
const Query &Q) {
Type *Ty = V->getType();
APInt DemandedElts = Ty->isVectorTy()
? APInt::getAllOnesValue(Ty->getVectorNumElements())
: APInt(1, 1);
computeKnownBits(V, DemandedElts, Known, Depth, Q);
}

void llvm::computeKnownBits(const Value *V, KnownBits &Known,
const DataLayout &DL, unsigned Depth,
Expand Down Expand Up @@ -295,8 +330,17 @@ bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask,
V, Mask, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo));
}

static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
unsigned Depth, const Query &Q);

static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
const Query &Q);
const Query &Q) {
Type *Ty = V->getType();
APInt DemandedElts = Ty->isVectorTy()
? APInt::getAllOnesValue(Ty->getVectorNumElements())
: APInt(1, 1);
return ComputeNumSignBits(V, DemandedElts, Depth, Q);
}

unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL,
unsigned Depth, AssumptionCache *AC,
Expand Down Expand Up @@ -1039,8 +1083,10 @@ static void computeKnownBitsFromShiftOperator(
Known.setAllZero();
}

static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
unsigned Depth, const Query &Q) {
static void computeKnownBitsFromOperator(const Operator *I,
const APInt &DemandedElts,
KnownBits &Known, unsigned Depth,
const Query &Q) {
unsigned BitWidth = Known.getBitWidth();

KnownBits Known2(Known);
Expand Down Expand Up @@ -1654,6 +1700,63 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
}
}
break;
case Instruction::ShuffleVector: {
auto *Shuf = cast<ShuffleVectorInst>(I);
// For undef elements, we don't know anything about the common state of
// the shuffle result.
APInt DemandedLHS, DemandedRHS;
if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) {
Known.resetAll();
return;
}
Known.One.setAllBits();
Known.Zero.setAllBits();
if (!!DemandedLHS) {
const Value *LHS = Shuf->getOperand(0);
computeKnownBits(LHS, DemandedLHS, Known, Depth + 1, Q);
// If we don't know any bits, early out.
if (Known.isUnknown())
break;
}
if (!!DemandedRHS) {
const Value *RHS = Shuf->getOperand(1);
computeKnownBits(RHS, DemandedRHS, Known2, Depth + 1, Q);
Known.One &= Known2.One;
Known.Zero &= Known2.Zero;
}
break;
}
case Instruction::InsertElement: {
auto *IEI = cast<InsertElementInst>(I);
Value *Vec = IEI->getOperand(0);
Value *Elt = IEI->getOperand(1);
auto *CIdx = dyn_cast<ConstantInt>(IEI->getOperand(2));
// Early out if the index is non-constant or out-of-range.
unsigned NumElts = DemandedElts.getBitWidth();
if (!CIdx || CIdx->getValue().uge(NumElts)) {
Known.resetAll();
return;
}
Known.One.setAllBits();
Known.Zero.setAllBits();
unsigned EltIdx = CIdx->getZExtValue();
// Do we demand the inserted element?
if (DemandedElts[EltIdx]) {
computeKnownBits(Elt, Known, Depth + 1, Q);
// If we don't know any bits, early out.
if (Known.isUnknown())
break;
}
// We don't need the base vector element that has been inserted.
APInt DemandedVecElts = DemandedElts;
DemandedVecElts.clearBit(EltIdx);
if (!!DemandedVecElts) {
computeKnownBits(Vec, DemandedVecElts, Known2, Depth + 1, Q);
Known.One &= Known2.One;
Known.Zero &= Known2.Zero;
}
break;
}
case Instruction::ExtractElement:
// Look through extract element. At the moment we keep this simple and skip
// tracking the specific element. But at least we might find information
Expand Down Expand Up @@ -1688,6 +1791,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
}
}
}
break;
}
}

Expand All @@ -1713,24 +1817,34 @@ KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) {
/// type, and vectors of integers. In the case
/// where V is a vector, known zero, and known one values are the
/// same width as the vector element, and the bit is set only if it is true
/// for all of the elements in the vector.
void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
const Query &Q) {
/// for all of the demanded elements in the vector specified by DemandedElts.
void computeKnownBits(const Value *V, const APInt &DemandedElts,
KnownBits &Known, unsigned Depth, const Query &Q) {
assert(V && "No Value?");
assert(Depth <= MaxDepth && "Limit Search Depth");
unsigned BitWidth = Known.getBitWidth();

assert((V->getType()->isIntOrIntVectorTy(BitWidth) ||
V->getType()->isPtrOrPtrVectorTy()) &&
Type *Ty = V->getType();
assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
"Not integer or pointer type!");
assert(((Ty->isVectorTy() &&
Ty->getVectorNumElements() == DemandedElts.getBitWidth()) ||
(!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) &&
"Unexpected vector size");

Type *ScalarTy = V->getType()->getScalarType();
Type *ScalarTy = Ty->getScalarType();
unsigned ExpectedWidth = ScalarTy->isPointerTy() ?
Q.DL.getPointerTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy);
assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth");
(void)BitWidth;
(void)ExpectedWidth;

if (!DemandedElts) {
// No demanded elts, better to assume we don't know anything.
Known.resetAll();
return;
}

const APInt *C;
if (match(V, m_APInt(C))) {
// We know all of the bits for a scalar constant or a splat vector constant!
Expand All @@ -1746,10 +1860,15 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
// Handle a constant vector by taking the intersection of the known bits of
// each element.
if (const ConstantDataSequential *CDS = dyn_cast<ConstantDataSequential>(V)) {
assert((!Ty->isVectorTy() ||
CDS->getNumElements() == DemandedElts.getBitWidth()) &&
"Unexpected vector size");
// We know that CDS must be a vector of integers. Take the intersection of
// each element.
Known.Zero.setAllBits(); Known.One.setAllBits();
for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) {
if (Ty->isVectorTy() && !DemandedElts[i])
continue;
APInt Elt = CDS->getElementAsAPInt(i);
Known.Zero &= ~Elt;
Known.One &= Elt;
Expand All @@ -1758,10 +1877,14 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
}

if (const auto *CV = dyn_cast<ConstantVector>(V)) {
assert(CV->getNumOperands() == DemandedElts.getBitWidth() &&
"Unexpected vector size");
// We know that CV must be a vector of integers. Take the intersection of
// each element.
Known.Zero.setAllBits(); Known.One.setAllBits();
for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
if (!DemandedElts[i])
continue;
Constant *Element = CV->getAggregateElement(i);
auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element);
if (!ElementCI) {
Expand Down Expand Up @@ -1800,10 +1923,10 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
}

if (const Operator *I = dyn_cast<Operator>(V))
computeKnownBitsFromOperator(I, Known, Depth, Q);
computeKnownBitsFromOperator(I, DemandedElts, Known, Depth, Q);

// Aligned pointers have trailing zeros - refine Known.Zero set
if (V->getType()->isPointerTy()) {
if (Ty->isPointerTy()) {
const MaybeAlign Align = V->getPointerAlignment(Q.DL);
if (Align)
Known.Zero.setLowBits(countTrailingZeros(Align->value()));
Expand Down Expand Up @@ -2429,6 +2552,7 @@ static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
/// or if any element was not analyzed; otherwise, return the count for the
/// element with the minimum number of sign bits.
static unsigned computeNumSignBitsVectorConstant(const Value *V,
const APInt &DemandedElts,
unsigned TyBits) {
const auto *CV = dyn_cast<Constant>(V);
if (!CV || !CV->getType()->isVectorTy())
Expand All @@ -2437,6 +2561,8 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V,
unsigned MinSignBits = TyBits;
unsigned NumElts = CV->getType()->getVectorNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
if (!DemandedElts[i])
continue;
// If we find a non-ConstantInt, bail out.
auto *Elt = dyn_cast_or_null<ConstantInt>(CV->getAggregateElement(i));
if (!Elt)
Expand All @@ -2448,12 +2574,22 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V,
return MinSignBits;
}

static unsigned ComputeNumSignBitsImpl(const Value *V,
const APInt &DemandedElts,
unsigned Depth, const Query &Q);

static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth,
const Query &Q);
const Query &Q) {
Type *Ty = V->getType();
APInt DemandedElts = Ty->isVectorTy()
? APInt::getAllOnesValue(Ty->getVectorNumElements())
: APInt(1, 1);
return ComputeNumSignBitsImpl(V, DemandedElts, Depth, Q);
}

static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
const Query &Q) {
unsigned Result = ComputeNumSignBitsImpl(V, Depth, Q);
static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
unsigned Depth, const Query &Q) {
unsigned Result = ComputeNumSignBitsImpl(V, DemandedElts, Depth, Q);
assert(Result > 0 && "At least one sign bit needs to be present!");
return Result;
}
Expand All @@ -2463,16 +2599,24 @@ static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
/// (itself), but other cases can give us information. For example, immediately
/// after an "ashr X, 2", we know that the top 3 bits are all equal to each
/// other, so we return 3. For vectors, return the number of sign bits for the
/// vector element with the minimum number of known sign bits.
static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth,
const Query &Q) {
/// vector element with the minimum number of known sign bits of the demanded
/// elements in the vector specified by DemandedElts.
static unsigned ComputeNumSignBitsImpl(const Value *V,
const APInt &DemandedElts,
unsigned Depth, const Query &Q) {
assert(Depth <= MaxDepth && "Limit Search Depth");

// We return the minimum number of sign bits that are guaranteed to be present
// in V, so for undef we have to conservatively return 1. We don't have the
// same behavior for poison though -- that's a FIXME today.

Type *ScalarTy = V->getType()->getScalarType();
Type *Ty = V->getType();
assert(((Ty->isVectorTy() &&
Ty->getVectorNumElements() == DemandedElts.getBitWidth()) ||
(!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) &&
"Unexpected vector size");

Type *ScalarTy = Ty->getScalarType();
unsigned TyBits = ScalarTy->isPointerTy() ?
Q.DL.getPointerTypeSizeInBits(ScalarTy) :
Q.DL.getTypeSizeInBits(ScalarTy);
Expand Down Expand Up @@ -2698,40 +2842,33 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth,
return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);

case Instruction::ShuffleVector: {
// TODO: This is copied almost directly from the SelectionDAG version of
// ComputeNumSignBits. It would be better if we could share common
// code. If not, make sure that changes are translated to the DAG.

// Collect the minimum number of sign bits that are shared by every vector
// element referenced by the shuffle.
auto *Shuf = cast<ShuffleVectorInst>(U);
int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements();
int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements();
APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0);
for (int i = 0; i != NumMaskElts; ++i) {
int M = Shuf->getMaskValue(i);
assert(M < NumElts * 2 && "Invalid shuffle mask constant");
// For undef elements, we don't know anything about the common state of
// the shuffle result.
if (M == -1)
return 1;
if (M < NumElts)
DemandedLHS.setBit(M % NumElts);
else
DemandedRHS.setBit(M % NumElts);
}
APInt DemandedLHS, DemandedRHS;
// For undef elements, we don't know anything about the common state of
// the shuffle result.
if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
return 1;
Tmp = std::numeric_limits<unsigned>::max();
if (!!DemandedLHS)
Tmp = ComputeNumSignBits(Shuf->getOperand(0), Depth + 1, Q);
if (!!DemandedLHS) {
const Value *LHS = Shuf->getOperand(0);
Tmp = ComputeNumSignBits(LHS, DemandedLHS, Depth + 1, Q);
}
// If we don't know anything, early out and try computeKnownBits
// fall-back.
if (Tmp == 1)
break;
if (!!DemandedRHS) {
Tmp2 = ComputeNumSignBits(Shuf->getOperand(1), Depth + 1, Q);
const Value *RHS = Shuf->getOperand(1);
Tmp2 = ComputeNumSignBits(RHS, DemandedRHS, Depth + 1, Q);
Tmp = std::min(Tmp, Tmp2);
}
// If we don't know anything, early out and try computeKnownBits
// fall-back.
if (Tmp == 1)
break;
assert(Tmp <= V->getType()->getScalarSizeInBits() &&
assert(Tmp <= Ty->getScalarSizeInBits() &&
"Failed to determine minimum sign bits");
return Tmp;
}
Expand All @@ -2743,11 +2880,12 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth,

// If we can examine all elements of a vector constant successfully, we're
// done (we can't do any better than that). If not, keep trying.
if (unsigned VecSignBits = computeNumSignBitsVectorConstant(V, TyBits))
if (unsigned VecSignBits =
computeNumSignBitsVectorConstant(V, DemandedElts, TyBits))
return VecSignBits;

KnownBits Known(TyBits);
computeKnownBits(V, Known, Depth, Q);
computeKnownBits(V, DemandedElts, Known, Depth, Q);

// If we know that the sign bit is either zero or one, determine the number of
// identical bits in the top of the input value.
Expand Down
4 changes: 1 addition & 3 deletions llvm/test/Transforms/InstCombine/nsw.ll
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,11 @@ define i8 @nopreserve4(i8 %A, i8 %B) {
ret i8 %add
}

; TODO: computeKnownBits() should look through a shufflevector.

define <3 x i32> @shl_nuw_nsw_shuffle_splat_vec(<2 x i8> %x) {
; CHECK-LABEL: @shl_nuw_nsw_shuffle_splat_vec(
; CHECK-NEXT: [[T2:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
; CHECK-NEXT: [[SHUF:%.*]] = shufflevector <2 x i32> [[T2]], <2 x i32> undef, <3 x i32> <i32 1, i32 0, i32 1>
; CHECK-NEXT: [[T3:%.*]] = shl nsw <3 x i32> [[SHUF]], <i32 17, i32 17, i32 17>
; CHECK-NEXT: [[T3:%.*]] = shl nuw nsw <3 x i32> [[SHUF]], <i32 17, i32 17, i32 17>
; CHECK-NEXT: ret <3 x i32> [[T3]]
;
%t2 = zext <2 x i8> %x to <2 x i32>
Expand Down
Loading

0 comments on commit 105e5c9

Please sign in to comment.