Skip to content

Commit

Permalink
[X86] MatchVectorAllZeroTest - handle icmp_eq(bitcast(vXi1 trunc(Y)),…
Browse files Browse the repository at this point in the history
…0) style reduction patterns

If we've truncated from a wider vector, then perform the all vector comparison on that with a suitable mask

There's a minor pre-SSE41 regression due to a missing movmsk(icmp_eq(and(x,c1pow2),c1pow2)) -> movmsk(shl(x,c2)) fold that will be addressed in a followup commit
  • Loading branch information
RKSimon committed Apr 3, 2023
1 parent 9b5ff44 commit 39d7bf6
Show file tree
Hide file tree
Showing 2 changed files with 280 additions and 340 deletions.
41 changes: 29 additions & 12 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24396,8 +24396,9 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
// Split down to 128/256/512-bit vector.
unsigned TestSize = UseKORTEST ? 512 : (Subtarget.hasAVX() ? 256 : 128);
if (VT.getSizeInBits() > TestSize) {
if (isAllOnesOrAllOnesSplat(RHS)) {
// If ICMP(LHS,-1) - reduce using AND splits.
KnownBits KnownRHS = DAG.computeKnownBits(RHS);
if (KnownRHS.isConstant() && KnownRHS.getConstant() == Mask) {
// If ICMP(AND(LHS,MASK),MASK) - reduce using AND splits.
while (VT.getSizeInBits() > TestSize) {
auto Split = DAG.SplitVector(LHS, DL);
VT = Split.first.getValueType();
Expand Down Expand Up @@ -24529,24 +24530,40 @@ static SDValue MatchVectorAllEqualTest(SDValue LHS, SDValue RHS,
}
}

// Match icmp(bitcast(icmp_ne(X,Y)),0) reduction patterns.
// Match icmp(bitcast(icmp_eq(X,Y)),-1) reduction patterns.
if (Mask.isAllOnes()) {
assert(!Op.getValueType().isVector() &&
"Illegal vector type for reduction pattern");
SDValue Src = peekThroughBitcasts(Op);
if (Src.getOpcode() == ISD::SETCC &&
Src.getValueType().isFixedLengthVector() &&
if (Src.getValueType().isFixedLengthVector() &&
Src.getValueType().getScalarType() == MVT::i1) {
ISD::CondCode SrcCC = cast<CondCodeSDNode>(Src.getOperand(2))->get();
if (SrcCC == (CmpNull ? ISD::SETNE : ISD::SETEQ)) {
// Match icmp(bitcast(icmp_ne(X,Y)),0) reduction patterns.
// Match icmp(bitcast(icmp_eq(X,Y)),-1) reduction patterns.
if (Src.getOpcode() == ISD::SETCC) {
SDValue LHS = Src.getOperand(0);
SDValue RHS = Src.getOperand(1);
EVT LHSVT = LHS.getValueType();
APInt SrcMask = APInt::getAllOnes(LHSVT.getScalarSizeInBits());
if (SDValue V = LowerVectorAllEqual(DL, LHS, RHS, CC, SrcMask,
Subtarget, DAG, X86CC))
return V;
ISD::CondCode SrcCC = cast<CondCodeSDNode>(Src.getOperand(2))->get();
if (SrcCC == (CmpNull ? ISD::SETNE : ISD::SETEQ) &&
llvm::has_single_bit<uint32_t>(LHSVT.getSizeInBits())) {
APInt SrcMask = APInt::getAllOnes(LHSVT.getScalarSizeInBits());
return LowerVectorAllEqual(DL, LHS, RHS, CC, SrcMask, Subtarget, DAG,
X86CC);
}
}
// Match icmp(bitcast(vXi1 trunc(Y)),0) reduction patterns.
// Match icmp(bitcast(vXi1 trunc(Y)),-1) reduction patterns.
// Peek through truncation, mask the LSB and compare against zero/LSB.
if (Src.getOpcode() == ISD::TRUNCATE) {
SDValue Inner = Src.getOperand(0);
EVT InnerVT = Inner.getValueType();
if (llvm::has_single_bit<uint32_t>(InnerVT.getSizeInBits())) {
unsigned BW = InnerVT.getScalarSizeInBits();
APInt SrcMask = APInt(BW, 1);
APInt Cmp = CmpNull ? APInt::getZero(BW) : SrcMask;
return LowerVectorAllEqual(DL, Inner,
DAG.getConstant(Cmp, DL, InnerVT), CC,
SrcMask, Subtarget, DAG, X86CC);
}
}
}
}
Expand Down
Loading

0 comments on commit 39d7bf6

Please sign in to comment.