Skip to content

Commit

Permalink
[X86] Move combineVectorSizedSetCCEquality above MatchVectorAllZeroTe…
Browse files Browse the repository at this point in the history
…st. NFC.

The plan is to merge most of the functionality of both of these into a single 'match vector sized data' function.
  • Loading branch information
RKSimon committed Mar 27, 2023
1 parent da92f2f commit 0f76fb9
Showing 1 changed file with 188 additions and 186 deletions.
374 changes: 188 additions & 186 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -24079,6 +24079,194 @@ static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl,
DAG.getTargetConstant(Cond, dl, MVT::i8), EFLAGS);
}

/// Recursive helper for combineVectorSizedSetCCEquality() to see if we have a
/// recognizable memcmp expansion.
static bool isOrXorXorTree(SDValue X, bool Root = true) {
if (X.getOpcode() == ISD::OR)
return isOrXorXorTree(X.getOperand(0), false) &&
isOrXorXorTree(X.getOperand(1), false);
if (Root)
return false;
return X.getOpcode() == ISD::XOR;
}

/// Recursive helper for combineVectorSizedSetCCEquality() to emit the memcmp
/// expansion.
template <typename F>
static SDValue emitOrXorXorTree(SDValue X, const SDLoc &DL, SelectionDAG &DAG,
EVT VecVT, EVT CmpVT, bool HasPT, F SToV) {
SDValue Op0 = X.getOperand(0);
SDValue Op1 = X.getOperand(1);
if (X.getOpcode() == ISD::OR) {
SDValue A = emitOrXorXorTree(Op0, DL, DAG, VecVT, CmpVT, HasPT, SToV);
SDValue B = emitOrXorXorTree(Op1, DL, DAG, VecVT, CmpVT, HasPT, SToV);
if (VecVT != CmpVT)
return DAG.getNode(ISD::OR, DL, CmpVT, A, B);
if (HasPT)
return DAG.getNode(ISD::OR, DL, VecVT, A, B);
return DAG.getNode(ISD::AND, DL, CmpVT, A, B);
}
if (X.getOpcode() == ISD::XOR) {
SDValue A = SToV(Op0);
SDValue B = SToV(Op1);
if (VecVT != CmpVT)
return DAG.getSetCC(DL, CmpVT, A, B, ISD::SETNE);
if (HasPT)
return DAG.getNode(ISD::XOR, DL, VecVT, A, B);
return DAG.getSetCC(DL, CmpVT, A, B, ISD::SETEQ);
}
llvm_unreachable("Impossible");
}

/// Try to map a 128-bit or larger integer comparison to vector instructions
/// before type legalization splits it up into chunks.
static SDValue combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y,
ISD::CondCode CC,
const SDLoc &DL,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
assert((CC == ISD::SETNE || CC == ISD::SETEQ) && "Bad comparison predicate");

// We're looking for an oversized integer equality comparison.
EVT OpVT = X.getValueType();
unsigned OpSize = OpVT.getSizeInBits();
if (!OpVT.isScalarInteger() || OpSize < 128)
return SDValue();

// Ignore a comparison with zero because that gets special treatment in
// EmitTest(). But make an exception for the special case of a pair of
// logically-combined vector-sized operands compared to zero. This pattern may
// be generated by the memcmp expansion pass with oversized integer compares
// (see PR33325).
bool IsOrXorXorTreeCCZero = isNullConstant(Y) && isOrXorXorTree(X);
if (isNullConstant(Y) && !IsOrXorXorTreeCCZero)
return SDValue();

// Don't perform this combine if constructing the vector will be expensive.
auto IsVectorBitCastCheap = [](SDValue X) {
X = peekThroughBitcasts(X);
return isa<ConstantSDNode>(X) || X.getValueType().isVector() ||
X.getOpcode() == ISD::LOAD;
};
if ((!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y)) &&
!IsOrXorXorTreeCCZero)
return SDValue();

// Use XOR (plus OR) and PTEST after SSE4.1 for 128/256-bit operands.
// Use PCMPNEQ (plus OR) and KORTEST for 512-bit operands.
// Otherwise use PCMPEQ (plus AND) and mask testing.
bool NoImplicitFloatOps =
DAG.getMachineFunction().getFunction().hasFnAttribute(
Attribute::NoImplicitFloat);
if (!Subtarget.useSoftFloat() && !NoImplicitFloatOps &&
((OpSize == 128 && Subtarget.hasSSE2()) ||
(OpSize == 256 && Subtarget.hasAVX()) ||
(OpSize == 512 && Subtarget.useAVX512Regs()))) {
bool HasPT = Subtarget.hasSSE41();

// PTEST and MOVMSK are slow on Knights Landing and Knights Mill and widened
// vector registers are essentially free. (Technically, widening registers
// prevents load folding, but the tradeoff is worth it.)
bool PreferKOT = Subtarget.preferMaskRegisters();
bool NeedZExt = PreferKOT && !Subtarget.hasVLX() && OpSize != 512;

EVT VecVT = MVT::v16i8;
EVT CmpVT = PreferKOT ? MVT::v16i1 : VecVT;
if (OpSize == 256) {
VecVT = MVT::v32i8;
CmpVT = PreferKOT ? MVT::v32i1 : VecVT;
}
EVT CastVT = VecVT;
bool NeedsAVX512FCast = false;
if (OpSize == 512 || NeedZExt) {
if (Subtarget.hasBWI()) {
VecVT = MVT::v64i8;
CmpVT = MVT::v64i1;
if (OpSize == 512)
CastVT = VecVT;
} else {
VecVT = MVT::v16i32;
CmpVT = MVT::v16i1;
CastVT = OpSize == 512 ? VecVT
: OpSize == 256 ? MVT::v8i32
: MVT::v4i32;
NeedsAVX512FCast = true;
}
}

auto ScalarToVector = [&](SDValue X) -> SDValue {
bool TmpZext = false;
EVT TmpCastVT = CastVT;
if (X.getOpcode() == ISD::ZERO_EXTEND) {
SDValue OrigX = X.getOperand(0);
unsigned OrigSize = OrigX.getScalarValueSizeInBits();
if (OrigSize < OpSize) {
if (OrigSize == 128) {
TmpCastVT = NeedsAVX512FCast ? MVT::v4i32 : MVT::v16i8;
X = OrigX;
TmpZext = true;
} else if (OrigSize == 256) {
TmpCastVT = NeedsAVX512FCast ? MVT::v8i32 : MVT::v32i8;
X = OrigX;
TmpZext = true;
}
}
}
X = DAG.getBitcast(TmpCastVT, X);
if (!NeedZExt && !TmpZext)
return X;
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT,
DAG.getConstant(0, DL, VecVT), X,
DAG.getVectorIdxConstant(0, DL));
};

SDValue Cmp;
if (IsOrXorXorTreeCCZero) {
// This is a bitwise-combined equality comparison of 2 pairs of vectors:
// setcc i128 (or (xor A, B), (xor C, D)), 0, eq|ne
// Use 2 vector equality compares and 'and' the results before doing a
// MOVMSK.
Cmp = emitOrXorXorTree(X, DL, DAG, VecVT, CmpVT, HasPT, ScalarToVector);
} else {
SDValue VecX = ScalarToVector(X);
SDValue VecY = ScalarToVector(Y);
if (VecVT != CmpVT) {
Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETNE);
} else if (HasPT) {
Cmp = DAG.getNode(ISD::XOR, DL, VecVT, VecX, VecY);
} else {
Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETEQ);
}
}
// AVX512 should emit a setcc that will lower to kortest.
if (VecVT != CmpVT) {
EVT KRegVT = CmpVT == MVT::v64i1 ? MVT::i64
: CmpVT == MVT::v32i1 ? MVT::i32
: MVT::i16;
return DAG.getSetCC(DL, VT, DAG.getBitcast(KRegVT, Cmp),
DAG.getConstant(0, DL, KRegVT), CC);
}
if (HasPT) {
SDValue BCCmp =
DAG.getBitcast(OpSize == 256 ? MVT::v4i64 : MVT::v2i64, Cmp);
SDValue PT = DAG.getNode(X86ISD::PTEST, DL, MVT::i32, BCCmp, BCCmp);
X86::CondCode X86CC = CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE;
SDValue X86SetCC = getSETCC(X86CC, PT, DL, DAG);
return DAG.getNode(ISD::TRUNCATE, DL, VT, X86SetCC.getValue(0));
}
// If all bytes match (bitmask is 0x(FFFF)FFFF), that's equality.
// setcc i128 X, Y, eq --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, eq
// setcc i128 X, Y, ne --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, ne
assert(Cmp.getValueType() == MVT::v16i8 &&
"Non 128-bit vector on pre-SSE41 target");
SDValue MovMsk = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Cmp);
SDValue FFFFs = DAG.getConstant(0xFFFF, DL, MVT::i32);
return DAG.getSetCC(DL, VT, MovMsk, FFFFs, CC);
}

return SDValue();
}

/// Helper for matching OR(EXTRACTELT(X,0),OR(EXTRACTELT(X,1),...))
/// style scalarized (associative) reduction patterns. Partial reductions
/// are supported when the pointer SrcMask is non-null.
Expand Down Expand Up @@ -53952,192 +54140,6 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

/// Recursive helper for combineVectorSizedSetCCEquality() to see if we have a
/// recognizable memcmp expansion.
static bool isOrXorXorTree(SDValue X, bool Root = true) {
if (X.getOpcode() == ISD::OR)
return isOrXorXorTree(X.getOperand(0), false) &&
isOrXorXorTree(X.getOperand(1), false);
if (Root)
return false;
return X.getOpcode() == ISD::XOR;
}

/// Recursive helper for combineVectorSizedSetCCEquality() to emit the memcmp
/// expansion.
template <typename F>
static SDValue emitOrXorXorTree(SDValue X, const SDLoc &DL, SelectionDAG &DAG,
EVT VecVT, EVT CmpVT, bool HasPT, F SToV) {
SDValue Op0 = X.getOperand(0);
SDValue Op1 = X.getOperand(1);
if (X.getOpcode() == ISD::OR) {
SDValue A = emitOrXorXorTree(Op0, DL, DAG, VecVT, CmpVT, HasPT, SToV);
SDValue B = emitOrXorXorTree(Op1, DL, DAG, VecVT, CmpVT, HasPT, SToV);
if (VecVT != CmpVT)
return DAG.getNode(ISD::OR, DL, CmpVT, A, B);
if (HasPT)
return DAG.getNode(ISD::OR, DL, VecVT, A, B);
return DAG.getNode(ISD::AND, DL, CmpVT, A, B);
}
if (X.getOpcode() == ISD::XOR) {
SDValue A = SToV(Op0);
SDValue B = SToV(Op1);
if (VecVT != CmpVT)
return DAG.getSetCC(DL, CmpVT, A, B, ISD::SETNE);
if (HasPT)
return DAG.getNode(ISD::XOR, DL, VecVT, A, B);
return DAG.getSetCC(DL, CmpVT, A, B, ISD::SETEQ);
}
llvm_unreachable("Impossible");
}

/// Try to map a 128-bit or larger integer comparison to vector instructions
/// before type legalization splits it up into chunks.
static SDValue combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y,
ISD::CondCode CC,
const SDLoc &DL,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
assert((CC == ISD::SETNE || CC == ISD::SETEQ) && "Bad comparison predicate");

// We're looking for an oversized integer equality comparison.
EVT OpVT = X.getValueType();
unsigned OpSize = OpVT.getSizeInBits();
if (!OpVT.isScalarInteger() || OpSize < 128)
return SDValue();

// Ignore a comparison with zero because that gets special treatment in
// EmitTest(). But make an exception for the special case of a pair of
// logically-combined vector-sized operands compared to zero. This pattern may
// be generated by the memcmp expansion pass with oversized integer compares
// (see PR33325).
bool IsOrXorXorTreeCCZero = isNullConstant(Y) && isOrXorXorTree(X);
if (isNullConstant(Y) && !IsOrXorXorTreeCCZero)
return SDValue();

// Don't perform this combine if constructing the vector will be expensive.
auto IsVectorBitCastCheap = [](SDValue X) {
X = peekThroughBitcasts(X);
return isa<ConstantSDNode>(X) || X.getValueType().isVector() ||
X.getOpcode() == ISD::LOAD;
};
if ((!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y)) &&
!IsOrXorXorTreeCCZero)
return SDValue();

// Use XOR (plus OR) and PTEST after SSE4.1 for 128/256-bit operands.
// Use PCMPNEQ (plus OR) and KORTEST for 512-bit operands.
// Otherwise use PCMPEQ (plus AND) and mask testing.
bool NoImplicitFloatOps =
DAG.getMachineFunction().getFunction().hasFnAttribute(
Attribute::NoImplicitFloat);
if (!Subtarget.useSoftFloat() && !NoImplicitFloatOps &&
((OpSize == 128 && Subtarget.hasSSE2()) ||
(OpSize == 256 && Subtarget.hasAVX()) ||
(OpSize == 512 && Subtarget.useAVX512Regs()))) {
bool HasPT = Subtarget.hasSSE41();

// PTEST and MOVMSK are slow on Knights Landing and Knights Mill and widened
// vector registers are essentially free. (Technically, widening registers
// prevents load folding, but the tradeoff is worth it.)
bool PreferKOT = Subtarget.preferMaskRegisters();
bool NeedZExt = PreferKOT && !Subtarget.hasVLX() && OpSize != 512;

EVT VecVT = MVT::v16i8;
EVT CmpVT = PreferKOT ? MVT::v16i1 : VecVT;
if (OpSize == 256) {
VecVT = MVT::v32i8;
CmpVT = PreferKOT ? MVT::v32i1 : VecVT;
}
EVT CastVT = VecVT;
bool NeedsAVX512FCast = false;
if (OpSize == 512 || NeedZExt) {
if (Subtarget.hasBWI()) {
VecVT = MVT::v64i8;
CmpVT = MVT::v64i1;
if (OpSize == 512)
CastVT = VecVT;
} else {
VecVT = MVT::v16i32;
CmpVT = MVT::v16i1;
CastVT = OpSize == 512 ? VecVT :
OpSize == 256 ? MVT::v8i32 : MVT::v4i32;
NeedsAVX512FCast = true;
}
}

auto ScalarToVector = [&](SDValue X) -> SDValue {
bool TmpZext = false;
EVT TmpCastVT = CastVT;
if (X.getOpcode() == ISD::ZERO_EXTEND) {
SDValue OrigX = X.getOperand(0);
unsigned OrigSize = OrigX.getScalarValueSizeInBits();
if (OrigSize < OpSize) {
if (OrigSize == 128) {
TmpCastVT = NeedsAVX512FCast ? MVT::v4i32 : MVT::v16i8;
X = OrigX;
TmpZext = true;
} else if (OrigSize == 256) {
TmpCastVT = NeedsAVX512FCast ? MVT::v8i32 : MVT::v32i8;
X = OrigX;
TmpZext = true;
}
}
}
X = DAG.getBitcast(TmpCastVT, X);
if (!NeedZExt && !TmpZext)
return X;
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT,
DAG.getConstant(0, DL, VecVT), X,
DAG.getVectorIdxConstant(0, DL));
};

SDValue Cmp;
if (IsOrXorXorTreeCCZero) {
// This is a bitwise-combined equality comparison of 2 pairs of vectors:
// setcc i128 (or (xor A, B), (xor C, D)), 0, eq|ne
// Use 2 vector equality compares and 'and' the results before doing a
// MOVMSK.
Cmp = emitOrXorXorTree(X, DL, DAG, VecVT, CmpVT, HasPT, ScalarToVector);
} else {
SDValue VecX = ScalarToVector(X);
SDValue VecY = ScalarToVector(Y);
if (VecVT != CmpVT) {
Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETNE);
} else if (HasPT) {
Cmp = DAG.getNode(ISD::XOR, DL, VecVT, VecX, VecY);
} else {
Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETEQ);
}
}
// AVX512 should emit a setcc that will lower to kortest.
if (VecVT != CmpVT) {
EVT KRegVT = CmpVT == MVT::v64i1 ? MVT::i64 :
CmpVT == MVT::v32i1 ? MVT::i32 : MVT::i16;
return DAG.getSetCC(DL, VT, DAG.getBitcast(KRegVT, Cmp),
DAG.getConstant(0, DL, KRegVT), CC);
}
if (HasPT) {
SDValue BCCmp = DAG.getBitcast(OpSize == 256 ? MVT::v4i64 : MVT::v2i64,
Cmp);
SDValue PT = DAG.getNode(X86ISD::PTEST, DL, MVT::i32, BCCmp, BCCmp);
X86::CondCode X86CC = CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE;
SDValue X86SetCC = getSETCC(X86CC, PT, DL, DAG);
return DAG.getNode(ISD::TRUNCATE, DL, VT, X86SetCC.getValue(0));
}
// If all bytes match (bitmask is 0x(FFFF)FFFF), that's equality.
// setcc i128 X, Y, eq --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, eq
// setcc i128 X, Y, ne --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, ne
assert(Cmp.getValueType() == MVT::v16i8 &&
"Non 128-bit vector on pre-SSE41 target");
SDValue MovMsk = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Cmp);
SDValue FFFFs = DAG.getConstant(0xFFFF, DL, MVT::i32);
return DAG.getSetCC(DL, VT, MovMsk, FFFFs, CC);
}

return SDValue();
}

/// If we have AVX512, but not BWI and this is a vXi16/vXi8 setcc, just
/// pre-promote its result type since vXi1 vectors don't get promoted
/// during type legalization.
Expand Down

0 comments on commit 0f76fb9

Please sign in to comment.