Skip to content

Commit

Permalink
[X86][SSE] Add ISD::ROTR support
Browse files Browse the repository at this point in the history
Fix issue in TargetLowering::expandROT where we only attempt to flip a rotation if the other direction has better support - this matches TargetLowering::expandFunnelShift

This allows us to enable ISD::ROTR lowering on SSE targets, which particularly simplifies/improves codegen for splat amount and AVX2 per-element shifts.
  • Loading branch information
RKSimon committed Dec 23, 2021
1 parent 0ff20f2 commit 71fc4bb
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 240 deletions.
5 changes: 3 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6657,9 +6657,10 @@ SDValue TargetLowering::expandROT(SDNode *Node, bool AllowVectorOps,
EVT ShVT = Op1.getValueType();
SDValue Zero = DAG.getConstant(0, DL, ShVT);

// If a rotate in the other direction is supported, use it.
// If a rotate in the other direction is more supported, use it.
unsigned RevRot = IsLeft ? ISD::ROTR : ISD::ROTL;
if (isOperationLegalOrCustom(RevRot, VT) && isPowerOf2_32(EltSizeInBits)) {
if (!isOperationLegalOrCustom(Node->getOpcode(), VT) &&
isOperationLegalOrCustom(RevRot, VT) && isPowerOf2_32(EltSizeInBits)) {
SDValue Sub = DAG.getNode(ISD::SUB, DL, ShVT, Zero, Op1);
return DAG.getNode(RevRot, DL, VT, Op0, Sub);
}
Expand Down
93 changes: 54 additions & 39 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1091,12 +1091,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::SRL, VT, Custom);
setOperationAction(ISD::SHL, VT, Custom);
setOperationAction(ISD::SRA, VT, Custom);
if (VT == MVT::v2i64) continue;
setOperationAction(ISD::ROTL, VT, Custom);
setOperationAction(ISD::ROTR, VT, Custom);
}

setOperationAction(ISD::ROTL, MVT::v4i32, Custom);
setOperationAction(ISD::ROTL, MVT::v8i16, Custom);
setOperationAction(ISD::ROTL, MVT::v16i8, Custom);

setOperationAction(ISD::STRICT_FSQRT, MVT::v2f64, Legal);
setOperationAction(ISD::STRICT_FADD, MVT::v2f64, Legal);
setOperationAction(ISD::STRICT_FSUB, MVT::v2f64, Legal);
Expand Down Expand Up @@ -1194,8 +1193,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,

if (!Subtarget.useSoftFloat() && Subtarget.hasXOP()) {
for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64,
MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 })
MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) {
setOperationAction(ISD::ROTL, VT, Custom);
setOperationAction(ISD::ROTR, VT, Custom);
}

// XOP can efficiently perform BITREVERSE with VPPERM.
for (auto VT : { MVT::i8, MVT::i16, MVT::i32, MVT::i64 })
Expand Down Expand Up @@ -1278,6 +1279,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::SRL, VT, Custom);
setOperationAction(ISD::SHL, VT, Custom);
setOperationAction(ISD::SRA, VT, Custom);
if (VT == MVT::v4i64) continue;
setOperationAction(ISD::ROTL, VT, Custom);
setOperationAction(ISD::ROTR, VT, Custom);
}

// These types need custom splitting if their input is a 128-bit vector.
Expand All @@ -1286,10 +1290,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::ZERO_EXTEND, MVT::v8i64, Custom);
setOperationAction(ISD::ZERO_EXTEND, MVT::v16i32, Custom);

setOperationAction(ISD::ROTL, MVT::v8i32, Custom);
setOperationAction(ISD::ROTL, MVT::v16i16, Custom);
setOperationAction(ISD::ROTL, MVT::v32i8, Custom);

setOperationAction(ISD::SELECT, MVT::v4f64, Custom);
setOperationAction(ISD::SELECT, MVT::v4i64, Custom);
setOperationAction(ISD::SELECT, MVT::v8i32, Custom);
Expand Down Expand Up @@ -1675,10 +1675,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
}

// With BWI, expanding (and promoting the shifts) is the better.
if (!Subtarget.useBWIRegs())
if (!Subtarget.useBWIRegs()) {
setOperationAction(ISD::ROTL, MVT::v32i16, Custom);
setOperationAction(ISD::ROTR, MVT::v32i16, Custom);
}

setOperationAction(ISD::ROTL, MVT::v64i8, Custom);
setOperationAction(ISD::ROTR, MVT::v64i8, Custom);

for (auto VT : { MVT::v64i8, MVT::v32i16 }) {
setOperationAction(ISD::ABS, VT, HasBWI ? Legal : Custom);
Expand Down Expand Up @@ -29847,7 +29850,19 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getNode(FunnelOpc, DL, VT, R, R, Amt);
}

assert(IsROTL && "Only ROTL supported");
SDValue Z = DAG.getConstant(0, DL, VT);

if (!IsROTL) {
// If the ISD::ROTR amount is constant, we're always better converting to
// ISD::ROTL.
if (SDValue NegAmt = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {Z, Amt}))
return DAG.getNode(ISD::ROTL, DL, VT, R, NegAmt);

// XOP targets always prefers ISD::ROTL.
if (Subtarget.hasXOP())
return DAG.getNode(ISD::ROTL, DL, VT, R,
DAG.getNode(ISD::SUB, DL, VT, Z, Amt));
}

// Split 256-bit integers on XOP/pre-AVX2 targets.
if (VT.is256BitVector() && (Subtarget.hasXOP() || !Subtarget.hasAVX2()))
Expand All @@ -29857,6 +29872,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
// +ve/-ve Amt = rotate left/right - just need to handle ISD::ROTL.
// XOP implicitly uses modulo rotation amounts.
if (Subtarget.hasXOP()) {
assert(IsROTL && "Only ROTL expected");
assert(VT.is128BitVector() && "Only rotate 128-bit vectors!");

// Attempt to rotate by immediate.
Expand Down Expand Up @@ -29885,36 +29901,27 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
(VT == MVT::v64i8 && Subtarget.useBWIRegs())) &&
"Only vXi32/vXi16/vXi8 vector rotates supported");

// Check for a hidden ISD::ROTR, splat + vXi8 lowering can handle both, but we
// currently hit infinite loops in legalization if we allow ISD::ROTR.
// FIXME: Infinite ROTL<->ROTR legalization in TargetLowering::expandROT.
SDValue HiddenROTRAmt;
if (Amt.getOpcode() == ISD::SUB &&
ISD::isBuildVectorAllZeros(Amt.getOperand(0).getNode()))
HiddenROTRAmt = Amt.getOperand(1);

MVT ExtSVT = MVT::getIntegerVT(2 * EltSizeInBits);
MVT ExtVT = MVT::getVectorVT(ExtSVT, NumElts / 2);

SDValue AmtMask = DAG.getConstant(EltSizeInBits - 1, DL, VT);
SDValue AmtMod = DAG.getNode(ISD::AND, DL, VT,
HiddenROTRAmt ? HiddenROTRAmt : Amt, AmtMask);
SDValue AmtMod = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);

// Attempt to fold as unpack(x,x) << zext(splat(y)):
// rotl(x,y) -> (unpack(x,x) << (y & (bw-1))) >> bw.
// rotr(x,y) -> (unpack(x,x) >> (y & (bw-1))).
// TODO: Handle vXi16 cases.
if (EltSizeInBits == 8 || EltSizeInBits == 32) {
if (SDValue BaseRotAmt = DAG.getSplatValue(AmtMod)) {
unsigned ShiftX86Opc = HiddenROTRAmt ? X86ISD::VSRLI : X86ISD::VSHLI;
unsigned ShiftX86Opc = IsROTL ? X86ISD::VSHLI : X86ISD::VSRLI;
SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, R, R));
SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, R, R));
BaseRotAmt = DAG.getZExtOrTrunc(BaseRotAmt, DL, MVT::i32);
Lo = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Lo, BaseRotAmt,
Subtarget, DAG);
Hi = getTargetVShiftNode(ShiftX86Opc, DL, ExtVT, Hi, BaseRotAmt,
Subtarget, DAG);
return getPack(DAG, Subtarget, DL, VT, Lo, Hi, !HiddenROTRAmt);
return getPack(DAG, Subtarget, DL, VT, Lo, Hi, IsROTL);
}
}

Expand All @@ -29925,7 +29932,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
bool IsConstAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode());
MVT WideVT =
MVT::getVectorVT(Subtarget.hasBWI() ? MVT::i16 : MVT::i32, NumElts);
unsigned ShiftOpc = HiddenROTRAmt ? ISD::SRL : ISD::SHL;
unsigned ShiftOpc = IsROTL ? ISD::SHL : ISD::SRL;

// Attempt to fold as:
// rotl(x,y) -> (((aext(x) << bw) | zext(x)) << (y & (bw-1))) >> bw.
Expand All @@ -29942,7 +29949,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
getTargetVShiftByConstNode(X86ISD::VSHLI, DL, WideVT, R, 8, DAG));
Amt = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT, AmtMod);
R = DAG.getNode(ShiftOpc, DL, WideVT, R, Amt);
if (!HiddenROTRAmt)
if (IsROTL)
R = getTargetVShiftByConstNode(X86ISD::VSRLI, DL, WideVT, R, 8, DAG);
return DAG.getNode(ISD::TRUNCATE, DL, VT, R);
}
Expand All @@ -29952,14 +29959,13 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
// rotr(x,y) -> (unpack(x,x) >> (y & (bw-1))).
if (IsConstAmt || supportedVectorVarShift(ExtVT, Subtarget, ShiftOpc)) {
// See if we can perform this by unpacking to lo/hi vXi16.
SDValue Z = DAG.getConstant(0, DL, VT);
SDValue RLo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, R, R));
SDValue RHi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, R, R));
SDValue ALo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, AmtMod, Z));
SDValue AHi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, AmtMod, Z));
SDValue Lo = DAG.getNode(ShiftOpc, DL, ExtVT, RLo, ALo);
SDValue Hi = DAG.getNode(ShiftOpc, DL, ExtVT, RHi, AHi);
return getPack(DAG, Subtarget, DL, VT, Lo, Hi, !HiddenROTRAmt);
return getPack(DAG, Subtarget, DL, VT, Lo, Hi, IsROTL);
}
assert((VT == MVT::v16i8 || VT == MVT::v32i8) && "Unsupported vXi8 type");

Expand All @@ -29982,15 +29988,15 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getSelect(DL, SelVT, C, V0, V1);
};

// 'Hidden' ROTR is currently only profitable on AVX512 targets where we
// have VPTERNLOG.
unsigned ShiftLHS = ISD::SHL;
unsigned ShiftRHS = ISD::SRL;
if (HiddenROTRAmt && useVPTERNLOG(Subtarget, VT)) {
std::swap(ShiftLHS, ShiftRHS);
Amt = HiddenROTRAmt;
// ISD::ROTR is currently only profitable on AVX512 targets with VPTERNLOG.
if (!IsROTL && !useVPTERNLOG(Subtarget, VT)) {
Amt = DAG.getNode(ISD::SUB, DL, VT, Z, Amt);
IsROTL = true;
}

unsigned ShiftLHS = IsROTL ? ISD::SHL : ISD::SRL;
unsigned ShiftRHS = IsROTL ? ISD::SRL : ISD::SHL;

// Turn 'a' into a mask suitable for VSELECT: a = a << 5;
// We can safely do this using i16 shifts as we're only interested in
// the 3 lower bits of each byte.
Expand Down Expand Up @@ -30027,9 +30033,6 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
return SignBitSelect(VT, Amt, M, R);
}

// ISD::ROT* uses modulo rotate amounts.
Amt = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);

bool IsSplatAmt = DAG.isSplatValue(Amt);
bool ConstantAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode());
bool LegalVarShifts = supportedVectorVarShift(VT, Subtarget, ISD::SHL) &&
Expand All @@ -30038,13 +30041,25 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
// Fallback for splats + all supported variable shifts.
// Fallback for non-constants AVX2 vXi16 as well.
if (IsSplatAmt || LegalVarShifts || (Subtarget.hasAVX2() && !ConstantAmt)) {
Amt = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);
SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT);
AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt);
SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt);
SDValue SRL = DAG.getNode(ISD::SRL, DL, VT, R, AmtR);
SDValue SHL = DAG.getNode(IsROTL ? ISD::SHL : ISD::SRL, DL, VT, R, Amt);
SDValue SRL = DAG.getNode(IsROTL ? ISD::SRL : ISD::SHL, DL, VT, R, AmtR);
return DAG.getNode(ISD::OR, DL, VT, SHL, SRL);
}

// Everything below assumes ISD::ROTL.
if (!IsROTL) {
Amt = DAG.getNode(ISD::SUB, DL, VT, Z, Amt);
IsROTL = true;
}

// ISD::ROT* uses modulo rotate amounts.
Amt = DAG.getNode(ISD::AND, DL, VT, Amt, AmtMask);

assert(IsROTL && "Only ROTL supported");

// As with shifts, attempt to convert the rotation amount to a multiplication
// factor, fallback to general expansion.
SDValue Scale = convertShiftLeftToScale(Amt, DL, Subtarget, DAG);
Expand Down
6 changes: 2 additions & 4 deletions llvm/test/CodeGen/X86/funnel-shift-rot.ll
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,12 @@ define <4 x i32> @rotr_v4i32(<4 x i32> %x, <4 x i32> %z) nounwind {
;
; X64-AVX2-LABEL: rotr_v4i32:
; X64-AVX2: # %bb.0:
; X64-AVX2-NEXT: vpxor %xmm2, %xmm2, %xmm2
; X64-AVX2-NEXT: vpsubd %xmm1, %xmm2, %xmm1
; X64-AVX2-NEXT: vpbroadcastd {{.*#+}} xmm2 = [31,31,31,31]
; X64-AVX2-NEXT: vpand %xmm2, %xmm1, %xmm1
; X64-AVX2-NEXT: vpsllvd %xmm1, %xmm0, %xmm2
; X64-AVX2-NEXT: vpsrlvd %xmm1, %xmm0, %xmm2
; X64-AVX2-NEXT: vpbroadcastd {{.*#+}} xmm3 = [32,32,32,32]
; X64-AVX2-NEXT: vpsubd %xmm1, %xmm3, %xmm1
; X64-AVX2-NEXT: vpsrlvd %xmm1, %xmm0, %xmm0
; X64-AVX2-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
; X64-AVX2-NEXT: vpor %xmm0, %xmm2, %xmm0
; X64-AVX2-NEXT: retq
%f = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> %x, <4 x i32> %x, <4 x i32> %z)
Expand Down
Loading

0 comments on commit 71fc4bb

Please sign in to comment.