From 35189d52218511153c6ad3a027599bb814818779 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Tue, 22 Aug 2017 23:54:13 +0000 Subject: [PATCH] [SelectionDAG] Make ISD::isConstantSplatVector always return an element sized APInt. This partially reverts r311429 in favor of making ISD::isConstantSplatVector do something not confusing. Turns out the only other user of it was also having to deal with the weird property of it returning a smaller size. So rather than continue to deal with this quirk everywhere, just make the interface do something sane. Differential Revision: https://reviews.llvm.org/D37039 llvm-svn: 311510 --- llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 5 +---- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 15 ++++++++------- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 8 +++----- llvm/lib/Target/X86/X86ISelLowering.cpp | 11 ++++------- 4 files changed, 16 insertions(+), 23 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 051c93601d3fe..db42fb6c170c0 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -85,10 +85,7 @@ namespace ISD { /// If N is a BUILD_VECTOR node whose elements are all the same constant or /// undefined, return true and return the constant value in \p SplatValue. - /// This sets \p SplatValue to the smallest possible splat unless AllowShrink - /// is set to false. - bool isConstantSplatVector(const SDNode *N, APInt &SplatValue, - bool AllowShrink = true); + bool isConstantSplatVector(const SDNode *N, APInt &SplatValue); /// Return true if the specified node is a BUILD_VECTOR where all of the /// elements are ~0 or undef. diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index b8f830852990b..ab459cd7ea034 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2588,6 +2588,12 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { N0IsConst = ISD::isConstantSplatVector(N0.getNode(), ConstValue0); N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1); + assert((!N0IsConst || + ConstValue0.getBitWidth() == VT.getScalarSizeInBits()) && + "Splat APInt should be element width"); + assert((!N1IsConst || + ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) && + "Splat APInt should be element width"); } else { N0IsConst = isa(N0); if (N0IsConst) { @@ -2613,12 +2619,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { // fold (mul x, 0) -> 0 if (N1IsConst && ConstValue1.isNullValue()) return N1; - // We require a splat of the entire scalar bit width for non-contiguous - // bit patterns. - bool IsFullSplat = - ConstValue1.getBitWidth() == VT.getScalarSizeInBits(); // fold (mul x, 1) -> x - if (N1IsConst && ConstValue1.isOneValue() && IsFullSplat) + if (N1IsConst && ConstValue1.isOneValue()) return N0; if (SDValue NewSel = foldBinOpIntoSelect(N)) @@ -2643,8 +2645,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc); } // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c - if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2() && - IsFullSplat) { + if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2()) { unsigned Log2Val = (-ConstValue1).logBase2(); SDLoc DL(N); // FIXME: If the input is something that is easily negated (e.g. a diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index d555315c5232b..b4154210329e4 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -125,8 +125,7 @@ bool ConstantFPSDNode::isValueValidForType(EVT VT, // ISD Namespace //===----------------------------------------------------------------------===// -bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal, - bool AllowShrink) { +bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) { auto *BV = dyn_cast(N); if (!BV) return false; @@ -135,10 +134,9 @@ bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal, unsigned SplatBitSize; bool HasUndefs; unsigned EltSize = N->getValueType(0).getVectorElementType().getSizeInBits(); - unsigned MinSplatBits = AllowShrink ? 0 : EltSize; return BV->isConstantSplat(SplatVal, SplatUndef, SplatBitSize, HasUndefs, - MinSplatBits) && - EltSize >= SplatBitSize; + EltSize) && + EltSize == SplatBitSize; } // FIXME: AllOnes and AllZeros duplicate a lot of code. Could these be diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 42c33953507a2..7e283d1d01b03 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -29567,8 +29567,7 @@ static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, // In SetLT case, The second operand of the comparison can be either 1 or 0. APInt SplatVal; if ((CC == ISD::SETLT) && - !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal, - /*AllowShrink*/false) && + !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal) && SplatVal.isOneValue()) || (ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode())))) return false; @@ -32084,8 +32083,7 @@ static SDValue combineAndMaskToShift(SDNode *N, SelectionDAG &DAG, return SDValue(); APInt SplatVal; - if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal, - /*AllowShrink*/false) || + if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal) || !SplatVal.isMask()) return SDValue(); @@ -32669,8 +32667,7 @@ static SDValue detectUSatPattern(SDValue In, EVT VT) { "Unexpected types for truncate operation"); APInt C; - if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C, - /*AllowShrink*/false)) { + if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C)) { // C should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according // the element size of the destination type. return C.isMask(VT.getScalarSizeInBits()) ? In.getOperand(0) : @@ -35377,7 +35374,7 @@ static SDValue combineIncDecVector(SDNode *N, SelectionDAG &DAG) { SDNode *N1 = N->getOperand(1).getNode(); APInt SplatVal; - if (!ISD::isConstantSplatVector(N1, SplatVal, /*AllowShrink*/false) || + if (!ISD::isConstantSplatVector(N1, SplatVal) || !SplatVal.isOneValue()) return SDValue();