Skip to content

Commit

Permalink
[X86][XOP] Added support for the lowering of 128-bit vector shifts to…
Browse files Browse the repository at this point in the history
… XOP shift instructions

The XOP shifts just have logical/arithmetic versions and the left/right shifts are controlled by whether the value is positive/negative. Because of this I've added new X86ISD nodes instead of trying to force them to use the existing shift nodes.

Additionally Excavator cores (bdver4) support XOP and AVX2 - meaning that it should use the AVX2 shifts when it can and fall back to XOP in other cases.

Differential Revision: http://reviews.llvm.org/D8690

llvm-svn: 248878
  • Loading branch information
RKSimon committed Sep 30, 2015
1 parent 82d705e commit 3d11c99
Show file tree
Hide file tree
Showing 13 changed files with 1,391 additions and 46 deletions.
52 changes: 37 additions & 15 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -17893,18 +17893,28 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,

// i64 SRA needs to be performed as partial shifts.
if ((VT == MVT::v2i64 || (Subtarget->hasInt256() && VT == MVT::v4i64)) &&
Op.getOpcode() == ISD::SRA)
Op.getOpcode() == ISD::SRA && !Subtarget->hasXOP())
return ArithmeticShiftRight64(ShiftAmt);

if (VT == MVT::v16i8 || (Subtarget->hasInt256() && VT == MVT::v32i8)) {
unsigned NumElts = VT.getVectorNumElements();
MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);

if (Op.getOpcode() == ISD::SHL) {
// Simple i8 add case
if (ShiftAmt == 1)
return DAG.getNode(ISD::ADD, dl, VT, R, R);
// Simple i8 add case
if (Op.getOpcode() == ISD::SHL && ShiftAmt == 1)
return DAG.getNode(ISD::ADD, dl, VT, R, R);

// ashr(R, 7) === cmp_slt(R, 0)
if (Op.getOpcode() == ISD::SRA && ShiftAmt == 7) {
SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl);
return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R);
}

// XOP can shift v16i8 directly instead of as shift v8i16 + mask.
if (VT == MVT::v16i8 && Subtarget->hasXOP())
return SDValue();

if (Op.getOpcode() == ISD::SHL) {
// Make a large shift.
SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ShiftVT,
R, ShiftAmt, DAG);
Expand All @@ -17927,12 +17937,6 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
DAG.getNode(ISD::BUILD_VECTOR, dl, VT, V));
}
if (Op.getOpcode() == ISD::SRA) {
if (ShiftAmt == 7) {
// ashr(R, 7) === cmp_slt(R, 0)
SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl);
return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R);
}

// ashr(R, Amt) === sub(xor(lshr(R, Amt), Mask), Mask)
SDValue Res = DAG.getNode(ISD::SRL, dl, VT, R, Amt);
SmallVector<SDValue, 32> V(NumElts,
Expand All @@ -17949,7 +17953,7 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
}

// Special case in 32-bit mode, where i64 is expanded into high and low parts.
if (!Subtarget->is64Bit() &&
if (!Subtarget->is64Bit() && !Subtarget->hasXOP() &&
(VT == MVT::v2i64 || (Subtarget->hasInt256() && VT == MVT::v4i64))) {

// Peek through any splat that was introduced for i64 shift vectorization.
Expand Down Expand Up @@ -18103,11 +18107,26 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
return V;

if (SDValue V = LowerScalarVariableShift(Op, DAG, Subtarget))
return V;
return V;

if (SupportedVectorVarShift(VT, Subtarget, Op.getOpcode()))
return Op;

// XOP has 128-bit variable logical/arithmetic shifts.
// +ve/-ve Amt = shift left/right.
if (Subtarget->hasXOP() &&
(VT == MVT::v2i64 || VT == MVT::v4i32 ||
VT == MVT::v8i16 || VT == MVT::v16i8)) {
if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) {
SDValue Zero = getZeroVector(VT, Subtarget, DAG, dl);
Amt = DAG.getNode(ISD::SUB, dl, VT, Zero, Amt);
}
if (Op.getOpcode() == ISD::SHL || Op.getOpcode() == ISD::SRL)
return DAG.getNode(X86ISD::VPSHL, dl, VT, R, Amt);
if (Op.getOpcode() == ISD::SRA)
return DAG.getNode(X86ISD::VPSHA, dl, VT, R, Amt);
}

// 2i64 vector logical shifts can efficiently avoid scalarization - do the
// shifts per-lane and then shuffle the partial results back together.
if (VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) {
Expand Down Expand Up @@ -18296,7 +18315,8 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
return DAG.getVectorShuffle(VT, dl, R02, R13, {0, 5, 2, 7});
}

if (VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget->hasInt256())) {
if (VT == MVT::v16i8 ||
(VT == MVT::v32i8 && Subtarget->hasInt256() && !Subtarget->hasXOP())) {
MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2);
unsigned ShiftOpcode = Op->getOpcode();

Expand Down Expand Up @@ -18416,7 +18436,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
DAG.getNode(Op.getOpcode(), dl, ExtVT, R, Amt));
}

if (Subtarget->hasInt256() && VT == MVT::v16i16) {
if (Subtarget->hasInt256() && !Subtarget->hasXOP() && VT == MVT::v16i16) {
MVT ExtVT = MVT::v8i32;
SDValue Z = getZeroVector(VT, Subtarget, DAG, dl);
SDValue ALo = DAG.getNode(X86ISD::UNPCKL, dl, VT, Amt, Z);
Expand Down Expand Up @@ -19820,6 +19840,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
case X86ISD::RDSEED: return "X86ISD::RDSEED";
case X86ISD::VPMADDUBSW: return "X86ISD::VPMADDUBSW";
case X86ISD::VPMADDWD: return "X86ISD::VPMADDWD";
case X86ISD::VPSHA: return "X86ISD::VPSHA";
case X86ISD::VPSHL: return "X86ISD::VPSHL";
case X86ISD::FMADD: return "X86ISD::FMADD";
case X86ISD::FMSUB: return "X86ISD::FMSUB";
case X86ISD::FNMADD: return "X86ISD::FNMADD";
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.h
Expand Up @@ -410,6 +410,9 @@ namespace llvm {
/// SSE4A Extraction and Insertion.
EXTRQI, INSERTQI,

// XOP arithmetic/logical shifts
VPSHA, VPSHL,

// Vector multiply packed unsigned doubleword integers
PMULUDQ,
// Vector multiply packed signed doubleword integers
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
Expand Up @@ -215,6 +215,13 @@ def X86vshli : SDNode<"X86ISD::VSHLI", SDTIntShiftOp>;
def X86vsrli : SDNode<"X86ISD::VSRLI", SDTIntShiftOp>;
def X86vsrai : SDNode<"X86ISD::VSRAI", SDTIntShiftOp>;

def X86vpshl : SDNode<"X86ISD::VPSHL",
SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
SDTCisVec<2>]>>;
def X86vpsha : SDNode<"X86ISD::VPSHA",
SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
SDTCisVec<2>]>>;

def SDTX86CmpPTest : SDTypeProfile<1, 2, [SDTCisVT<0, i32>,
SDTCisVec<1>,
SDTCisSameAs<2, 1>]>;
Expand Down
53 changes: 40 additions & 13 deletions llvm/lib/Target/X86/X86InstrXOP.td
Expand Up @@ -83,7 +83,42 @@ let ExeDomain = SSEPackedDouble in {
defm VFRCZPD : xop2op256<0x81, "vfrczpd", int_x86_xop_vfrcz_pd_256, loadv4f64>;
}

multiclass xop3op<bits<8> opc, string OpcodeStr, Intrinsic Int> {
multiclass xop3op<bits<8> opc, string OpcodeStr, SDNode OpNode,
ValueType vt128> {
def rr : IXOP<opc, MRMSrcReg, (outs VR128:$dst),
(ins VR128:$src1, VR128:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set VR128:$dst,
(vt128 (OpNode (vt128 VR128:$src1), (vt128 VR128:$src2))))]>,
XOP_4VOp3, Sched<[WriteVarVecShift]>;
def rm : IXOP<opc, MRMSrcMem, (outs VR128:$dst),
(ins VR128:$src1, i128mem:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set VR128:$dst,
(vt128 (OpNode (vt128 VR128:$src1),
(vt128 (bitconvert (loadv2i64 addr:$src2))))))]>,
XOP_4V, VEX_W, Sched<[WriteVarVecShift, ReadAfterLd]>;
def mr : IXOP<opc, MRMSrcMem, (outs VR128:$dst),
(ins i128mem:$src1, VR128:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set VR128:$dst,
(vt128 (OpNode (vt128 (bitconvert (loadv2i64 addr:$src1))),
(vt128 VR128:$src2))))]>,
XOP_4VOp3, Sched<[WriteVarVecShift, ReadAfterLd]>;
}

let ExeDomain = SSEPackedInt in {
defm VPSHAB : xop3op<0x98, "vpshab", X86vpsha, v16i8>;
defm VPSHAD : xop3op<0x9A, "vpshad", X86vpsha, v4i32>;
defm VPSHAQ : xop3op<0x9B, "vpshaq", X86vpsha, v2i64>;
defm VPSHAW : xop3op<0x99, "vpshaw", X86vpsha, v8i16>;
defm VPSHLB : xop3op<0x94, "vpshlb", X86vpshl, v16i8>;
defm VPSHLD : xop3op<0x96, "vpshld", X86vpshl, v4i32>;
defm VPSHLQ : xop3op<0x97, "vpshlq", X86vpshl, v2i64>;
defm VPSHLW : xop3op<0x95, "vpshlw", X86vpshl, v8i16>;
}

multiclass xop3op_int<bits<8> opc, string OpcodeStr, Intrinsic Int> {
def rr : IXOP<opc, MRMSrcReg, (outs VR128:$dst),
(ins VR128:$src1, VR128:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
Expand All @@ -103,18 +138,10 @@ multiclass xop3op<bits<8> opc, string OpcodeStr, Intrinsic Int> {
}

let ExeDomain = SSEPackedInt in {
defm VPSHLW : xop3op<0x95, "vpshlw", int_x86_xop_vpshlw>;
defm VPSHLQ : xop3op<0x97, "vpshlq", int_x86_xop_vpshlq>;
defm VPSHLD : xop3op<0x96, "vpshld", int_x86_xop_vpshld>;
defm VPSHLB : xop3op<0x94, "vpshlb", int_x86_xop_vpshlb>;
defm VPSHAW : xop3op<0x99, "vpshaw", int_x86_xop_vpshaw>;
defm VPSHAQ : xop3op<0x9B, "vpshaq", int_x86_xop_vpshaq>;
defm VPSHAD : xop3op<0x9A, "vpshad", int_x86_xop_vpshad>;
defm VPSHAB : xop3op<0x98, "vpshab", int_x86_xop_vpshab>;
defm VPROTW : xop3op<0x91, "vprotw", int_x86_xop_vprotw>;
defm VPROTQ : xop3op<0x93, "vprotq", int_x86_xop_vprotq>;
defm VPROTD : xop3op<0x92, "vprotd", int_x86_xop_vprotd>;
defm VPROTB : xop3op<0x90, "vprotb", int_x86_xop_vprotb>;
defm VPROTW : xop3op_int<0x91, "vprotw", int_x86_xop_vprotw>;
defm VPROTQ : xop3op_int<0x93, "vprotq", int_x86_xop_vprotq>;
defm VPROTD : xop3op_int<0x92, "vprotd", int_x86_xop_vprotd>;
defm VPROTB : xop3op_int<0x90, "vprotb", int_x86_xop_vprotb>;
}

multiclass xop3opimm<bits<8> opc, string OpcodeStr, Intrinsic Int> {
Expand Down
10 changes: 9 additions & 1 deletion llvm/lib/Target/X86/X86IntrinsicsInfo.h
Expand Up @@ -1661,7 +1661,15 @@ static const IntrinsicData IntrinsicsWithoutChain[] = {
X86_INTRINSIC_DATA(ssse3_pshuf_b_128, INTR_TYPE_2OP, X86ISD::PSHUFB, 0),
X86_INTRINSIC_DATA(ssse3_psign_b_128, INTR_TYPE_2OP, X86ISD::PSIGN, 0),
X86_INTRINSIC_DATA(ssse3_psign_d_128, INTR_TYPE_2OP, X86ISD::PSIGN, 0),
X86_INTRINSIC_DATA(ssse3_psign_w_128, INTR_TYPE_2OP, X86ISD::PSIGN, 0)
X86_INTRINSIC_DATA(ssse3_psign_w_128, INTR_TYPE_2OP, X86ISD::PSIGN, 0),
X86_INTRINSIC_DATA(xop_vpshab, INTR_TYPE_2OP, X86ISD::VPSHA, 0),
X86_INTRINSIC_DATA(xop_vpshad, INTR_TYPE_2OP, X86ISD::VPSHA, 0),
X86_INTRINSIC_DATA(xop_vpshaq, INTR_TYPE_2OP, X86ISD::VPSHA, 0),
X86_INTRINSIC_DATA(xop_vpshaw, INTR_TYPE_2OP, X86ISD::VPSHA, 0),
X86_INTRINSIC_DATA(xop_vpshlb, INTR_TYPE_2OP, X86ISD::VPSHL, 0),
X86_INTRINSIC_DATA(xop_vpshld, INTR_TYPE_2OP, X86ISD::VPSHL, 0),
X86_INTRINSIC_DATA(xop_vpshlq, INTR_TYPE_2OP, X86ISD::VPSHL, 0),
X86_INTRINSIC_DATA(xop_vpshlw, INTR_TYPE_2OP, X86ISD::VPSHL, 0)
};

/*
Expand Down
76 changes: 61 additions & 15 deletions llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Expand Up @@ -140,6 +140,12 @@ int X86TTIImpl::getArithmeticInstrCost(
{ ISD::SRA, MVT::v8i64, 1 },
};

if (ST->hasAVX512()) {
int Idx = CostTableLookup(AVX512CostTable, ISD, LT.second);
if (Idx != -1)
return LT.first * AVX512CostTable[Idx].Cost;
}

static const CostTblEntry<MVT::SimpleValueType> AVX2CostTable[] = {
// Shifts on v4i64/v8i32 on AVX2 is legal even though we declare to
// customize them to detect the cases where shift amount is a scalar one.
Expand All @@ -153,7 +159,59 @@ int X86TTIImpl::getArithmeticInstrCost(
{ ISD::SRL, MVT::v2i64, 1 },
{ ISD::SHL, MVT::v4i64, 1 },
{ ISD::SRL, MVT::v4i64, 1 },
};

// Look for AVX2 lowering tricks.
if (ST->hasAVX2()) {
if (ISD == ISD::SHL && LT.second == MVT::v16i16 &&
(Op2Info == TargetTransformInfo::OK_UniformConstantValue ||
Op2Info == TargetTransformInfo::OK_NonUniformConstantValue))
// On AVX2, a packed v16i16 shift left by a constant build_vector
// is lowered into a vector multiply (vpmullw).
return LT.first;

int Idx = CostTableLookup(AVX2CostTable, ISD, LT.second);
if (Idx != -1)
return LT.first * AVX2CostTable[Idx].Cost;
}

static const CostTblEntry<MVT::SimpleValueType> XOPCostTable[] = {
// 128bit shifts take 1cy, but right shifts require negation beforehand.
{ ISD::SHL, MVT::v16i8, 1 },
{ ISD::SRL, MVT::v16i8, 2 },
{ ISD::SRA, MVT::v16i8, 2 },
{ ISD::SHL, MVT::v8i16, 1 },
{ ISD::SRL, MVT::v8i16, 2 },
{ ISD::SRA, MVT::v8i16, 2 },
{ ISD::SHL, MVT::v4i32, 1 },
{ ISD::SRL, MVT::v4i32, 2 },
{ ISD::SRA, MVT::v4i32, 2 },
{ ISD::SHL, MVT::v2i64, 1 },
{ ISD::SRL, MVT::v2i64, 2 },
{ ISD::SRA, MVT::v2i64, 2 },
// 256bit shifts require splitting if AVX2 didn't catch them above.
{ ISD::SHL, MVT::v32i8, 2 },
{ ISD::SRL, MVT::v32i8, 4 },
{ ISD::SRA, MVT::v32i8, 4 },
{ ISD::SHL, MVT::v16i16, 2 },
{ ISD::SRL, MVT::v16i16, 4 },
{ ISD::SRA, MVT::v16i16, 4 },
{ ISD::SHL, MVT::v8i32, 2 },
{ ISD::SRL, MVT::v8i32, 4 },
{ ISD::SRA, MVT::v8i32, 4 },
{ ISD::SHL, MVT::v4i64, 2 },
{ ISD::SRL, MVT::v4i64, 4 },
{ ISD::SRA, MVT::v4i64, 4 },
};

// Look for XOP lowering tricks.
if (ST->hasXOP()) {
int Idx = CostTableLookup(XOPCostTable, ISD, LT.second);
if (Idx != -1)
return LT.first * XOPCostTable[Idx].Cost;
}

static const CostTblEntry<MVT::SimpleValueType> AVX2CustomCostTable[] = {
{ ISD::SHL, MVT::v32i8, 11 }, // vpblendvb sequence.
{ ISD::SHL, MVT::v16i16, 10 }, // extend/vpsrlvd/pack sequence.

Expand All @@ -176,23 +234,11 @@ int X86TTIImpl::getArithmeticInstrCost(
{ ISD::UDIV, MVT::v4i64, 4*20 },
};

if (ST->hasAVX512()) {
int Idx = CostTableLookup(AVX512CostTable, ISD, LT.second);
if (Idx != -1)
return LT.first * AVX512CostTable[Idx].Cost;
}
// Look for AVX2 lowering tricks.
// Look for AVX2 lowering tricks for custom cases.
if (ST->hasAVX2()) {
if (ISD == ISD::SHL && LT.second == MVT::v16i16 &&
(Op2Info == TargetTransformInfo::OK_UniformConstantValue ||
Op2Info == TargetTransformInfo::OK_NonUniformConstantValue))
// On AVX2, a packed v16i16 shift left by a constant build_vector
// is lowered into a vector multiply (vpmullw).
return LT.first;

int Idx = CostTableLookup(AVX2CostTable, ISD, LT.second);
int Idx = CostTableLookup(AVX2CustomCostTable, ISD, LT.second);
if (Idx != -1)
return LT.first * AVX2CostTable[Idx].Cost;
return LT.first * AVX2CustomCostTable[Idx].Cost;
}

static const CostTblEntry<MVT::SimpleValueType>
Expand Down

0 comments on commit 3d11c99

Please sign in to comment.