Skip to content

Commit

Permalink
[SelectionDAG] Add demanded elts variants to isConstOrConstSplat help…
Browse files Browse the repository at this point in the history
…ers. NFCI.

These helpers extend the existing isConstOrConstSplat helper checks to support DemandedElts masks as well.

We already had a local version of this in SelectionDAG that computeKnownBits/ComputeNumSignBits made use of, but this adds the functionality directly to the BuildVectorSDNode node and extends isConstOrConstSplat etc. to use that.

This will allow us to reuse the functionality in SimplifyDemandedVectorElts/SimplifyDemandedBits.

Differential Revision: https://reviews.llvm.org/D58503

llvm-svn: 354797
  • Loading branch information
RKSimon committed Feb 25, 2019
1 parent 8a7f4c9 commit 80d0e9c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 37 deletions.
39 changes: 39 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Expand Up @@ -1629,9 +1629,19 @@ bool isBitwiseNot(SDValue V);
/// Returns the SDNode if it is a constant splat BuildVector or constant int.
ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false);

/// Returns the SDNode if it is a demanded constant splat BuildVector or
/// constant int.
ConstantSDNode *isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
bool AllowUndefs = false);

/// Returns the SDNode if it is a constant splat BuildVector or constant float.
ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, bool AllowUndefs = false);

/// Returns the SDNode if it is a demanded constant splat BuildVector or
/// constant float.
ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, const APInt &DemandedElts,
bool AllowUndefs = false);

/// Return true if the value is a constant 0 integer or a splatted vector of
/// a constant 0 integer (with no undefs by default).
/// Build vector implicit truncation is not an issue for null values.
Expand Down Expand Up @@ -1868,12 +1878,31 @@ class BuildVectorSDNode : public SDNode {
unsigned MinSplatBits = 0,
bool isBigEndian = false) const;

/// Returns the demanded splatted value or a null value if this is not a
/// splat.
///
/// The DemandedElts mask indicates the elements that must be in the splat.
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the vector width and set the bits where elements are undef.
SDValue getSplatValue(const APInt &DemandedElts,
BitVector *UndefElements = nullptr) const;

/// Returns the splatted value or a null value if this is not a splat.
///
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the vector width and set the bits where elements are undef.
SDValue getSplatValue(BitVector *UndefElements = nullptr) const;

/// Returns the demanded splatted constant or null if this is not a constant
/// splat.
///
/// The DemandedElts mask indicates the elements that must be in the splat.
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the vector width and set the bits where elements are undef.
ConstantSDNode *
getConstantSplatNode(const APInt &DemandedElts,
BitVector *UndefElements = nullptr) const;

/// Returns the splatted constant or null if this is not a constant
/// splat.
///
Expand All @@ -1882,6 +1911,16 @@ class BuildVectorSDNode : public SDNode {
ConstantSDNode *
getConstantSplatNode(BitVector *UndefElements = nullptr) const;

/// Returns the demanded splatted constant FP or null if this is not a
/// constant FP splat.
///
/// The DemandedElts mask indicates the elements that must be in the splat.
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the vector width and set the bits where elements are undef.
ConstantFPSDNode *
getConstantFPSplatNode(const APInt &DemandedElts,
BitVector *UndefElements = nullptr) const;

/// Returns the splatted constant FP or null if this is not a constant
/// FP splat.
///
Expand Down
111 changes: 74 additions & 37 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -2253,30 +2253,6 @@ bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) {
(AllowUndefs || !UndefElts);
}

/// Helper function that checks to see if a node is a constant or a
/// build vector of splat constants at least within the demanded elts.
static ConstantSDNode *isConstOrDemandedConstSplat(SDValue N,
const APInt &DemandedElts) {
if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
return CN;
if (N.getOpcode() != ISD::BUILD_VECTOR)
return nullptr;
EVT VT = N.getValueType();
ConstantSDNode *Cst = nullptr;
unsigned NumElts = VT.getVectorNumElements();
assert(DemandedElts.getBitWidth() == NumElts && "Unexpected vector size");
for (unsigned i = 0; i != NumElts; ++i) {
if (!DemandedElts[i])
continue;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(i));
if (!C || (Cst && Cst->getAPIntValue() != C->getAPIntValue()) ||
C->getValueType(0) != VT.getScalarType())
return nullptr;
Cst = C;
}
return Cst;
}

/// If a SHL/SRA/SRL node has a constant or splat constant shift amount that
/// is less than the element bit-width of the shift node, return it.
static const APInt *getValidShiftAmountConstant(SDValue V) {
Expand Down Expand Up @@ -2717,8 +2693,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
case ISD::FSHL:
case ISD::FSHR:
if (ConstantSDNode *C =
isConstOrDemandedConstSplat(Op.getOperand(2), DemandedElts)) {
if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(2), DemandedElts)) {
unsigned Amt = C->getAPIntValue().urem(BitWidth);

// For fshl, 0-shift returns the 1st arg.
Expand Down Expand Up @@ -3155,10 +3130,10 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
// the minimum of the clamp min/max range.
bool IsMax = (Opcode == ISD::SMAX);
ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)))
if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1),
DemandedElts);
CstHigh =
isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
if (CstLow && CstHigh) {
if (!IsMax)
std::swap(CstLow, CstHigh);
Expand Down Expand Up @@ -3439,15 +3414,15 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
// SRA X, C -> adds C sign bits.
if (ConstantSDNode *C =
isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) {
isConstOrConstSplat(Op.getOperand(1), DemandedElts)) {
APInt ShiftVal = C->getAPIntValue();
ShiftVal += Tmp;
Tmp = ShiftVal.uge(VTBits) ? VTBits : ShiftVal.getZExtValue();
}
return Tmp;
case ISD::SHL:
if (ConstantSDNode *C =
isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) {
isConstOrConstSplat(Op.getOperand(1), DemandedElts)) {
// shl destroys sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
if (C->getAPIntValue().uge(VTBits) || // Bad shift.
Expand Down Expand Up @@ -3487,10 +3462,10 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
// the minimum of the clamp min/max range.
bool IsMax = (Opcode == ISD::SMAX);
ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)))
if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1),
DemandedElts);
CstHigh =
isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
if (CstLow && CstHigh) {
if (!IsMax)
std::swap(CstLow, CstHigh);
Expand Down Expand Up @@ -8593,6 +8568,24 @@ ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) {
return nullptr;
}

ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
bool AllowUndefs) {
if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
return CN;

if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
BitVector UndefElements;
ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements);

// BuildVectors can truncate their operands. Ignore that case here.
if (CN && (UndefElements.none() || AllowUndefs) &&
CN->getValueType(0) == N.getValueType().getScalarType())
return CN;
}

return nullptr;
}

ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) {
if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
return CN;
Expand All @@ -8607,6 +8600,23 @@ ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) {
return nullptr;
}

ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N,
const APInt &DemandedElts,
bool AllowUndefs) {
if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
return CN;

if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
BitVector UndefElements;
ConstantFPSDNode *CN =
BV->getConstantFPSplatNode(DemandedElts, &UndefElements);
if (CN && (UndefElements.none() || AllowUndefs))
return CN;
}

return nullptr;
}

bool llvm::isNullOrNullSplat(SDValue N, bool AllowUndefs) {
// TODO: may want to use peekThroughBitcast() here.
ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
Expand Down Expand Up @@ -9193,13 +9203,20 @@ bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue, APInt &SplatUndef,
return true;
}

SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
SDValue BuildVectorSDNode::getSplatValue(const APInt &DemandedElts,
BitVector *UndefElements) const {
if (UndefElements) {
UndefElements->clear();
UndefElements->resize(getNumOperands());
}
assert(getNumOperands() == DemandedElts.getBitWidth() &&
"Unexpected vector size");
if (!DemandedElts)
return SDValue();
SDValue Splatted;
for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
if (!DemandedElts[i])
continue;
SDValue Op = getOperand(i);
if (Op.isUndef()) {
if (UndefElements)
Expand All @@ -9212,19 +9229,39 @@ SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
}

if (!Splatted) {
assert(getOperand(0).isUndef() &&
unsigned FirstDemandedIdx = DemandedElts.countTrailingZeros();
assert(getOperand(FirstDemandedIdx).isUndef() &&
"Can only have a splat without a constant for all undefs.");
return getOperand(0);
return getOperand(FirstDemandedIdx);
}

return Splatted;
}

SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
APInt DemandedElts = APInt::getAllOnesValue(getNumOperands());
return getSplatValue(DemandedElts, UndefElements);
}

ConstantSDNode *
BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts,
BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantSDNode>(
getSplatValue(DemandedElts, UndefElements));
}

ConstantSDNode *
BuildVectorSDNode::getConstantSplatNode(BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantSDNode>(getSplatValue(UndefElements));
}

ConstantFPSDNode *
BuildVectorSDNode::getConstantFPSplatNode(const APInt &DemandedElts,
BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantFPSDNode>(
getSplatValue(DemandedElts, UndefElements));
}

ConstantFPSDNode *
BuildVectorSDNode::getConstantFPSplatNode(BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantFPSDNode>(getSplatValue(UndefElements));
Expand Down

0 comments on commit 80d0e9c

Please sign in to comment.