diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index f93225d44fe12a..1e345caa1e17b7 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1364,6 +1364,7 @@ bool TargetLowering::SimplifyDemandedBits( case ISD::SHL: { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); + EVT ShiftVT = Op1.getValueType(); if (const APInt *SA = TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) { @@ -1388,7 +1389,7 @@ bool TargetLowering::SimplifyDemandedBits( Opc = ISD::SRL; } - SDValue NewSA = TLO.DAG.getConstant(Diff, dl, Op1.getValueType()); + SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT); return TLO.CombineTo( Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA)); } @@ -1396,19 +1397,9 @@ bool TargetLowering::SimplifyDemandedBits( } } - APInt InDemandedMask = DemandedBits.lshr(ShAmt); - if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, - Depth + 1)) - return true; - - // Try shrinking the operation as long as the shift amount will still be - // in range. - if ((ShAmt < DemandedBits.getActiveBits()) && - ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) - return true; - // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits // are not demanded. This will likely allow the anyext to be folded away. + // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::ANY_EXTEND) { SDValue InnerOp = Op0.getOperand(0); EVT InnerVT = InnerOp.getValueType(); @@ -1424,11 +1415,13 @@ bool TargetLowering::SimplifyDemandedBits( return TLO.CombineTo( Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl)); } + // Repeat the SHL optimization above in cases where an extension // intervenes: (shl (anyext (shr x, c1)), c2) to // (shl (anyext x), c2-c1). This requires that the bottom c1 bits // aren't demanded (as above) and that the shifted upper c1 bits of // x aren't demanded. + // TODO - support non-uniform vector amounts. if (Op0.hasOneUse() && InnerOp.getOpcode() == ISD::SRL && InnerOp.hasOneUse()) { if (const APInt *SA2 = @@ -1438,8 +1431,8 @@ bool TargetLowering::SimplifyDemandedBits( DemandedBits.getActiveBits() <= (InnerBits - InnerShAmt + ShAmt) && DemandedBits.countTrailingZeros() >= ShAmt) { - SDValue NewSA = TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, - Op1.getValueType()); + SDValue NewSA = + TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, ShiftVT); SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, InnerOp.getOperand(0)); return TLO.CombineTo( @@ -1449,16 +1442,28 @@ bool TargetLowering::SimplifyDemandedBits( } } + APInt InDemandedMask = DemandedBits.lshr(ShAmt); + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, + Depth + 1)) + return true; + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero <<= ShAmt; Known.One <<= ShAmt; // low bits known zero. Known.Zero.setLowBits(ShAmt); + + // Try shrinking the operation as long as the shift amount will still be + // in range. + if ((ShAmt < DemandedBits.getActiveBits()) && + ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) + return true; } break; } case ISD::SRL: { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); + EVT ShiftVT = Op1.getValueType(); if (const APInt *SA = TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) { @@ -1466,14 +1471,6 @@ bool TargetLowering::SimplifyDemandedBits( if (ShAmt == 0) return TLO.CombineTo(Op, Op0); - EVT ShiftVT = Op1.getValueType(); - APInt InDemandedMask = (DemandedBits << ShAmt); - - // If the shift is exact, then it does demand the low bits (and knows that - // they are zero). - if (Op->getFlags().hasExact()) - InDemandedMask.setLowBits(ShAmt); - // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a // single shift. We can do this if the top bits (which are shifted out) // are never demanded. @@ -1500,6 +1497,13 @@ bool TargetLowering::SimplifyDemandedBits( } } + APInt InDemandedMask = (DemandedBits << ShAmt); + + // If the shift is exact, then it does demand the low bits (and knows that + // they are zero). + if (Op->getFlags().hasExact()) + InDemandedMask.setLowBits(ShAmt); + // Compute the new bits that are at the top now. if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) @@ -1515,6 +1519,7 @@ bool TargetLowering::SimplifyDemandedBits( case ISD::SRA: { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); + EVT ShiftVT = Op1.getValueType(); // If we only want bits that already match the signbit then we don't need // to shift. @@ -1568,8 +1573,7 @@ bool TargetLowering::SimplifyDemandedBits( int Log2 = DemandedBits.exactLogBase2(); if (Log2 >= 0) { // The bit must come from the sign. - SDValue NewSA = - TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, Op1.getValueType()); + SDValue NewSA = TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, ShiftVT); return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, NewSA)); }