Skip to content

Commit

Permalink
[X86] Improve SMULO/UMULO codegen for vXi8 vectors.
Browse files Browse the repository at this point in the history
The default expansion creates a MUL and either a MULHS/MULHU. Each
of those separately expand to sequences that use one or more
PMULLW instructions as well as additional instructions to
extend the types to vXi16. The MULHS/MULHU expansion computes the
whole 16-bit product, but only keeps the high part.

We can improve the lowering of SMULO/UMULO for some cases by using the MULHS/MULHU
expansion, but keep both the high and low parts. And we can use
those parts to calculate the overflow.

For AVX512 we might have vXi1 overflow outputs. We can improve those by using
vpcmpeqw to produce a k register if AVX512BW is enabled. This is a little better
than truncating the high result to use vpcmpeqb. If we don't have avx512bw we
can extend up to v16i32 to use vpcmpeqd to produce a k register.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D97624
  • Loading branch information
topperc committed Mar 31, 2021
1 parent 00c0c8c commit 437958d
Show file tree
Hide file tree
Showing 4 changed files with 1,653 additions and 1,988 deletions.
281 changes: 219 additions & 62 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::MULHU, MVT::v8i16, Legal);
setOperationAction(ISD::MULHS, MVT::v8i16, Legal);
setOperationAction(ISD::MUL, MVT::v8i16, Legal);

setOperationAction(ISD::SMULO, MVT::v16i8, Custom);
setOperationAction(ISD::UMULO, MVT::v16i8, Custom);

setOperationAction(ISD::FNEG, MVT::v2f64, Custom);
setOperationAction(ISD::FABS, MVT::v2f64, Custom);
setOperationAction(ISD::FCOPYSIGN, MVT::v2f64, Custom);
Expand Down Expand Up @@ -1331,6 +1335,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::MULHU, MVT::v32i8, Custom);
setOperationAction(ISD::MULHS, MVT::v32i8, Custom);

setOperationAction(ISD::SMULO, MVT::v32i8, Custom);
setOperationAction(ISD::UMULO, MVT::v32i8, Custom);

setOperationAction(ISD::ABS, MVT::v4i64, Custom);
setOperationAction(ISD::SMAX, MVT::v4i64, Custom);
setOperationAction(ISD::UMAX, MVT::v4i64, Custom);
Expand Down Expand Up @@ -1627,6 +1634,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::MULHS, MVT::v64i8, Custom);
setOperationAction(ISD::MULHU, MVT::v64i8, Custom);

setOperationAction(ISD::SMULO, MVT::v64i8, Custom);
setOperationAction(ISD::UMULO, MVT::v64i8, Custom);

setOperationAction(ISD::BITREVERSE, MVT::v64i8, Custom);

for (auto VT : { MVT::v64i8, MVT::v32i16, MVT::v16i32, MVT::v8i64 }) {
Expand Down Expand Up @@ -27385,6 +27395,94 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getNode(ISD::ADD, dl, VT, AloBlo, Hi);
}

static SDValue LowervXi8MulWithUNPCK(SDValue A, SDValue B, const SDLoc &dl,
MVT VT, bool IsSigned,
const X86Subtarget &Subtarget,
SelectionDAG &DAG,
SDValue *Low = nullptr) {
unsigned NumElts = VT.getVectorNumElements();

// For vXi8 we will unpack the low and high half of each 128 bit lane to widen
// to a vXi16 type. Do the multiplies, shift the results and pack the half
// lane results back together.

// We'll take different approaches for signed and unsigned.
// For unsigned we'll use punpcklbw/punpckhbw to put zero extend the bytes
// and use pmullw to calculate the full 16-bit product.
// For signed we'll use punpcklbw/punpckbw to extend the bytes to words and
// shift them left into the upper byte of each word. This allows us to use
// pmulhw to calculate the full 16-bit product. This trick means we don't
// need to sign extend the bytes to use pmullw.

MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
SDValue Zero = DAG.getConstant(0, dl, VT);

SDValue ALo, AHi;
if (IsSigned) {
ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, Zero, A));
AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, A));
} else {
ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, Zero));
AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, Zero));
}

SDValue BLo, BHi;
if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) {
// If the RHS is a constant, manually unpackl/unpackh and extend.
SmallVector<SDValue, 16> LoOps, HiOps;
for (unsigned i = 0; i != NumElts; i += 16) {
for (unsigned j = 0; j != 8; ++j) {
SDValue LoOp = B.getOperand(i + j);
SDValue HiOp = B.getOperand(i + j + 8);

if (IsSigned) {
LoOp = DAG.getAnyExtOrTrunc(LoOp, dl, MVT::i16);
HiOp = DAG.getAnyExtOrTrunc(HiOp, dl, MVT::i16);
LoOp = DAG.getNode(ISD::SHL, dl, MVT::i16, LoOp,
DAG.getConstant(8, dl, MVT::i16));
HiOp = DAG.getNode(ISD::SHL, dl, MVT::i16, HiOp,
DAG.getConstant(8, dl, MVT::i16));
} else {
LoOp = DAG.getZExtOrTrunc(LoOp, dl, MVT::i16);
HiOp = DAG.getZExtOrTrunc(HiOp, dl, MVT::i16);
}

LoOps.push_back(LoOp);
HiOps.push_back(HiOp);
}
}

BLo = DAG.getBuildVector(ExVT, dl, LoOps);
BHi = DAG.getBuildVector(ExVT, dl, HiOps);
} else if (IsSigned) {
BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, Zero, B));
BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, B));
} else {
BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, Zero));
BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, Zero));
}

// Multiply, lshr the upper 8bits to the lower 8bits of the lo/hi results and
// pack back to vXi8.
unsigned MulOpc = IsSigned ? ISD::MULHS : ISD::MUL;
SDValue RLo = DAG.getNode(MulOpc, dl, ExVT, ALo, BLo);
SDValue RHi = DAG.getNode(MulOpc, dl, ExVT, AHi, BHi);

if (Low) {
// Mask the lower bits and pack the results to rejoin the halves.
SDValue Mask = DAG.getConstant(255, dl, ExVT);
SDValue LLo = DAG.getNode(ISD::AND, dl, ExVT, RLo, Mask);
SDValue LHi = DAG.getNode(ISD::AND, dl, ExVT, RHi, Mask);
*Low = DAG.getNode(X86ISD::PACKUS, dl, VT, LLo, LHi);
}

RLo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, RLo, 8, DAG);
RHi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, RHi, 8, DAG);

// Bitcast back to VT and then pack all the even elements from Lo and Hi.
return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi);
}

static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
SDLoc dl(Op);
Expand Down Expand Up @@ -27476,92 +27574,151 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,

// With SSE41 we can use sign/zero extend, but for pre-SSE41 we unpack
// and then ashr/lshr the upper bits down to the lower bits before multiply.
unsigned ExAVX = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;

if ((VT == MVT::v16i8 && Subtarget.hasInt256()) ||
(VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) {
MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
unsigned ExAVX = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
SDValue ExA = DAG.getNode(ExAVX, dl, ExVT, A);
SDValue ExB = DAG.getNode(ExAVX, dl, ExVT, B);
SDValue Mul = DAG.getNode(ISD::MUL, dl, ExVT, ExA, ExB);
Mul = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Mul, 8, DAG);
return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);
}

// For vXi8 we will unpack the low and high half of each 128 bit lane to widen
// to a vXi16 type. Do the multiplies, shift the results and pack the half
// lane results back together.
return LowervXi8MulWithUNPCK(A, B, dl, VT, IsSigned, Subtarget, DAG);
}

// We'll take different approaches for signed and unsigned.
// For unsigned we'll use punpcklbw/punpckhbw to zero extend the bytes and
// use pmullw to calculate the full 16-bit product.
// For signed we'll use punpcklbw/punpckhbw to extend the bytes to words by
// placing the bytes in the upper byte of each word with zeros in the lower
// byte. This allows us to use pmulhw to calculate the full 16-bit product.
// This trick means we don't need to sign extend the bytes to use pmullw.
// Custom lowering for SMULO/UMULO.
static SDValue LowerMULO(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();

MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
SDValue Zero = DAG.getConstant(0, dl, VT);
// Scalars defer to LowerXALUO.
if (!VT.isVector())
return LowerXALUO(Op, DAG);

SDValue ALo, AHi;
if (IsSigned) {
ALo = DAG.getBitcast(
ExVT, getUnpackl(DAG, dl, VT, Zero, A));
AHi = DAG.getBitcast(
ExVT, getUnpackh(DAG, dl, VT, Zero, A));
} else {
ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A,
Zero));
AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A,
Zero));
SDLoc dl(Op);
bool IsSigned = Op->getOpcode() == ISD::SMULO;
SDValue A = Op.getOperand(0);
SDValue B = Op.getOperand(1);
EVT OvfVT = Op->getValueType(1);

if ((VT == MVT::v32i8 && !Subtarget.hasInt256()) ||
(VT == MVT::v64i8 && !Subtarget.hasBWI())) {
// Extract the LHS Lo/Hi vectors
SDValue LHSLo, LHSHi;
std::tie(LHSLo, LHSHi) = splitVector(A, DAG, dl);

// Extract the RHS Lo/Hi vectors
SDValue RHSLo, RHSHi;
std::tie(RHSLo, RHSHi) = splitVector(B, DAG, dl);

EVT LoOvfVT, HiOvfVT;
std::tie(LoOvfVT, HiOvfVT) = DAG.GetSplitDestVTs(OvfVT);
SDVTList LoVTs = DAG.getVTList(LHSLo.getValueType(), LoOvfVT);
SDVTList HiVTs = DAG.getVTList(LHSHi.getValueType(), HiOvfVT);

// Issue the split operations.
SDValue Lo = DAG.getNode(Op.getOpcode(), dl, LoVTs, LHSLo, RHSLo);
SDValue Hi = DAG.getNode(Op.getOpcode(), dl, HiVTs, LHSHi, RHSHi);

// Join the separate data results and the overflow results.
SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi);
SDValue Ovf = DAG.getNode(ISD::CONCAT_VECTORS, dl, OvfVT, Lo.getValue(1),
Hi.getValue(1));

return DAG.getMergeValues({Res, Ovf}, dl);
}

SDValue BLo, BHi;
if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) {
// If the RHS is a constant, manually unpackl/unpackh and extend.
SmallVector<SDValue, 16> LoOps, HiOps;
for (unsigned i = 0; i != NumElts; i += 16) {
for (unsigned j = 0; j != 8; ++j) {
SDValue LoOp = B.getOperand(i + j);
SDValue HiOp = B.getOperand(i + j + 8);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
EVT SetccVT =
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);

if (IsSigned) {
LoOp = DAG.getAnyExtOrTrunc(LoOp, dl, MVT::i16);
HiOp = DAG.getAnyExtOrTrunc(HiOp, dl, MVT::i16);
LoOp = DAG.getNode(ISD::SHL, dl, MVT::i16, LoOp,
DAG.getConstant(8, dl, MVT::i16));
HiOp = DAG.getNode(ISD::SHL, dl, MVT::i16, HiOp,
DAG.getConstant(8, dl, MVT::i16));
} else {
LoOp = DAG.getZExtOrTrunc(LoOp, dl, MVT::i16);
HiOp = DAG.getZExtOrTrunc(HiOp, dl, MVT::i16);
if ((VT == MVT::v16i8 && Subtarget.hasInt256()) ||
(VT == MVT::v32i8 && Subtarget.canExtendTo512BW())) {
unsigned NumElts = VT.getVectorNumElements();
MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
unsigned ExAVX = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
SDValue ExA = DAG.getNode(ExAVX, dl, ExVT, A);
SDValue ExB = DAG.getNode(ExAVX, dl, ExVT, B);
SDValue Mul = DAG.getNode(ISD::MUL, dl, ExVT, ExA, ExB);

SDValue Low = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);

SDValue Ovf;
if (IsSigned) {
SDValue High, LowSign;
if (OvfVT.getVectorElementType() == MVT::i1 &&
(Subtarget.hasBWI() || Subtarget.canExtendTo512DQ())) {
// Rather the truncating try to do the compare on vXi16 or vXi32.
// Shift the high down filling with sign bits.
High = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, Mul, 8, DAG);
// Fill all 16 bits with the sign bit from the low.
LowSign =
getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ExVT, Mul, 8, DAG);
LowSign = getTargetVShiftByConstNode(X86ISD::VSRAI, dl, ExVT, LowSign,
15, DAG);
SetccVT = OvfVT;
if (!Subtarget.hasBWI()) {
// We can't do a vXi16 compare so sign extend to v16i32.
High = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v16i32, High);
LowSign = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v16i32, LowSign);
}
} else {
// Otherwise do the compare at vXi8.
High = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Mul, 8, DAG);
High = DAG.getNode(ISD::TRUNCATE, dl, VT, High);
LowSign =
DAG.getNode(ISD::SRA, dl, VT, Low, DAG.getConstant(7, dl, VT));
}

LoOps.push_back(LoOp);
HiOps.push_back(HiOp);
Ovf = DAG.getSetCC(dl, SetccVT, LowSign, High, ISD::SETNE);
} else {
SDValue High =
getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, Mul, 8, DAG);
if (OvfVT.getVectorElementType() == MVT::i1 &&
(Subtarget.hasBWI() || Subtarget.canExtendTo512DQ())) {
// Rather the truncating try to do the compare on vXi16 or vXi32.
SetccVT = OvfVT;
if (!Subtarget.hasBWI()) {
// We can't do a vXi16 compare so sign extend to v16i32.
High = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v16i32, High);
}
} else {
// Otherwise do the compare at vXi8.
High = DAG.getNode(ISD::TRUNCATE, dl, VT, High);
}

Ovf =
DAG.getSetCC(dl, SetccVT, High,
DAG.getConstant(0, dl, High.getValueType()), ISD::SETNE);
}

BLo = DAG.getBuildVector(ExVT, dl, LoOps);
BHi = DAG.getBuildVector(ExVT, dl, HiOps);
} else if (IsSigned) {
BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, Zero, B));
BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, B));
Ovf = DAG.getSExtOrTrunc(Ovf, dl, OvfVT);

return DAG.getMergeValues({Low, Ovf}, dl);
}

SDValue Low;
SDValue High =
LowervXi8MulWithUNPCK(A, B, dl, VT, IsSigned, Subtarget, DAG, &Low);

SDValue Ovf;
if (IsSigned) {
// SMULO overflows if the high bits don't match the sign of the low.
SDValue LowSign =
DAG.getNode(ISD::SRA, dl, VT, Low, DAG.getConstant(7, dl, VT));
Ovf = DAG.getSetCC(dl, SetccVT, LowSign, High, ISD::SETNE);
} else {
BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, Zero));
BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, Zero));
// UMULO overflows if the high bits are non-zero.
Ovf =
DAG.getSetCC(dl, SetccVT, High, DAG.getConstant(0, dl, VT), ISD::SETNE);
}

// Multiply, lshr the upper 8bits to the lower 8bits of the lo/hi results and
// pack back to vXi8.
unsigned MulOpc = IsSigned ? ISD::MULHS : ISD::MUL;
SDValue RLo = DAG.getNode(MulOpc, dl, ExVT, ALo, BLo);
SDValue RHi = DAG.getNode(MulOpc, dl, ExVT, AHi, BHi);
RLo = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, RLo, 8, DAG);
RHi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ExVT, RHi, 8, DAG);
Ovf = DAG.getSExtOrTrunc(Ovf, dl, OvfVT);

// Pack all the even elements from Lo and Hi.
return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi);
return DAG.getMergeValues({Low, Ovf}, dl);
}

SDValue X86TargetLowering::LowerWin64_i128OP(SDValue Op, SelectionDAG &DAG) const {
Expand Down Expand Up @@ -30038,9 +30195,9 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::SADDO:
case ISD::UADDO:
case ISD::SSUBO:
case ISD::USUBO:
case ISD::USUBO: return LowerXALUO(Op, DAG);
case ISD::SMULO:
case ISD::UMULO: return LowerXALUO(Op, DAG);
case ISD::UMULO: return LowerMULO(Op, Subtarget, DAG);
case ISD::READCYCLECOUNTER: return LowerREADCYCLECOUNTER(Op, Subtarget,DAG);
case ISD::BITCAST: return LowerBITCAST(Op, Subtarget, DAG);
case ISD::SADDO_CARRY:
Expand Down
Loading

0 comments on commit 437958d

Please sign in to comment.