Skip to content

Commit

Permalink
[X86] LowerShiftByScalarVariable - find splat patterns with getSplatS…
Browse files Browse the repository at this point in the history
…ourceVector instead of getSplatValue

This completes the removal of uses of SelectionDAG::getSplatValue started in D119090 - by avoiding extracting the splatted element we make it a lot easier to zero-extend the bottom 64-bits of the shift amount and fixes issues we had on 32-bit targets where i64 isn't legal.

I've removed the old version of getTargetVShiftNode that took the scalar shift amount argument and LowerRotate can finally efficiently handle vXi16 rotates-by-scalar (using the same code as general funnel-shifts).

The only regression we see is in the X86-AVX2 PR52719 test case in vector-shift-ashr-256.ll - this is now hitting the same problem as the X86-AVX1 case (failure to simplify a multi-use X86ISD::VBROADCAST_LOAD) which I intend to address in a follow up patch.
  • Loading branch information
RKSimon committed Mar 4, 2022
1 parent 85c53c7 commit 147cfcb
Show file tree
Hide file tree
Showing 27 changed files with 619 additions and 992 deletions.
100 changes: 11 additions & 89 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -25810,72 +25810,6 @@ static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT,
return DAG.getNode(Opc, dl, VT, SrcOp, ShAmt);
}

/// Handle vector element shifts where the shift amount may or may not be a
/// constant. Takes immediate version of shift as input.
/// TODO: Replace with vector + (splat) idx to avoid extract_element nodes.
static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT,
SDValue SrcOp, SDValue ShAmt,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
MVT SVT = ShAmt.getSimpleValueType();
assert((SVT == MVT::i32 || SVT == MVT::i64) && "Unexpected value type!");

// Change opcode to non-immediate version.
Opc = getTargetVShiftUniformOpcode(Opc, true);

// Need to build a vector containing shift amount.
// SSE/AVX packed shifts only use the lower 64-bit of the shift count.
// +====================+============+=======================================+
// | ShAmt is | HasSSE4.1? | Construct ShAmt vector as |
// +====================+============+=======================================+
// | i64 | Yes, No | Use ShAmt as lowest elt |
// | i32 | Yes | zero-extend in-reg |
// | (i32 zext(i16/i8)) | Yes | zero-extend in-reg |
// | (i32 zext(i16/i8)) | No | byte-shift-in-reg |
// | i16/i32 | No | v4i32 build_vector(ShAmt, 0, ud, ud)) |
// +====================+============+=======================================+

if (SVT == MVT::i64)
ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v2i64, ShAmt);
else if (ShAmt.getOpcode() == ISD::ZERO_EXTEND &&
ShAmt.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
(ShAmt.getOperand(0).getSimpleValueType() == MVT::i16 ||
ShAmt.getOperand(0).getSimpleValueType() == MVT::i8)) {
ShAmt = ShAmt.getOperand(0);
MVT AmtTy = ShAmt.getSimpleValueType() == MVT::i8 ? MVT::v16i8 : MVT::v8i16;
ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), AmtTy, ShAmt);
if (Subtarget.hasSSE41())
ShAmt = DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(ShAmt),
MVT::v2i64, ShAmt);
else {
SDValue ByteShift = DAG.getTargetConstant(
(128 - AmtTy.getScalarSizeInBits()) / 8, SDLoc(ShAmt), MVT::i8);
ShAmt = DAG.getBitcast(MVT::v16i8, ShAmt);
ShAmt = DAG.getNode(X86ISD::VSHLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt,
ByteShift);
ShAmt = DAG.getNode(X86ISD::VSRLDQ, SDLoc(ShAmt), MVT::v16i8, ShAmt,
ByteShift);
}
} else if (Subtarget.hasSSE41() &&
ShAmt.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v4i32, ShAmt);
ShAmt = DAG.getNode(ISD::ZERO_EXTEND_VECTOR_INREG, SDLoc(ShAmt),
MVT::v2i64, ShAmt);
} else {
SDValue ShOps[4] = {ShAmt, DAG.getConstant(0, dl, SVT), DAG.getUNDEF(SVT),
DAG.getUNDEF(SVT)};
ShAmt = DAG.getBuildVector(MVT::v4i32, dl, ShOps);
}

// The return type has to be a 128-bit type with the same element
// type as the input type.
MVT EltVT = VT.getVectorElementType();
MVT ShVT = MVT::getVectorVT(EltVT, 128 / EltVT.getSizeInBits());

ShAmt = DAG.getBitcast(ShVT, ShAmt);
return DAG.getNode(Opc, dl, VT, SrcOp, ShAmt);
}

/// Return Mask with the necessary casting or extending
/// for \p Mask according to \p MaskVT when lowering masking intrinsics
static SDValue getMaskNode(SDValue Mask, MVT MaskVT,
Expand Down Expand Up @@ -29341,22 +29275,12 @@ static SDValue LowerShiftByScalarVariable(SDValue Op, SelectionDAG &DAG,
unsigned Opcode = Op.getOpcode();
unsigned X86OpcI = getTargetVShiftUniformOpcode(Opcode, false);

// TODO: Use getSplatSourceVector.
if (SDValue BaseShAmt = DAG.getSplatValue(Amt)) {
if (supportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode)) {
MVT EltVT = VT.getVectorElementType();
assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!");
if (EltVT != MVT::i64 && EltVT.bitsGT(MVT::i32))
BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64, BaseShAmt);
else if (EltVT.bitsLT(MVT::i32))
BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, BaseShAmt);

return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, Subtarget, DAG);
}
}

int BaseShAmtIdx = -1;
if (SDValue BaseShAmt = DAG.getSplatSourceVector(Amt, BaseShAmtIdx)) {
if (supportedVectorShiftWithBaseAmnt(VT, Subtarget, Opcode))
return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, BaseShAmtIdx,
Subtarget, DAG);

// vXi8 shifts - shift as v8i16 + mask result.
if (((VT == MVT::v16i8 && !Subtarget.canExtendTo512DQ()) ||
(VT == MVT::v32i8 && !Subtarget.canExtendTo512BW()) ||
Expand Down Expand Up @@ -30217,11 +30141,13 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
// Attempt to fold as unpack(x,x) << zext(splat(y)):
// rotl(x,y) -> (unpack(x,x) << (y & (bw-1))) >> bw.
// rotr(x,y) -> (unpack(x,x) >> (y & (bw-1))).
// TODO: Handle vXi16 cases on all targets.
if (EltSizeInBits == 8 || EltSizeInBits == 32 ||
(EltSizeInBits == 16 && !Subtarget.hasSSE41())) {
if (EltSizeInBits == 8 || EltSizeInBits == 16 || EltSizeInBits == 32) {
int BaseRotAmtIdx = -1;
if (SDValue BaseRotAmt = DAG.getSplatSourceVector(AmtMod, BaseRotAmtIdx)) {
if (EltSizeInBits == 16 && Subtarget.hasSSE41()) {
unsigned FunnelOpc = IsROTL ? ISD::FSHL : ISD::FSHR;
return DAG.getNode(FunnelOpc, DL, VT, R, R, Amt);
}
unsigned ShiftX86Opc = IsROTL ? X86ISD::VSHLI : X86ISD::VSRLI;
SDValue Lo = DAG.getBitcast(ExtVT, getUnpackl(DAG, DL, VT, R, R));
SDValue Hi = DAG.getBitcast(ExtVT, getUnpackh(DAG, DL, VT, R, R));
Expand Down Expand Up @@ -41560,12 +41486,8 @@ bool X86TargetLowering::isSplatValueForTargetNode(SDValue Op,
switch (Opc) {
case X86ISD::VBROADCAST:
case X86ISD::VBROADCAST_LOAD:
// TODO: Permit vXi64 types on 32-bit targets.
if (isTypeLegal(Op.getValueType().getVectorElementType())) {
UndefElts = APInt::getNullValue(NumElts);
return true;
}
return false;
UndefElts = APInt::getNullValue(NumElts);
return true;
}

return TargetLowering::isSplatValueForTargetNode(Op, DemandedElts, UndefElts,
Expand Down
20 changes: 10 additions & 10 deletions llvm/test/CodeGen/X86/pr15296.ll
Expand Up @@ -62,11 +62,11 @@ allocas:
define <4 x i64> @shiftInput___64in32bitmode(<4 x i64> %input, i64 %shiftval) nounwind {
; X86-LABEL: shiftInput___64in32bitmode:
; X86: # %bb.0: # %allocas
; X86-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero
; X86-NEXT: vextractf128 $1, %ymm0, %xmm2
; X86-NEXT: vpsrlq %xmm1, %xmm2, %xmm2
; X86-NEXT: vpsrlq %xmm1, %xmm0, %xmm0
; X86-NEXT: vinsertf128 $1, %xmm2, %ymm0, %ymm0
; X86-NEXT: vextractf128 $1, %ymm0, %xmm1
; X86-NEXT: vmovq {{.*#+}} xmm2 = mem[0],zero
; X86-NEXT: vpsrlq %xmm2, %xmm1, %xmm1
; X86-NEXT: vpsrlq %xmm2, %xmm0, %xmm0
; X86-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
; X86-NEXT: retl
;
; X64-LABEL: shiftInput___64in32bitmode:
Expand All @@ -87,11 +87,11 @@ allocas:
define <4 x i64> @shiftInput___2x32bitcast(<4 x i64> %input, i32 %shiftval) nounwind {
; X86-LABEL: shiftInput___2x32bitcast:
; X86: # %bb.0: # %allocas
; X86-NEXT: vmovd {{.*#+}} xmm1 = mem[0],zero,zero,zero
; X86-NEXT: vextractf128 $1, %ymm0, %xmm2
; X86-NEXT: vpsrlq %xmm1, %xmm2, %xmm2
; X86-NEXT: vpsrlq %xmm1, %xmm0, %xmm0
; X86-NEXT: vinsertf128 $1, %xmm2, %ymm0, %ymm0
; X86-NEXT: vextractf128 $1, %ymm0, %xmm1
; X86-NEXT: vmovd {{.*#+}} xmm2 = mem[0],zero,zero,zero
; X86-NEXT: vpsrlq %xmm2, %xmm1, %xmm1
; X86-NEXT: vpsrlq %xmm2, %xmm0, %xmm0
; X86-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
; X86-NEXT: retl
;
; X64-LABEL: shiftInput___2x32bitcast:
Expand Down
116 changes: 36 additions & 80 deletions llvm/test/CodeGen/X86/vector-fshl-128.ll
Expand Up @@ -1156,60 +1156,32 @@ define <4 x i32> @splatvar_funnnel_v4i32(<4 x i32> %x, <4 x i32> %y, <4 x i32> %
}

define <8 x i16> @splatvar_funnnel_v8i16(<8 x i16> %x, <8 x i16> %y, <8 x i16> %amt) nounwind {
; SSE2-LABEL: splatvar_funnnel_v8i16:
; SSE2: # %bb.0:
; SSE2-NEXT: movdqa {{.*#+}} xmm3 = [15,15,15,15,15,15,15,15]
; SSE2-NEXT: movdqa %xmm2, %xmm4
; SSE2-NEXT: pandn %xmm3, %xmm4
; SSE2-NEXT: pslldq {{.*#+}} xmm4 = zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,xmm4[0,1]
; SSE2-NEXT: psrldq {{.*#+}} xmm4 = xmm4[14,15],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
; SSE2-NEXT: psrlw $1, %xmm1
; SSE2-NEXT: psrlw %xmm4, %xmm1
; SSE2-NEXT: pand %xmm3, %xmm2
; SSE2-NEXT: pslldq {{.*#+}} xmm2 = zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,xmm2[0,1]
; SSE2-NEXT: psrldq {{.*#+}} xmm2 = xmm2[14,15],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
; SSE2-NEXT: psllw %xmm2, %xmm0
; SSE2-NEXT: por %xmm1, %xmm0
; SSE2-NEXT: retq
;
; SSE41-LABEL: splatvar_funnnel_v8i16:
; SSE41: # %bb.0:
; SSE41-NEXT: movdqa {{.*#+}} xmm3 = [15,0,0,0]
; SSE41-NEXT: movdqa %xmm2, %xmm4
; SSE41-NEXT: pandn %xmm3, %xmm4
; SSE41-NEXT: psrlw $1, %xmm1
; SSE41-NEXT: psrlw %xmm4, %xmm1
; SSE41-NEXT: pand %xmm3, %xmm2
; SSE41-NEXT: psllw %xmm2, %xmm0
; SSE41-NEXT: por %xmm1, %xmm0
; SSE41-NEXT: retq
;
; AVX1-LABEL: splatvar_funnnel_v8i16:
; AVX1: # %bb.0:
; AVX1-NEXT: vmovddup {{.*#+}} xmm3 = [15,15]
; AVX1-NEXT: # xmm3 = mem[0,0]
; AVX1-NEXT: vandnps %xmm3, %xmm2, %xmm4
; AVX1-NEXT: vpsrlw $1, %xmm1, %xmm1
; AVX1-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
; AVX1-NEXT: vandps %xmm3, %xmm2, %xmm2
; AVX1-NEXT: vpsllw %xmm2, %xmm0, %xmm0
; AVX1-NEXT: vpor %xmm1, %xmm0, %xmm0
; AVX1-NEXT: retq
; SSE-LABEL: splatvar_funnnel_v8i16:
; SSE: # %bb.0:
; SSE-NEXT: movdqa {{.*#+}} xmm3 = [15,0,0,0]
; SSE-NEXT: movdqa %xmm2, %xmm4
; SSE-NEXT: pandn %xmm3, %xmm4
; SSE-NEXT: psrlw $1, %xmm1
; SSE-NEXT: psrlw %xmm4, %xmm1
; SSE-NEXT: pand %xmm3, %xmm2
; SSE-NEXT: psllw %xmm2, %xmm0
; SSE-NEXT: por %xmm1, %xmm0
; SSE-NEXT: retq
;
; AVX2-LABEL: splatvar_funnnel_v8i16:
; AVX2: # %bb.0:
; AVX2-NEXT: vpbroadcastq {{.*#+}} xmm3 = [15,15]
; AVX2-NEXT: vpandn %xmm3, %xmm2, %xmm4
; AVX2-NEXT: vpsrlw $1, %xmm1, %xmm1
; AVX2-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
; AVX2-NEXT: vpand %xmm3, %xmm2, %xmm2
; AVX2-NEXT: vpsllw %xmm2, %xmm0, %xmm0
; AVX2-NEXT: vpor %xmm1, %xmm0, %xmm0
; AVX2-NEXT: retq
; AVX-LABEL: splatvar_funnnel_v8i16:
; AVX: # %bb.0:
; AVX-NEXT: vmovdqa {{.*#+}} xmm3 = [15,0,0,0]
; AVX-NEXT: vpandn %xmm3, %xmm2, %xmm4
; AVX-NEXT: vpsrlw $1, %xmm1, %xmm1
; AVX-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
; AVX-NEXT: vpand %xmm3, %xmm2, %xmm2
; AVX-NEXT: vpsllw %xmm2, %xmm0, %xmm0
; AVX-NEXT: vpor %xmm1, %xmm0, %xmm0
; AVX-NEXT: retq
;
; AVX512F-LABEL: splatvar_funnnel_v8i16:
; AVX512F: # %bb.0:
; AVX512F-NEXT: vpbroadcastq {{.*#+}} xmm3 = [15,15]
; AVX512F-NEXT: vmovdqa {{.*#+}} xmm3 = [15,0,0,0]
; AVX512F-NEXT: vpandn %xmm3, %xmm2, %xmm4
; AVX512F-NEXT: vpsrlw $1, %xmm1, %xmm1
; AVX512F-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
Expand All @@ -1220,7 +1192,7 @@ define <8 x i16> @splatvar_funnnel_v8i16(<8 x i16> %x, <8 x i16> %y, <8 x i16> %
;
; AVX512VL-LABEL: splatvar_funnnel_v8i16:
; AVX512VL: # %bb.0:
; AVX512VL-NEXT: vpbroadcastq {{.*#+}} xmm3 = [15,15]
; AVX512VL-NEXT: vmovdqa {{.*#+}} xmm3 = [15,0,0,0]
; AVX512VL-NEXT: vpandn %xmm3, %xmm2, %xmm4
; AVX512VL-NEXT: vpsrlw $1, %xmm1, %xmm1
; AVX512VL-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
Expand All @@ -1231,7 +1203,7 @@ define <8 x i16> @splatvar_funnnel_v8i16(<8 x i16> %x, <8 x i16> %y, <8 x i16> %
;
; AVX512BW-LABEL: splatvar_funnnel_v8i16:
; AVX512BW: # %bb.0:
; AVX512BW-NEXT: vpbroadcastq {{.*#+}} xmm3 = [15,15]
; AVX512BW-NEXT: vmovdqa {{.*#+}} xmm3 = [15,0,0,0]
; AVX512BW-NEXT: vpandn %xmm3, %xmm2, %xmm4
; AVX512BW-NEXT: vpsrlw $1, %xmm1, %xmm1
; AVX512BW-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
Expand All @@ -1252,7 +1224,7 @@ define <8 x i16> @splatvar_funnnel_v8i16(<8 x i16> %x, <8 x i16> %y, <8 x i16> %
;
; AVX512VLBW-LABEL: splatvar_funnnel_v8i16:
; AVX512VLBW: # %bb.0:
; AVX512VLBW-NEXT: vpbroadcastq {{.*#+}} xmm3 = [15,15]
; AVX512VLBW-NEXT: vmovdqa {{.*#+}} xmm3 = [15,0,0,0]
; AVX512VLBW-NEXT: vpandn %xmm3, %xmm2, %xmm4
; AVX512VLBW-NEXT: vpsrlw $1, %xmm1, %xmm1
; AVX512VLBW-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
Expand All @@ -1267,41 +1239,25 @@ define <8 x i16> @splatvar_funnnel_v8i16(<8 x i16> %x, <8 x i16> %y, <8 x i16> %
; AVX512VLVBMI2-NEXT: vpshldvw %xmm2, %xmm1, %xmm0
; AVX512VLVBMI2-NEXT: retq
;
; XOPAVX1-LABEL: splatvar_funnnel_v8i16:
; XOPAVX1: # %bb.0:
; XOPAVX1-NEXT: vmovddup {{.*#+}} xmm3 = [15,15]
; XOPAVX1-NEXT: # xmm3 = mem[0,0]
; XOPAVX1-NEXT: vandnps %xmm3, %xmm2, %xmm4
; XOPAVX1-NEXT: vpsrlw $1, %xmm1, %xmm1
; XOPAVX1-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
; XOPAVX1-NEXT: vandps %xmm3, %xmm2, %xmm2
; XOPAVX1-NEXT: vpsllw %xmm2, %xmm0, %xmm0
; XOPAVX1-NEXT: vpor %xmm1, %xmm0, %xmm0
; XOPAVX1-NEXT: retq
;
; XOPAVX2-LABEL: splatvar_funnnel_v8i16:
; XOPAVX2: # %bb.0:
; XOPAVX2-NEXT: vpbroadcastq {{.*#+}} xmm3 = [15,15]
; XOPAVX2-NEXT: vpandn %xmm3, %xmm2, %xmm4
; XOPAVX2-NEXT: vpsrlw $1, %xmm1, %xmm1
; XOPAVX2-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
; XOPAVX2-NEXT: vpand %xmm3, %xmm2, %xmm2
; XOPAVX2-NEXT: vpsllw %xmm2, %xmm0, %xmm0
; XOPAVX2-NEXT: vpor %xmm1, %xmm0, %xmm0
; XOPAVX2-NEXT: retq
; XOP-LABEL: splatvar_funnnel_v8i16:
; XOP: # %bb.0:
; XOP-NEXT: vmovdqa {{.*#+}} xmm3 = [15,0,0,0]
; XOP-NEXT: vpandn %xmm3, %xmm2, %xmm4
; XOP-NEXT: vpsrlw $1, %xmm1, %xmm1
; XOP-NEXT: vpsrlw %xmm4, %xmm1, %xmm1
; XOP-NEXT: vpand %xmm3, %xmm2, %xmm2
; XOP-NEXT: vpsllw %xmm2, %xmm0, %xmm0
; XOP-NEXT: vpor %xmm1, %xmm0, %xmm0
; XOP-NEXT: retq
;
; X86-SSE2-LABEL: splatvar_funnnel_v8i16:
; X86-SSE2: # %bb.0:
; X86-SSE2-NEXT: movdqa {{.*#+}} xmm3 = [15,15,15,15,15,15,15,15]
; X86-SSE2-NEXT: movdqa {{.*#+}} xmm3 = [15,0,0,0]
; X86-SSE2-NEXT: movdqa %xmm2, %xmm4
; X86-SSE2-NEXT: pandn %xmm3, %xmm4
; X86-SSE2-NEXT: pslldq {{.*#+}} xmm4 = zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,xmm4[0,1]
; X86-SSE2-NEXT: psrldq {{.*#+}} xmm4 = xmm4[14,15],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
; X86-SSE2-NEXT: psrlw $1, %xmm1
; X86-SSE2-NEXT: psrlw %xmm4, %xmm1
; X86-SSE2-NEXT: pand %xmm3, %xmm2
; X86-SSE2-NEXT: pslldq {{.*#+}} xmm2 = zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,xmm2[0,1]
; X86-SSE2-NEXT: psrldq {{.*#+}} xmm2 = xmm2[14,15],zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero,zero
; X86-SSE2-NEXT: psllw %xmm2, %xmm0
; X86-SSE2-NEXT: por %xmm1, %xmm0
; X86-SSE2-NEXT: retl
Expand Down

0 comments on commit 147cfcb

Please sign in to comment.