diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 4b1d1a20777af0..ead00a9d2015a1 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -37889,6 +37889,74 @@ static SDValue createMMXBuildVector(BuildVectorSDNode *BV, SelectionDAG &DAG, return Ops[0]; } +// Recursive function that attempts to find if a bool vector node was originally +// a vector/float/double that got truncated/extended/bitcast to/from a scalar +// integer. If so, replace the scalar ops with bool vector equivalents back down +// the chain. +static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, SDLoc DL, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + unsigned Opc = V.getOpcode(); + switch (Opc) { + case ISD::BITCAST: { + // Bitcast from a vector/float/double, we can cheaply bitcast to VT. + SDValue Src = V.getOperand(0); + EVT SrcVT = Src.getValueType(); + if (SrcVT.isVector() || SrcVT.isFloatingPoint()) + return DAG.getBitcast(VT, Src); + break; + } + case ISD::TRUNCATE: { + // If we find a suitable source, a truncated scalar becomes a subvector. + SDValue Src = V.getOperand(0); + EVT NewSrcVT = + EVT::getVectorVT(*DAG.getContext(), MVT::i1, Src.getValueSizeInBits()); + if (TLI.isTypeLegal(NewSrcVT)) + if (SDValue N0 = + combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget)) + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, N0, + DAG.getIntPtrConstant(0, DL)); + break; + } + case ISD::ANY_EXTEND: + case ISD::ZERO_EXTEND: { + // If we find a suitable source, an extended scalar becomes a subvector. + SDValue Src = V.getOperand(0); + EVT NewSrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, + Src.getScalarValueSizeInBits()); + if (TLI.isTypeLegal(NewSrcVT)) + if (SDValue N0 = + combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget)) + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, + Opc == ISD::ANY_EXTEND ? DAG.getUNDEF(VT) + : DAG.getConstant(0, DL, VT), + N0, DAG.getIntPtrConstant(0, DL)); + break; + } + case ISD::OR: { + // If we find suitable sources, we can just move an OR to the vector domain. + SDValue Src0 = V.getOperand(0); + SDValue Src1 = V.getOperand(1); + if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget)) + if (SDValue N1 = combineBitcastToBoolVector(VT, Src1, DL, DAG, Subtarget)) + return DAG.getNode(Opc, DL, VT, N0, N1); + break; + } + case ISD::SHL: { + // If we find a suitable source, a SHL becomes a KSHIFTL. + SDValue Src0 = V.getOperand(0); + if (auto *Amt = dyn_cast(V.getOperand(1))) + if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget)) + return DAG.getNode( + X86ISD::KSHIFTL, DL, VT, N0, + DAG.getTargetConstant(Amt->getZExtValue(), DL, MVT::i8)); + break; + } + } + return SDValue(); +} + static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -37948,6 +38016,16 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, N0 = DAG.getBitcast(MVT::i8, N0); return DAG.getNode(ISD::TRUNCATE, dl, VT, N0); } + } else { + // If we're bitcasting from iX to vXi1, see if the integer originally + // began as a vXi1 and whether we can remove the bitcast entirely. + if (VT.isVector() && VT.getScalarType() == MVT::i1 && + SrcVT.isScalarInteger() && + DAG.getTargetLoweringInfo().isTypeLegal(VT)) { + if (SDValue V = + combineBitcastToBoolVector(VT, N0, SDLoc(N), DAG, Subtarget)) + return V; + } } // Look for (i8 (bitcast (v8i1 (extract_subvector (v16i1 X), 0)))) and diff --git a/llvm/test/CodeGen/X86/avx512-intrinsics.ll b/llvm/test/CodeGen/X86/avx512-intrinsics.ll index 2e5dc1e69c8a3e..a8222b1edab0e2 100644 --- a/llvm/test/CodeGen/X86/avx512-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512-intrinsics.ll @@ -7496,13 +7496,7 @@ define <16 x float> @bad_mask_transition(<8 x double> %a, <8 x double> %b, <8 x ; X64-LABEL: bad_mask_transition: ; X64: # %bb.0: # %entry ; X64-NEXT: vcmplt_oqpd %zmm1, %zmm0, %k0 -; X64-NEXT: kmovw %k0, %eax -; X64-NEXT: vcmplt_oqpd %zmm3, %zmm2, %k0 -; X64-NEXT: kmovw %k0, %ecx -; X64-NEXT: movzbl %al, %eax -; X64-NEXT: movzbl %cl, %ecx -; X64-NEXT: kmovw %eax, %k0 -; X64-NEXT: kmovw %ecx, %k1 +; X64-NEXT: vcmplt_oqpd %zmm3, %zmm2, %k1 ; X64-NEXT: kunpckbw %k0, %k1, %k1 ; X64-NEXT: vblendmps %zmm5, %zmm4, %zmm0 {%k1} ; X64-NEXT: retq @@ -7518,13 +7512,7 @@ define <16 x float> @bad_mask_transition(<8 x double> %a, <8 x double> %b, <8 x ; X86-NEXT: subl $64, %esp ; X86-NEXT: vmovaps 72(%ebp), %zmm3 ; X86-NEXT: vcmplt_oqpd %zmm1, %zmm0, %k0 -; X86-NEXT: kmovw %k0, %eax -; X86-NEXT: vcmplt_oqpd 8(%ebp), %zmm2, %k0 -; X86-NEXT: kmovw %k0, %ecx -; X86-NEXT: movzbl %al, %eax -; X86-NEXT: movzbl %cl, %ecx -; X86-NEXT: kmovw %eax, %k0 -; X86-NEXT: kmovw %ecx, %k1 +; X86-NEXT: vcmplt_oqpd 8(%ebp), %zmm2, %k1 ; X86-NEXT: kunpckbw %k0, %k1, %k1 ; X86-NEXT: vmovaps 136(%ebp), %zmm3 {%k1} ; X86-NEXT: vmovaps %zmm3, %zmm0 @@ -7551,10 +7539,7 @@ entry: define <16 x float> @bad_mask_transition_2(<8 x double> %a, <8 x double> %b, <8 x double> %c, <8 x double> %d, <16 x float> %e, <16 x float> %f) { ; X64-LABEL: bad_mask_transition_2: ; X64: # %bb.0: # %entry -; X64-NEXT: vcmplt_oqpd %zmm1, %zmm0, %k0 -; X64-NEXT: kmovw %k0, %eax -; X64-NEXT: movzbl %al, %eax -; X64-NEXT: kmovw %eax, %k1 +; X64-NEXT: vcmplt_oqpd %zmm1, %zmm0, %k1 ; X64-NEXT: vblendmps %zmm5, %zmm4, %zmm0 {%k1} ; X64-NEXT: retq ; @@ -7568,10 +7553,7 @@ define <16 x float> @bad_mask_transition_2(<8 x double> %a, <8 x double> %b, <8 ; X86-NEXT: andl $-64, %esp ; X86-NEXT: subl $64, %esp ; X86-NEXT: vmovaps 72(%ebp), %zmm2 -; X86-NEXT: vcmplt_oqpd %zmm1, %zmm0, %k0 -; X86-NEXT: kmovw %k0, %eax -; X86-NEXT: movzbl %al, %eax -; X86-NEXT: kmovw %eax, %k1 +; X86-NEXT: vcmplt_oqpd %zmm1, %zmm0, %k1 ; X86-NEXT: vmovaps 136(%ebp), %zmm2 {%k1} ; X86-NEXT: vmovaps %zmm2, %zmm0 ; X86-NEXT: movl %ebp, %esp diff --git a/llvm/test/CodeGen/X86/pr41619.ll b/llvm/test/CodeGen/X86/pr41619.ll index 13bfd910587c2c..87c62927090370 100644 --- a/llvm/test/CodeGen/X86/pr41619.ll +++ b/llvm/test/CodeGen/X86/pr41619.ll @@ -44,8 +44,6 @@ define i32 @bar(double %blah) nounwind { ; AVX512-LABEL: bar: ; AVX512: ## %bb.0: ; AVX512-NEXT: vmovq %xmm0, %rax -; AVX512-NEXT: kmovd %eax, %k0 -; AVX512-NEXT: kmovq %k0, %rax ; AVX512-NEXT: ## kill: def $eax killed $eax killed $rax ; AVX512-NEXT: retq %z = bitcast double %blah to i64 diff --git a/llvm/test/CodeGen/X86/vector-shuffle-v1.ll b/llvm/test/CodeGen/X86/vector-shuffle-v1.ll index 782303e97b1253..c2c5eafb9cf98a 100644 --- a/llvm/test/CodeGen/X86/vector-shuffle-v1.ll +++ b/llvm/test/CodeGen/X86/vector-shuffle-v1.ll @@ -891,12 +891,10 @@ define void @PR32547(<8 x float> %a, <8 x float> %b, <8 x float> %c, <8 x float> ; AVX512F-NEXT: # kill: def $ymm0 killed $ymm0 def $zmm0 ; AVX512F-NEXT: vcmpltps %zmm1, %zmm0, %k0 ; AVX512F-NEXT: vcmpltps %zmm3, %zmm2, %k1 -; AVX512F-NEXT: kmovw %k1, %eax -; AVX512F-NEXT: kmovw %k0, %ecx -; AVX512F-NEXT: movzbl %al, %eax -; AVX512F-NEXT: shll $8, %ecx -; AVX512F-NEXT: orl %eax, %ecx -; AVX512F-NEXT: kmovw %ecx, %k1 +; AVX512F-NEXT: kshiftlw $8, %k0, %k0 +; AVX512F-NEXT: kshiftlw $8, %k1, %k1 +; AVX512F-NEXT: kshiftrw $8, %k1, %k1 +; AVX512F-NEXT: korw %k1, %k0, %k1 ; AVX512F-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; AVX512F-NEXT: vmovaps %zmm0, (%rdi) {%k1} ; AVX512F-NEXT: vzeroupper @@ -906,12 +904,8 @@ define void @PR32547(<8 x float> %a, <8 x float> %b, <8 x float> %c, <8 x float> ; AVX512VL: # %bb.0: # %entry ; AVX512VL-NEXT: vcmpltps %ymm1, %ymm0, %k0 ; AVX512VL-NEXT: vcmpltps %ymm3, %ymm2, %k1 -; AVX512VL-NEXT: kmovw %k1, %eax -; AVX512VL-NEXT: kmovw %k0, %ecx -; AVX512VL-NEXT: movzbl %al, %eax -; AVX512VL-NEXT: shll $8, %ecx -; AVX512VL-NEXT: orl %eax, %ecx -; AVX512VL-NEXT: kmovw %ecx, %k1 +; AVX512VL-NEXT: kshiftlw $8, %k0, %k0 +; AVX512VL-NEXT: korw %k1, %k0, %k1 ; AVX512VL-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; AVX512VL-NEXT: vmovaps %zmm0, (%rdi) {%k1} ; AVX512VL-NEXT: vzeroupper @@ -921,11 +915,8 @@ define void @PR32547(<8 x float> %a, <8 x float> %b, <8 x float> %c, <8 x float> ; VL_BW_DQ: # %bb.0: # %entry ; VL_BW_DQ-NEXT: vcmpltps %ymm1, %ymm0, %k0 ; VL_BW_DQ-NEXT: vcmpltps %ymm3, %ymm2, %k1 -; VL_BW_DQ-NEXT: kmovd %k0, %eax -; VL_BW_DQ-NEXT: kmovb %k1, %ecx -; VL_BW_DQ-NEXT: shll $8, %eax -; VL_BW_DQ-NEXT: orl %ecx, %eax -; VL_BW_DQ-NEXT: kmovd %eax, %k1 +; VL_BW_DQ-NEXT: kshiftlw $8, %k0, %k0 +; VL_BW_DQ-NEXT: korw %k1, %k0, %k1 ; VL_BW_DQ-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; VL_BW_DQ-NEXT: vmovaps %zmm0, (%rdi) {%k1} ; VL_BW_DQ-NEXT: vzeroupper