Skip to content

Commit

Permalink
[X86][SSE] Add initial FSHL/FSHR vXi8 lowering support
Browse files Browse the repository at this point in the history
This is very similar to the existing ROTL/ROTR support for scalar shifts in LowerRotate, I think as time goes on we should be able to share much of this code in helpers between Funnel Shift + Rotation lowering.
  • Loading branch information
RKSimon committed Jan 8, 2022
1 parent 9cf9ed9 commit b5d2e23
Show file tree
Hide file tree
Showing 7 changed files with 1,050 additions and 1,336 deletions.
73 changes: 61 additions & 12 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::ROTR, VT, Custom);
}

setOperationAction(ISD::FSHL, MVT::v16i8, Custom);
setOperationAction(ISD::FSHR, MVT::v16i8, Custom);

setOperationAction(ISD::STRICT_FSQRT, MVT::v2f64, Legal);
setOperationAction(ISD::STRICT_FADD, MVT::v2f64, Legal);
setOperationAction(ISD::STRICT_FSUB, MVT::v2f64, Legal);
Expand Down Expand Up @@ -1284,6 +1287,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::ROTR, VT, Custom);
}

setOperationAction(ISD::FSHL, MVT::v32i8, Custom);
setOperationAction(ISD::FSHR, MVT::v32i8, Custom);

// These types need custom splitting if their input is a 128-bit vector.
setOperationAction(ISD::SIGN_EXTEND, MVT::v8i64, Custom);
setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom);
Expand Down Expand Up @@ -1688,6 +1694,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::SSUBSAT, VT, HasBWI ? Legal : Custom);
}

setOperationAction(ISD::FSHL, MVT::v64i8, Custom);
setOperationAction(ISD::FSHR, MVT::v64i8, Custom);

if (Subtarget.hasDQI()) {
setOperationAction(ISD::SINT_TO_FP, MVT::v8i64, Legal);
setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal);
Expand Down Expand Up @@ -29740,20 +29749,60 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
bool IsFSHR = Op.getOpcode() == ISD::FSHR;

if (VT.isVector()) {
assert(Subtarget.hasVBMI2() && "Expected VBMI2");
APInt APIntShiftAmt;
bool IsCstSplat = X86::isConstantSplat(Amt, APIntShiftAmt);

if (IsFSHR)
std::swap(Op0, Op1);
if (Subtarget.hasVBMI2() && EltSizeInBits > 8) {
if (IsFSHR)
std::swap(Op0, Op1);

APInt APIntShiftAmt;
if (X86::isConstantSplat(Amt, APIntShiftAmt)) {
uint64_t ShiftAmt = APIntShiftAmt.urem(EltSizeInBits);
SDValue Imm = DAG.getTargetConstant(ShiftAmt, DL, MVT::i8);
return getAVX512Node(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, VT,
{Op0, Op1, Imm}, DAG, Subtarget);
}
return getAVX512Node(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL, VT,
{Op0, Op1, Amt}, DAG, Subtarget);
if (IsCstSplat) {
uint64_t ShiftAmt = APIntShiftAmt.urem(EltSizeInBits);
SDValue Imm = DAG.getTargetConstant(ShiftAmt, DL, MVT::i8);
return getAVX512Node(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, VT,
{Op0, Op1, Imm}, DAG, Subtarget);
}
return getAVX512Node(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL, VT,
{Op0, Op1, Amt}, DAG, Subtarget);
}
assert((VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8) &&
"Unexpected funnel shift type!");

// fshl(x,y,z) -> unpack(y,x) << (z & (bw-1))) >> bw.
// fshr(x,y,z) -> unpack(y,x) >> (z & (bw-1))).
if (IsCstSplat)
return SDValue();

SDValue AmtMask = DAG.getConstant(EltSizeInBits - 1, DL, VT);
SDValue AmtMod = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);

MVT ExtSVT = MVT::getIntegerVT(2 * EltSizeInBits);
MVT ExtVT = MVT::getVectorVT(ExtSVT, VT.getVectorNumElements() / 2);

// Split 256-bit integers on XOP/pre-AVX2 targets.
// Split 512-bit integers on non 512-bit BWI targets.
if ((VT.is256BitVector() && (Subtarget.hasXOP() || !Subtarget.hasAVX2())) ||
(VT.is512BitVector() && !Subtarget.useBWIRegs())) {
// Pre-mask the amount modulo using the wider vector.
Op = DAG.getNode(Op.getOpcode(), DL, VT, Op0, Op1, AmtMod);
return splitVectorOp(Op, DAG);
}

// Attempt to fold scalar shift as unpack(y,x) << zext(splat(z))
if (SDValue ScalarAmt = DAG.getSplatValue(AmtMod)) {
unsigned ShiftX86Opc = IsFSHR ? X86ISD::VSRLI : X86ISD::VSHLI;
SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, Op1, Op0));
SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, Op1, Op0));
ScalarAmt = DAG.getZExtOrTrunc(ScalarAmt, DL, MVT::i32);
Lo = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Lo, ScalarAmt, Subtarget,
DAG);
Hi = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Hi, ScalarAmt, Subtarget,
DAG);
return getPack(DAG, Subtarget, DL, VT, Lo, Hi, !IsFSHR);
}

// Fallback to generic expansion.
return SDValue();
}
assert(
(VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) &&
Expand Down
Loading

0 comments on commit b5d2e23

Please sign in to comment.