Skip to content

Commit

Permalink
[X86] Move combineToExtendBoolVectorInReg before the select combines.…
Browse files Browse the repository at this point in the history
… NFC.

Avoid the need for a forward declaration.

Cleanup prep for Issue #53760
  • Loading branch information
RKSimon committed Feb 11, 2022
1 parent efe5cba commit 48e1434
Showing 1 changed file with 98 additions and 99 deletions.
197 changes: 98 additions & 99 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43123,6 +43123,104 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
// This is more or less the reverse of combineBitcastvxi1.
static SDValue combineToExtendBoolVectorInReg(
unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) {
if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND &&
Opcode != ISD::ANY_EXTEND)
return SDValue();
if (!DCI.isBeforeLegalizeOps())
return SDValue();
if (!Subtarget.hasSSE2() || Subtarget.hasAVX512())
return SDValue();

EVT SVT = VT.getScalarType();
EVT InSVT = N0.getValueType().getScalarType();
unsigned EltSizeInBits = SVT.getSizeInBits();

// Input type must be extending a bool vector (bit-casted from a scalar
// integer) to legal integer types.
if (!VT.isVector())
return SDValue();
if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8)
return SDValue();
if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST)
return SDValue();

SDValue N00 = N0.getOperand(0);
EVT SclVT = N00.getValueType();
if (!SclVT.isScalarInteger())
return SDValue();

SDValue Vec;
SmallVector<int> ShuffleMask;
unsigned NumElts = VT.getVectorNumElements();
assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size");

// Broadcast the scalar integer to the vector elements.
if (NumElts > EltSizeInBits) {
// If the scalar integer is greater than the vector element size, then we
// must split it down into sub-sections for broadcasting. For example:
// i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections.
// i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections.
assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale");
unsigned Scale = NumElts / EltSizeInBits;
EVT BroadcastVT = EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits);
Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
Vec = DAG.getBitcast(VT, Vec);

for (unsigned i = 0; i != Scale; ++i)
ShuffleMask.append(EltSizeInBits, i);
Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
} else if (Subtarget.hasAVX2() && NumElts < EltSizeInBits &&
(SclVT == MVT::i8 || SclVT == MVT::i16 || SclVT == MVT::i32)) {
// If we have register broadcast instructions, use the scalar size as the
// element type for the shuffle. Then cast to the wider element type. The
// widened bits won't be used, and this might allow the use of a broadcast
// load.
assert((EltSizeInBits % NumElts) == 0 && "Unexpected integer scale");
unsigned Scale = EltSizeInBits / NumElts;
EVT BroadcastVT =
EVT::getVectorVT(*DAG.getContext(), SclVT, NumElts * Scale);
Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
ShuffleMask.append(NumElts * Scale, 0);
Vec = DAG.getVectorShuffle(BroadcastVT, DL, Vec, Vec, ShuffleMask);
Vec = DAG.getBitcast(VT, Vec);
} else {
// For smaller scalar integers, we can simply any-extend it to the vector
// element size (we don't care about the upper bits) and broadcast it to all
// elements.
SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT);
Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl);
ShuffleMask.append(NumElts, 0);
Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
}

// Now, mask the relevant bit in each element.
SmallVector<SDValue, 32> Bits;
for (unsigned i = 0; i != NumElts; ++i) {
int BitIdx = (i % EltSizeInBits);
APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1);
Bits.push_back(DAG.getConstant(Bit, DL, SVT));
}
SDValue BitMask = DAG.getBuildVector(VT, DL, Bits);
Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask);

// Compare against the bitmask and extend the result.
EVT CCVT = VT.changeVectorElementType(MVT::i1);
Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ);
Vec = DAG.getSExtOrTrunc(Vec, DL, VT);

// For SEXT, this is now done, otherwise shift the result down for
// zero-extension.
if (Opcode == ISD::SIGN_EXTEND)
return Vec;
return DAG.getNode(ISD::SRL, DL, VT, Vec,
DAG.getConstant(EltSizeInBits - 1, DL, VT));
}

/// If a vector select has an operand that is -1 or 0, try to simplify the
/// select to a bitwise logic operation.
/// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()?
Expand Down Expand Up @@ -50420,105 +50518,6 @@ static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) {
return Res;
}

// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
// This is more or less the reverse of combineBitcastvxi1.
static SDValue combineToExtendBoolVectorInReg(
unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) {
if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND &&
Opcode != ISD::ANY_EXTEND)
return SDValue();
if (!DCI.isBeforeLegalizeOps())
return SDValue();
if (!Subtarget.hasSSE2() || Subtarget.hasAVX512())
return SDValue();

EVT SVT = VT.getScalarType();
EVT InSVT = N0.getValueType().getScalarType();
unsigned EltSizeInBits = SVT.getSizeInBits();

// Input type must be extending a bool vector (bit-casted from a scalar
// integer) to legal integer types.
if (!VT.isVector())
return SDValue();
if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8)
return SDValue();
if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST)
return SDValue();

SDValue N00 = N0.getOperand(0);
EVT SclVT = N00.getValueType();
if (!SclVT.isScalarInteger())
return SDValue();

SDValue Vec;
SmallVector<int> ShuffleMask;
unsigned NumElts = VT.getVectorNumElements();
assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size");

// Broadcast the scalar integer to the vector elements.
if (NumElts > EltSizeInBits) {
// If the scalar integer is greater than the vector element size, then we
// must split it down into sub-sections for broadcasting. For example:
// i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections.
// i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections.
assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale");
unsigned Scale = NumElts / EltSizeInBits;
EVT BroadcastVT =
EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits);
Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
Vec = DAG.getBitcast(VT, Vec);

for (unsigned i = 0; i != Scale; ++i)
ShuffleMask.append(EltSizeInBits, i);
Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
} else if (Subtarget.hasAVX2() && NumElts < EltSizeInBits &&
(SclVT == MVT::i8 || SclVT == MVT::i16 || SclVT == MVT::i32)) {
// If we have register broadcast instructions, use the scalar size as the
// element type for the shuffle. Then cast to the wider element type. The
// widened bits won't be used, and this might allow the use of a broadcast
// load.
assert((EltSizeInBits % NumElts) == 0 && "Unexpected integer scale");
unsigned Scale = EltSizeInBits / NumElts;
EVT BroadcastVT =
EVT::getVectorVT(*DAG.getContext(), SclVT, NumElts * Scale);
Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
ShuffleMask.append(NumElts * Scale, 0);
Vec = DAG.getVectorShuffle(BroadcastVT, DL, Vec, Vec, ShuffleMask);
Vec = DAG.getBitcast(VT, Vec);
} else {
// For smaller scalar integers, we can simply any-extend it to the vector
// element size (we don't care about the upper bits) and broadcast it to all
// elements.
SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT);
Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl);
ShuffleMask.append(NumElts, 0);
Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
}

// Now, mask the relevant bit in each element.
SmallVector<SDValue, 32> Bits;
for (unsigned i = 0; i != NumElts; ++i) {
int BitIdx = (i % EltSizeInBits);
APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1);
Bits.push_back(DAG.getConstant(Bit, DL, SVT));
}
SDValue BitMask = DAG.getBuildVector(VT, DL, Bits);
Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask);

// Compare against the bitmask and extend the result.
EVT CCVT = VT.changeVectorElementType(MVT::i1);
Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ);
Vec = DAG.getSExtOrTrunc(Vec, DL, VT);

// For SEXT, this is now done, otherwise shift the result down for
// zero-extension.
if (Opcode == ISD::SIGN_EXTEND)
return Vec;
return DAG.getNode(ISD::SRL, DL, VT, Vec,
DAG.getConstant(EltSizeInBits - 1, DL, VT));
}

// Attempt to combine a (sext/zext (setcc)) to a setcc with a xmm/ymm/zmm
// result type.
static SDValue combineExtSetcc(SDNode *N, SelectionDAG &DAG,
Expand Down

0 comments on commit 48e1434

Please sign in to comment.