Skip to content

Commit

Permalink
[X86][SSE] Lower 128-bit vectors to SIGN/ZERO_EXTEND_VECTOR_IN_REG ops
Browse files Browse the repository at this point in the history
As described on PR31712, we miss a variety of legalization combines because we lower these to X86ISD::VSEXT/VZEXT despite them having the same functionality. This patch makes 128-bit (SSE41) SIGN/ZERO_EXTEND_VECTOR_IN_REG ops legal, adds the necessary tablegen plumbing and uses a helper 'getExtendInVec' to decide when to use SIGN/ZERO_EXTEND_VECTOR_IN_REG or VSEXT/VZEXT.

We're missing a couple of shuffle combines that will be added in a future patch for review.

Later patches can then support the AVX2 cases as a mixture of SIGN/ZERO_EXTEND and SIGN/ZERO_EXTEND_VECTOR_IN_REG, and then finally deal with the AVX512 cases.

Differential Revision: https://reviews.llvm.org/D30549

llvm-svn: 296985
  • Loading branch information
RKSimon committed Mar 5, 2017
1 parent 4bc8292 commit 9f5c251
Show file tree
Hide file tree
Showing 13 changed files with 296 additions and 246 deletions.
7 changes: 7 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Expand Up @@ -160,6 +160,10 @@ def SDTExtInreg : SDTypeProfile<1, 2, [ // sext_inreg
SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisVT<2, OtherVT>,
SDTCisVTSmallerThanOp<2, 1>
]>;
def SDTExtInvec : SDTypeProfile<1, 1, [ // sext_invec
SDTCisInt<0>, SDTCisVec<0>, SDTCisInt<1>, SDTCisVec<1>,
SDTCisOpSmallerThanOp<1, 0>, SDTCisSameSizeAs<0,1>
]>;

def SDTSetCC : SDTypeProfile<1, 3, [ // setcc
SDTCisInt<0>, SDTCisSameAs<1, 2>, SDTCisVT<3, OtherVT>
Expand Down Expand Up @@ -406,6 +410,9 @@ def umax : SDNode<"ISD::UMAX" , SDTIntBinOp,
[SDNPCommutative, SDNPAssociative]>;

def sext_inreg : SDNode<"ISD::SIGN_EXTEND_INREG", SDTExtInreg>;
def sext_invec : SDNode<"ISD::SIGN_EXTEND_VECTOR_INREG", SDTExtInvec>;
def zext_invec : SDNode<"ISD::ZERO_EXTEND_VECTOR_INREG", SDTExtInvec>;

def bitreverse : SDNode<"ISD::BITREVERSE" , SDTIntUnaryOp>;
def bswap : SDNode<"ISD::BSWAP" , SDTIntUnaryOp>;
def ctlz : SDNode<"ISD::CTLZ" , SDTIntUnaryOp>;
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -2419,6 +2419,20 @@ void SelectionDAG::computeKnownBits(SDValue Op, APInt &KnownZero,
}
break;
}
case ISD::ZERO_EXTEND_VECTOR_INREG: {
EVT InVT = Op.getOperand(0).getValueType();
unsigned InBits = InVT.getScalarSizeInBits();
APInt NewBits = APInt::getHighBitsSet(BitWidth, BitWidth - InBits);
KnownZero = KnownZero.trunc(InBits);
KnownOne = KnownOne.trunc(InBits);
computeKnownBits(Op.getOperand(0), KnownZero, KnownOne,
DemandedElts.zext(InVT.getVectorNumElements()),
Depth + 1);
KnownZero = KnownZero.zext(BitWidth);
KnownOne = KnownOne.zext(BitWidth);
KnownZero |= NewBits;
break;
}
case ISD::ZERO_EXTEND: {
EVT InVT = Op.getOperand(0).getValueType();
unsigned InBits = InVT.getScalarSizeInBits();
Expand All @@ -2432,6 +2446,7 @@ void SelectionDAG::computeKnownBits(SDValue Op, APInt &KnownZero,
KnownZero |= NewBits;
break;
}
// TODO ISD::SIGN_EXTEND_VECTOR_INREG
case ISD::SIGN_EXTEND: {
EVT InVT = Op.getOperand(0).getValueType();
unsigned InBits = InVT.getScalarSizeInBits();
Expand Down Expand Up @@ -2859,6 +2874,7 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, unsigned Depth) const {
}

case ISD::SIGN_EXTEND:
case ISD::SIGN_EXTEND_VECTOR_INREG:
Tmp = VTBits - Op.getOperand(0).getScalarValueSizeInBits();
return ComputeNumSignBits(Op.getOperand(0), Depth+1) + Tmp;

Expand Down
95 changes: 61 additions & 34 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -923,6 +923,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,

// SSE41 brings specific instructions for doing vector sign extend even in
// cases where we don't have SRA.
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v2i64, Legal);
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v4i32, Legal);
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v8i16, Legal);

setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, MVT::v2i64, Legal);
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, MVT::v4i32, Legal);
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, MVT::v8i16, Legal);

for (MVT VT : MVT::integer_vector_valuetypes()) {
setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i8, Custom);
setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i16, Custom);
Expand Down Expand Up @@ -5137,6 +5145,26 @@ static SDValue getOnesVector(EVT VT, const X86Subtarget &Subtarget,
return DAG.getBitcast(VT, Vec);
}

static SDValue getExtendInVec(unsigned Opc, const SDLoc &DL, EVT VT, SDValue In,
SelectionDAG &DAG) {
EVT InVT = In.getValueType();
assert((X86ISD::VSEXT == Opc || X86ISD::VZEXT == Opc) && "Unexpected opcode");

if (VT.is128BitVector() && InVT.is128BitVector())
return X86ISD::VSEXT == Opc ? DAG.getSignExtendVectorInReg(In, DL, VT)
: DAG.getZeroExtendVectorInReg(In, DL, VT);

// For 256-bit vectors, we only need the lower (128-bit) input half.
// For 512-bit vectors, we only need the lower input half or quarter.
if (VT.getSizeInBits() > 128 && InVT.getSizeInBits() > 128) {
int Scale = VT.getScalarSizeInBits() / InVT.getScalarSizeInBits();
In = extractSubVector(In, 0, DAG, DL,
std::max(128, (int)VT.getSizeInBits() / Scale));
}

return DAG.getNode(Opc, DL, VT, In);
}

/// Generate unpacklo/unpackhi shuffle mask.
static void createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, bool Lo,
bool Unary) {
Expand Down Expand Up @@ -5853,6 +5881,7 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask,
}
return true;
}
case ISD::ZERO_EXTEND_VECTOR_INREG:
case X86ISD::VZEXT: {
// TODO - add support for VPMOVZX with smaller input vector types.
SDValue Src = N.getOperand(0);
Expand Down Expand Up @@ -9215,14 +9244,7 @@ static SDValue lowerVectorShuffleAsSpecificZeroOrAnyExtend(
MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * Scale),
NumElements / Scale);
InputV = ShuffleOffset(InputV);

// For 256-bit vectors, we only need the lower (128-bit) input half.
// For 512-bit vectors, we only need the lower input half or quarter.
if (VT.getSizeInBits() > 128)
InputV = extractSubVector(InputV, 0, DAG, DL,
std::max(128, (int)VT.getSizeInBits() / Scale));

InputV = DAG.getNode(X86ISD::VZEXT, DL, ExtVT, InputV);
InputV = getExtendInVec(X86ISD::VZEXT, DL, ExtVT, InputV, DAG);
return DAG.getBitcast(VT, InputV);
}

Expand Down Expand Up @@ -15647,7 +15669,7 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
// word to byte only under BWI
if (InVT == MVT::v16i16 && !Subtarget.hasBWI()) // v16i16 -> v16i8
return DAG.getNode(X86ISD::VTRUNC, DL, VT,
DAG.getNode(X86ISD::VSEXT, DL, MVT::v16i32, In));
getExtendInVec(X86ISD::VSEXT, DL, MVT::v16i32, In, DAG));
return DAG.getNode(X86ISD::VTRUNC, DL, VT, In);
}

Expand Down Expand Up @@ -17625,8 +17647,8 @@ static SDValue LowerSIGN_EXTEND_AVX512(SDValue Op,
if (VT.is512BitVector() && InVTElt != MVT::i1 &&
(NumElts == 8 || NumElts == 16 || Subtarget.hasBWI())) {
if (In.getOpcode() == X86ISD::VSEXT || In.getOpcode() == X86ISD::VZEXT)
return DAG.getNode(In.getOpcode(), dl, VT, In.getOperand(0));
return DAG.getNode(X86ISD::VSEXT, dl, VT, In);
return getExtendInVec(In.getOpcode(), dl, VT, In.getOperand(0), DAG);
return getExtendInVec(X86ISD::VSEXT, dl, VT, In, DAG);
}

if (InVTElt != MVT::i1)
Expand All @@ -17638,7 +17660,7 @@ static SDValue LowerSIGN_EXTEND_AVX512(SDValue Op,

SDValue V;
if (Subtarget.hasDQI()) {
V = DAG.getNode(X86ISD::VSEXT, dl, ExtVT, In);
V = getExtendInVec(X86ISD::VSEXT, dl, ExtVT, In, DAG);
assert(!VT.is512BitVector() && "Unexpected vector type");
} else {
SDValue NegOne = getOnesVector(ExtVT, Subtarget, DAG, dl);
Expand Down Expand Up @@ -17690,11 +17712,15 @@ static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op,
assert((Op.getOpcode() != ISD::ZERO_EXTEND_VECTOR_INREG ||
InVT == MVT::v64i8) && "Zero extend only for v64i8 input!");

// SSE41 targets can use the pmovsx* instructions directly.
unsigned ExtOpc = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ?
X86ISD::VSEXT : X86ISD::VZEXT;
if (Subtarget.hasSSE41())
// SSE41 targets can use the pmovsx* instructions directly for 128-bit results,
// so are legal and shouldn't occur here. AVX2/AVX512 pmovsx* instructions still
// need to be handled here for 256/512-bit results.
if (Subtarget.hasInt256()) {
assert(VT.getSizeInBits() > 128 && "Unexpected 128-bit vector extension");
unsigned ExtOpc = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ?
X86ISD::VSEXT : X86ISD::VZEXT;
return DAG.getNode(ExtOpc, dl, VT, In);
}

// We should only get here for sign extend.
assert(Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG &&
Expand Down Expand Up @@ -17779,8 +17805,8 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
MVT HalfVT = MVT::getVectorVT(VT.getVectorElementType(),
VT.getVectorNumElements() / 2);

OpLo = DAG.getNode(X86ISD::VSEXT, dl, HalfVT, OpLo);
OpHi = DAG.getNode(X86ISD::VSEXT, dl, HalfVT, OpHi);
OpLo = DAG.getSignExtendVectorInReg(OpLo, dl, HalfVT);
OpHi = DAG.getSignExtendVectorInReg(OpHi, dl, HalfVT);

return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi);
}
Expand Down Expand Up @@ -18095,7 +18121,7 @@ static SDValue LowerExtendedLoad(SDValue Op, const X86Subtarget &Subtarget,
if (Ext == ISD::SEXTLOAD) {
// If we have SSE4.1, we can directly emit a VSEXT node.
if (Subtarget.hasSSE41()) {
SDValue Sext = DAG.getNode(X86ISD::VSEXT, dl, RegVT, SlicedVec);
SDValue Sext = getExtendInVec(X86ISD::VSEXT, dl, RegVT, SlicedVec, DAG);
DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), TF);
return Sext;
}
Expand Down Expand Up @@ -18766,11 +18792,11 @@ static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT,
ShAmt.getOperand(0).getSimpleValueType() == MVT::i16) {
ShAmt = ShAmt.getOperand(0);
ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v8i16, ShAmt);
ShAmt = DAG.getNode(X86ISD::VZEXT, SDLoc(ShAmt), MVT::v2i64, ShAmt);
ShAmt = DAG.getZeroExtendVectorInReg(ShAmt, SDLoc(ShAmt), MVT::v2i64);
} else if (Subtarget.hasSSE41() &&
ShAmt.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v4i32, ShAmt);
ShAmt = DAG.getNode(X86ISD::VZEXT, SDLoc(ShAmt), MVT::v2i64, ShAmt);
ShAmt = DAG.getZeroExtendVectorInReg(ShAmt, SDLoc(ShAmt), MVT::v2i64);
} else {
SmallVector<SDValue, 4> ShOps = {ShAmt, DAG.getConstant(0, dl, SVT),
DAG.getUNDEF(SVT), DAG.getUNDEF(SVT)};
Expand Down Expand Up @@ -21061,8 +21087,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
// Extract the lo parts and sign extend to i16
SDValue ALo, BLo;
if (Subtarget.hasSSE41()) {
ALo = DAG.getNode(X86ISD::VSEXT, dl, ExVT, A);
BLo = DAG.getNode(X86ISD::VSEXT, dl, ExVT, B);
ALo = DAG.getSignExtendVectorInReg(A, dl, ExVT);
BLo = DAG.getSignExtendVectorInReg(B, dl, ExVT);
} else {
const int ShufMask[] = {-1, 0, -1, 1, -1, 2, -1, 3,
-1, 4, -1, 5, -1, 6, -1, 7};
Expand All @@ -21081,8 +21107,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
-1, -1, -1, -1, -1, -1, -1, -1};
AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
AHi = DAG.getNode(X86ISD::VSEXT, dl, ExVT, AHi);
BHi = DAG.getNode(X86ISD::VSEXT, dl, ExVT, BHi);
AHi = DAG.getSignExtendVectorInReg(AHi, dl, ExVT);
BHi = DAG.getSignExtendVectorInReg(BHi, dl, ExVT);
} else {
const int ShufMask[] = {-1, 8, -1, 9, -1, 10, -1, 11,
-1, 12, -1, 13, -1, 14, -1, 15};
Expand Down Expand Up @@ -21243,8 +21269,8 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
DAG.getVectorShuffle(MVT::v16i16, dl, Lo, Hi, HiMask));
}

SDValue ExA = DAG.getNode(ExSSE41, dl, MVT::v16i16, A);
SDValue ExB = DAG.getNode(ExSSE41, dl, MVT::v16i16, B);
SDValue ExA = getExtendInVec(ExSSE41, dl, MVT::v16i16, A, DAG);
SDValue ExB = getExtendInVec(ExSSE41, dl, MVT::v16i16, B, DAG);
SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v16i16, ExA, ExB);
SDValue MulH = DAG.getNode(ISD::SRL, dl, MVT::v16i16, Mul,
DAG.getConstant(8, dl, MVT::v16i16));
Expand All @@ -21260,8 +21286,8 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
// Extract the lo parts and zero/sign extend to i16.
SDValue ALo, BLo;
if (Subtarget.hasSSE41()) {
ALo = DAG.getNode(ExSSE41, dl, ExVT, A);
BLo = DAG.getNode(ExSSE41, dl, ExVT, B);
ALo = getExtendInVec(ExSSE41, dl, ExVT, A, DAG);
BLo = getExtendInVec(ExSSE41, dl, ExVT, B, DAG);
} else {
const int ShufMask[] = {-1, 0, -1, 1, -1, 2, -1, 3,
-1, 4, -1, 5, -1, 6, -1, 7};
Expand All @@ -21280,8 +21306,8 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
-1, -1, -1, -1, -1, -1, -1, -1};
AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
AHi = DAG.getNode(ExSSE41, dl, ExVT, AHi);
BHi = DAG.getNode(ExSSE41, dl, ExVT, BHi);
AHi = getExtendInVec(ExSSE41, dl, ExVT, AHi, DAG);
BHi = getExtendInVec(ExSSE41, dl, ExVT, BHi, DAG);
} else {
const int ShufMask[] = {-1, 8, -1, 9, -1, 10, -1, 11,
-1, 12, -1, 13, -1, 14, -1, 15};
Expand Down Expand Up @@ -26458,7 +26484,7 @@ static bool matchUnaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
unsigned NumMaskElts = Mask.size();
unsigned MaskEltSize = MaskVT.getScalarSizeInBits();

// Match against a VZEXT instruction.
// Match against a ZERO_EXTEND_VECTOR_INREG/VZEXT instruction.
// TODO: Add 512-bit vector support (split AVX512F and AVX512BW).
if (AllowIntDomain && ((MaskVT.is128BitVector() && Subtarget.hasSSE41()) ||
(MaskVT.is256BitVector() && Subtarget.hasInt256()))) {
Expand All @@ -26477,7 +26503,8 @@ static bool matchUnaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
V1 = extractSubVector(V1, 0, DAG, DL, SrcSize);
DstVT = MVT::getIntegerVT(Scale * MaskEltSize);
DstVT = MVT::getVectorVT(DstVT, NumDstElts);
Shuffle = X86ISD::VZEXT;
Shuffle = (SrcVT != MaskVT ? X86ISD::VZEXT
: ISD::ZERO_EXTEND_VECTOR_INREG);
return true;
}
}
Expand Down Expand Up @@ -32169,7 +32196,7 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG,
Mld->getBasePtr(), NewMask, WideSrc0,
Mld->getMemoryVT(), Mld->getMemOperand(),
ISD::NON_EXTLOAD);
SDValue NewVec = DAG.getNode(X86ISD::VSEXT, dl, VT, WideLd);
SDValue NewVec = getExtendInVec(X86ISD::VSEXT, dl, VT, WideLd, DAG);
return DCI.CombineTo(N, NewVec, WideLd.getValue(1), true);
}

Expand Down

0 comments on commit 9f5c251

Please sign in to comment.