Skip to content

Commit

Permalink
[X86] MatchVectorAllZeroTest - fix bug when splitting vectors of larg…
Browse files Browse the repository at this point in the history
…e elements

DAG::SplitVector only works with vectors with even numbers of elements, when splitting vectors with large (illegal) element widths, we are likely to split down to <1 x iXXX>.

In such cases, pre-bitcast to a <X x i64> type to ensure splitting will always succeed.

Thanks to @alexfh for identifying this.
  • Loading branch information
RKSimon committed Apr 6, 2023
1 parent 6bda53c commit b29ec28
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 2 deletions.
18 changes: 16 additions & 2 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24338,12 +24338,12 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,

// Helper function for comparing all bits of two vectors.
static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
ISD::CondCode CC, const APInt &Mask,
ISD::CondCode CC, const APInt &OriginalMask,
const X86Subtarget &Subtarget,
SelectionDAG &DAG, X86::CondCode &X86CC) {
EVT VT = LHS.getValueType();
unsigned ScalarSize = VT.getScalarSizeInBits();
if (Mask.getBitWidth() != ScalarSize) {
if (OriginalMask.getBitWidth() != ScalarSize) {
assert(ScalarSize == 1 && "Element Mask vs Vector bitwidth mismatch");
return SDValue();
}
Expand All @@ -24355,6 +24355,8 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
X86CC = (CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE);

APInt Mask = OriginalMask;

auto MaskBits = [&](SDValue Src) {
if (Mask.isAllOnes())
return Src;
Expand Down Expand Up @@ -24395,6 +24397,18 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,

// Split down to 128/256/512-bit vector.
unsigned TestSize = UseKORTEST ? 512 : (Subtarget.hasAVX() ? 256 : 128);

// If the input vector has vector elements wider than the target test size,
// then cast to <X x i64> so it will safely split.
if (ScalarSize > TestSize) {
if (!Mask.isAllOnes())
return SDValue();
VT = EVT::getVectorVT(*DAG.getContext(), MVT::i64, VT.getSizeInBits() / 64);
LHS = DAG.getBitcast(VT, LHS);
RHS = DAG.getBitcast(VT, RHS);
Mask = APInt::getAllOnes(64);
}

if (VT.getSizeInBits() > TestSize) {
KnownBits KnownRHS = DAG.computeKnownBits(RHS);
if (KnownRHS.isConstant() && KnownRHS.getConstant() == Mask) {
Expand Down
172 changes: 172 additions & 0 deletions llvm/test/CodeGen/X86/setcc-wide-types.ll
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,178 @@ define i32 @eq_i512(<8 x i64> %x, <8 x i64> %y) {
ret i32 %zext
}

define i1 @ne_v4i256(<4 x i256> %a0) {
; SSE2-LABEL: ne_v4i256:
; SSE2: # %bb.0:
; SSE2-NEXT: movq {{[0-9]+}}(%rsp), %rax
; SSE2-NEXT: movq {{[0-9]+}}(%rsp), %r10
; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %r10
; SSE2-NEXT: movq %r10, %xmm0
; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rax
; SSE2-NEXT: movq %rax, %xmm1
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rcx
; SSE2-NEXT: movq %rcx, %xmm0
; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rdx
; SSE2-NEXT: movq %rdx, %xmm2
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm0[0]
; SSE2-NEXT: por %xmm1, %xmm2
; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %r9
; SSE2-NEXT: movq %r9, %xmm0
; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %r8
; SSE2-NEXT: movq %r8, %xmm1
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rsi
; SSE2-NEXT: movq %rsi, %xmm0
; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rdi
; SSE2-NEXT: movq %rdi, %xmm3
; SSE2-NEXT: punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm0[0]
; SSE2-NEXT: por %xmm1, %xmm3
; SSE2-NEXT: por %xmm2, %xmm3
; SSE2-NEXT: pxor %xmm0, %xmm0
; SSE2-NEXT: pcmpeqd %xmm3, %xmm0
; SSE2-NEXT: movmskps %xmm0, %eax
; SSE2-NEXT: xorl $15, %eax
; SSE2-NEXT: sete %al
; SSE2-NEXT: retq
;
; SSE41-LABEL: ne_v4i256:
; SSE41: # %bb.0:
; SSE41-NEXT: movq {{[0-9]+}}(%rsp), %rax
; SSE41-NEXT: movq {{[0-9]+}}(%rsp), %r10
; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %r10
; SSE41-NEXT: movq %r10, %xmm0
; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rax
; SSE41-NEXT: movq %rax, %xmm1
; SSE41-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rcx
; SSE41-NEXT: movq %rcx, %xmm0
; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rdx
; SSE41-NEXT: movq %rdx, %xmm2
; SSE41-NEXT: punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm0[0]
; SSE41-NEXT: por %xmm1, %xmm2
; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %r9
; SSE41-NEXT: movq %r9, %xmm0
; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %r8
; SSE41-NEXT: movq %r8, %xmm1
; SSE41-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rsi
; SSE41-NEXT: movq %rsi, %xmm0
; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rdi
; SSE41-NEXT: movq %rdi, %xmm3
; SSE41-NEXT: punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm0[0]
; SSE41-NEXT: por %xmm1, %xmm3
; SSE41-NEXT: por %xmm2, %xmm3
; SSE41-NEXT: ptest %xmm3, %xmm3
; SSE41-NEXT: sete %al
; SSE41-NEXT: retq
;
; AVX1-LABEL: ne_v4i256:
; AVX1: # %bb.0:
; AVX1-NEXT: movq {{[0-9]+}}(%rsp), %rax
; AVX1-NEXT: movq {{[0-9]+}}(%rsp), %r10
; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %r10
; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rcx
; AVX1-NEXT: orq %r10, %rcx
; AVX1-NEXT: vmovq %rcx, %xmm0
; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rax
; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rdx
; AVX1-NEXT: orq %rax, %rdx
; AVX1-NEXT: vmovq %rdx, %xmm1
; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0]
; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %r9
; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rsi
; AVX1-NEXT: orq %r9, %rsi
; AVX1-NEXT: vmovq %rsi, %xmm1
; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %r8
; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rdi
; AVX1-NEXT: orq %r8, %rdi
; AVX1-NEXT: vmovq %rdi, %xmm2
; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0]
; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm0
; AVX1-NEXT: vptest %ymm0, %ymm0
; AVX1-NEXT: sete %al
; AVX1-NEXT: vzeroupper
; AVX1-NEXT: retq
;
; AVX2-LABEL: ne_v4i256:
; AVX2: # %bb.0:
; AVX2-NEXT: movq {{[0-9]+}}(%rsp), %rax
; AVX2-NEXT: movq {{[0-9]+}}(%rsp), %r10
; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %r10
; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rcx
; AVX2-NEXT: orq %r10, %rcx
; AVX2-NEXT: vmovq %rcx, %xmm0
; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rax
; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rdx
; AVX2-NEXT: orq %rax, %rdx
; AVX2-NEXT: vmovq %rdx, %xmm1
; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0]
; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %r9
; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rsi
; AVX2-NEXT: orq %r9, %rsi
; AVX2-NEXT: vmovq %rsi, %xmm1
; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %r8
; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rdi
; AVX2-NEXT: orq %r8, %rdi
; AVX2-NEXT: vmovq %rdi, %xmm2
; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0]
; AVX2-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm0
; AVX2-NEXT: vptest %ymm0, %ymm0
; AVX2-NEXT: sete %al
; AVX2-NEXT: vzeroupper
; AVX2-NEXT: retq
;
; AVX512-LABEL: ne_v4i256:
; AVX512: # %bb.0:
; AVX512-NEXT: movq {{[0-9]+}}(%rsp), %rax
; AVX512-NEXT: movq {{[0-9]+}}(%rsp), %r10
; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rax
; AVX512-NEXT: vmovd %eax, %xmm0
; AVX512-NEXT: shrq $32, %rax
; AVX512-NEXT: vpinsrd $1, %eax, %xmm0, %xmm0
; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %r10
; AVX512-NEXT: vpinsrd $2, %r10d, %xmm0, %xmm0
; AVX512-NEXT: shrq $32, %r10
; AVX512-NEXT: vpinsrd $3, %r10d, %xmm0, %xmm0
; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %r8
; AVX512-NEXT: vmovd %r8d, %xmm1
; AVX512-NEXT: shrq $32, %r8
; AVX512-NEXT: vpinsrd $1, %r8d, %xmm1, %xmm1
; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %r9
; AVX512-NEXT: vpinsrd $2, %r9d, %xmm1, %xmm1
; AVX512-NEXT: shrq $32, %r9
; AVX512-NEXT: vpinsrd $3, %r9d, %xmm1, %xmm1
; AVX512-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm0
; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rdx
; AVX512-NEXT: vmovd %edx, %xmm1
; AVX512-NEXT: shrq $32, %rdx
; AVX512-NEXT: vpinsrd $1, %edx, %xmm1, %xmm1
; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rcx
; AVX512-NEXT: vpinsrd $2, %ecx, %xmm1, %xmm1
; AVX512-NEXT: shrq $32, %rcx
; AVX512-NEXT: vpinsrd $3, %ecx, %xmm1, %xmm1
; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rdi
; AVX512-NEXT: vmovd %edi, %xmm2
; AVX512-NEXT: shrq $32, %rdi
; AVX512-NEXT: vpinsrd $1, %edi, %xmm2, %xmm2
; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rsi
; AVX512-NEXT: vpinsrd $2, %esi, %xmm2, %xmm2
; AVX512-NEXT: shrq $32, %rsi
; AVX512-NEXT: vpinsrd $3, %esi, %xmm2, %xmm2
; AVX512-NEXT: vinserti128 $1, %xmm1, %ymm2, %ymm1
; AVX512-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k0
; AVX512-NEXT: kortestw %k0, %k0
; AVX512-NEXT: sete %al
; AVX512-NEXT: vzeroupper
; AVX512-NEXT: retq
%c = icmp ne <4 x i256> %a0, zeroinitializer
%b = bitcast <4 x i1> %c to i4
%r = icmp eq i4 %b, 0
ret i1 %r
}

; This test models the expansion of 'memcmp(a, b, 32) != 0'
; if we allowed 2 pairs of 16-byte loads per block.

Expand Down

0 comments on commit b29ec28

Please sign in to comment.