diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 022efc609e481..155d68c0b220a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1769,15 +1769,11 @@ bool TargetLowering::SimplifyDemandedBits( // undesirable. break; - auto *ShAmt = dyn_cast(Src.getOperand(1)); - if (!ShAmt || ShAmt->getAPIntValue().uge(BitWidth)) + SDValue ShAmt = Src.getOperand(1); + auto *ShAmtC = dyn_cast(ShAmt); + if (!ShAmtC || ShAmtC->getAPIntValue().uge(BitWidth)) break; - - SDValue Shift = Src.getOperand(1); - uint64_t ShVal = ShAmt->getZExtValue(); - - if (TLO.LegalTypes()) - Shift = TLO.DAG.getConstant(ShVal, dl, getShiftAmountTy(VT, DL)); + uint64_t ShVal = ShAmtC->getZExtValue(); APInt HighBits = APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth); @@ -1787,10 +1783,12 @@ bool TargetLowering::SimplifyDemandedBits( if (!(HighBits & DemandedBits)) { // None of the shifted in bits are needed. Add a truncate of the // shift input, then shift it. + if (TLO.LegalTypes()) + ShAmt = TLO.DAG.getConstant(ShVal, dl, getShiftAmountTy(VT, DL)); SDValue NewTrunc = TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0)); return TLO.CombineTo( - Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, Shift)); + Op, TLO.DAG.getNode(ISD::SRL, dl, VT, NewTrunc, ShAmt)); } break; }