diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 71dfd090142e7..0d152d65022c2 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -6393,6 +6393,37 @@ SDValue SplitOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget, return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs); } +// Helper function that extends a non-512-bit vector op to 512-bits on non-VLX +// targets. +static SDValue getAVX512Node(unsigned Opcode, const SDLoc &DL, MVT VT, + ArrayRef Ops, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + assert(Subtarget.hasAVX512() && "AVX512 target expected"); + + // If we have VLX or the type is already 512-bits, then create the node + // directly. + if (Subtarget.hasVLX() || VT.is512BitVector()) + return DAG.getNode(Opcode, DL, VT, Ops); + + // Widen the vector ops. + MVT SVT = VT.getScalarType(); + MVT WideVT = MVT::getVectorVT(SVT, 512 / SVT.getSizeInBits()); + SmallVector WideOps(Ops.begin(), Ops.end()); + for (SDValue &Op : WideOps) { + MVT OpVT = Op.getSimpleValueType(); + // Just pass through scalar operands. + if (!OpVT.isVector()) + continue; + assert(OpVT.getSizeInBits() == VT.getSizeInBits() && + "Vector size mismatch"); + Op = widenSubVector(Op, false, Subtarget, DAG, DL, 512); + } + + // Perform the 512-bit op then extract the bottom subvector. + SDValue Res = DAG.getNode(Opcode, DL, WideVT, WideOps); + return extractSubVector(Res, 0, DAG, DL, VT.getSizeInBits()); +} + /// Insert i1-subvector to i1-vector. static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG, const X86Subtarget &Subtarget) { @@ -29593,29 +29624,15 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget, if (IsFSHR) std::swap(Op0, Op1); - // With AVX512, but not VLX we need to widen to get a 512-bit result type. - if (!Subtarget.hasVLX() && !VT.is512BitVector()) { - Op0 = widenSubVector(Op0, false, Subtarget, DAG, DL, 512); - Op1 = widenSubVector(Op1, false, Subtarget, DAG, DL, 512); - } - - SDValue Funnel; APInt APIntShiftAmt; - MVT ResultVT = Op0.getSimpleValueType(); if (X86::isConstantSplat(Amt, APIntShiftAmt)) { uint64_t ShiftAmt = APIntShiftAmt.urem(VT.getScalarSizeInBits()); - Funnel = - DAG.getNode(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, ResultVT, Op0, - Op1, DAG.getTargetConstant(ShiftAmt, DL, MVT::i8)); - } else { - if (!Subtarget.hasVLX() && !VT.is512BitVector()) - Amt = widenSubVector(Amt, false, Subtarget, DAG, DL, 512); - Funnel = DAG.getNode(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL, - ResultVT, Op0, Op1, Amt); - } - if (!Subtarget.hasVLX() && !VT.is512BitVector()) - Funnel = extractSubVector(Funnel, 0, DAG, DL, VT.getSizeInBits()); - return Funnel; + SDValue Imm = DAG.getTargetConstant(ShiftAmt, DL, MVT::i8); + return getAVX512Node(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, VT, + {Op0, Op1, Imm}, DAG, Subtarget); + } + return getAVX512Node(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL, VT, + {Op0, Op1, Amt}, DAG, Subtarget); } assert( (VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) &&