diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index f79bc84930f7fb..c407efb543bd1f 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -39597,6 +39597,7 @@ static SDValue combineOrShiftToFunnelShift(SDNode *N, SelectionDAG &DAG, SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64) return SDValue(); @@ -39620,11 +39621,13 @@ static SDValue combineOrShiftToFunnelShift(SDNode *N, SelectionDAG &DAG, if (!N0.hasOneUse() || !N1.hasOneUse()) return SDValue(); + EVT ShiftVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout()); + SDValue ShAmt0 = N0.getOperand(1); - if (ShAmt0.getValueType() != MVT::i8) + if (ShAmt0.getValueType() != ShiftVT) return SDValue(); SDValue ShAmt1 = N1.getOperand(1); - if (ShAmt1.getValueType() != MVT::i8) + if (ShAmt1.getValueType() != ShiftVT) return SDValue(); // Peek through any modulo shift masks. @@ -39659,12 +39662,12 @@ static SDValue combineOrShiftToFunnelShift(SDNode *N, SelectionDAG &DAG, std::swap(ShMsk0, ShMsk1); } - auto GetFunnelShift = [&DAG, &DL, VT, Opc](SDValue Op0, SDValue Op1, - SDValue Amt) { + auto GetFunnelShift = [&DAG, &DL, VT, Opc, &ShiftVT](SDValue Op0, SDValue Op1, + SDValue Amt) { if (Opc == ISD::FSHR) std::swap(Op0, Op1); return DAG.getNode(Opc, DL, VT, Op0, Op1, - DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Amt)); + DAG.getNode(ISD::TRUNCATE, DL, ShiftVT, Amt)); }; // OR( SHL( X, C ), SRL( Y, 32 - C ) ) -> FSHL( X, Y, C )