diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index a31b8070deb82c..ded156a69d2b60 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3261,38 +3261,31 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, break; } case ISD::INSERT_VECTOR_ELT: { + // If we know the element index, split the demand between the + // source vector and the inserted element, otherwise assume we need + // the original demanded vector elements and the value. SDValue InVec = Op.getOperand(0); SDValue InVal = Op.getOperand(1); SDValue EltNo = Op.getOperand(2); - - ConstantSDNode *CEltNo = dyn_cast(EltNo); + bool DemandedVal = true; + APInt DemandedVecElts = DemandedElts; + auto *CEltNo = dyn_cast(EltNo); if (CEltNo && CEltNo->getAPIntValue().ult(NumElts)) { - // If we know the element index, split the demand between the - // source vector and the inserted element. - Known.Zero = Known.One = APInt::getAllOnesValue(BitWidth); unsigned EltIdx = CEltNo->getZExtValue(); - - // If we demand the inserted element then add its common known bits. - if (DemandedElts[EltIdx]) { - Known2 = computeKnownBits(InVal, Depth + 1); - Known.One &= Known2.One.zextOrTrunc(Known.One.getBitWidth()); - Known.Zero &= Known2.Zero.zextOrTrunc(Known.Zero.getBitWidth()); - } - - // If we demand the source vector then add its common known bits, ensuring - // that we don't demand the inserted element. - APInt VectorElts = DemandedElts & ~(APInt::getOneBitSet(NumElts, EltIdx)); - if (!!VectorElts) { - Known2 = computeKnownBits(InVec, VectorElts, Depth + 1); - Known.One &= Known2.One; - Known.Zero &= Known2.Zero; - } - } else { - // Unknown element index, so ignore DemandedElts and demand them all. - Known = computeKnownBits(InVec, Depth + 1); + DemandedVal = !!DemandedElts[EltIdx]; + DemandedVecElts.clearBit(EltIdx); + } + Known.One.setAllBits(); + Known.Zero.setAllBits(); + if (DemandedVal) { Known2 = computeKnownBits(InVal, Depth + 1); - Known.One &= Known2.One.zextOrTrunc(Known.One.getBitWidth()); - Known.Zero &= Known2.Zero.zextOrTrunc(Known.Zero.getBitWidth()); + Known.One &= Known2.One.zextOrTrunc(BitWidth); + Known.Zero &= Known2.Zero.zextOrTrunc(BitWidth); + } + if (!!DemandedVecElts) { + Known2 = computeKnownBits(InVec, DemandedVecElts, Depth + 1); + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; } break; } @@ -3850,39 +3843,32 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, return std::max(std::min(KnownSign - rIndex * BitWidth, BitWidth), 0); } case ISD::INSERT_VECTOR_ELT: { + // If we know the element index, split the demand between the + // source vector and the inserted element, otherwise assume we need + // the original demanded vector elements and the value. SDValue InVec = Op.getOperand(0); SDValue InVal = Op.getOperand(1); SDValue EltNo = Op.getOperand(2); - - ConstantSDNode *CEltNo = dyn_cast(EltNo); + bool DemandedVal = true; + APInt DemandedVecElts = DemandedElts; + auto *CEltNo = dyn_cast(EltNo); if (CEltNo && CEltNo->getAPIntValue().ult(NumElts)) { - // If we know the element index, split the demand between the - // source vector and the inserted element. unsigned EltIdx = CEltNo->getZExtValue(); - - // If we demand the inserted element then get its sign bits. - Tmp = std::numeric_limits::max(); - if (DemandedElts[EltIdx]) { - // TODO - handle implicit truncation of inserted elements. - if (InVal.getScalarValueSizeInBits() != VTBits) - break; - Tmp = ComputeNumSignBits(InVal, Depth + 1); - } - - // If we demand the source vector then get its sign bits, and determine - // the minimum. - APInt VectorElts = DemandedElts; - VectorElts.clearBit(EltIdx); - if (!!VectorElts) { - Tmp2 = ComputeNumSignBits(InVec, VectorElts, Depth + 1); - Tmp = std::min(Tmp, Tmp2); - } - } else { - // Unknown element index, so ignore DemandedElts and demand them all. - Tmp = ComputeNumSignBits(InVec, Depth + 1); + DemandedVal = !!DemandedElts[EltIdx]; + DemandedVecElts.clearBit(EltIdx); + } + Tmp = std::numeric_limits::max(); + if (DemandedVal) { + // TODO - handle implicit truncation of inserted elements. + if (InVal.getScalarValueSizeInBits() != VTBits) + break; Tmp2 = ComputeNumSignBits(InVal, Depth + 1); Tmp = std::min(Tmp, Tmp2); } + if (!!DemandedVecElts) { + Tmp2 = ComputeNumSignBits(InVec, DemandedVecElts, Depth + 1); + Tmp = std::min(Tmp, Tmp2); + } assert(Tmp <= VTBits && "Failed to determine minimum sign bits"); return Tmp; }