diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 71e643e8c93702..aafbe7b716c5ba 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -44324,6 +44324,8 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, SDLoc dl(InputVector); bool IsPextr = N->getOpcode() != ISD::EXTRACT_VECTOR_ELT; unsigned NumSrcElts = SrcVT.getVectorNumElements(); + unsigned NumEltBits = VT.getScalarSizeInBits(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (CIdx && CIdx->getAPIntValue().uge(NumSrcElts)) return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT); @@ -44338,15 +44340,26 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, uint64_t Idx = CIdx->getZExtValue(); if (UndefVecElts[Idx]) return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT); - return DAG.getConstant(EltBits[Idx].zext(VT.getScalarSizeInBits()), dl, - VT); + return DAG.getConstant(EltBits[Idx].zext(NumEltBits), dl, VT); + } + + // Convert extract_element(bitcast() -> bitcast(extract_subvector()). + // Improves lowering of bool masks on rust which splits them into byte array. + if (InputVector.getOpcode() == ISD::BITCAST && (NumEltBits % 8) == 0) { + SDValue Src = peekThroughBitcasts(InputVector); + if (Src.getValueType().getScalarType() == MVT::i1 && + TLI.isTypeLegal(Src.getValueType())) { + MVT SubVT = MVT::getVectorVT(MVT::i1, NumEltBits); + SDValue Sub = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, Src, + DAG.getIntPtrConstant(CIdx->getZExtValue() * NumEltBits, dl)); + return DAG.getBitcast(VT, Sub); + } } } if (IsPextr) { - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (TLI.SimplifyDemandedBits(SDValue(N, 0), - APInt::getAllOnes(VT.getSizeInBits()), DCI)) + if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits), + DCI)) return SDValue(N, 0); // PEXTR*(PINSR*(v, s, c), c) -> s (with implicit zext handling). diff --git a/llvm/test/CodeGen/X86/bitcast-vector-bool.ll b/llvm/test/CodeGen/X86/bitcast-vector-bool.ll index 94901e665d566e..de132c1c7ef436 100644 --- a/llvm/test/CodeGen/X86/bitcast-vector-bool.ll +++ b/llvm/test/CodeGen/X86/bitcast-vector-bool.ll @@ -123,14 +123,24 @@ define i8 @bitcast_v16i8_to_v2i8(<16 x i8> %a0) nounwind { ; SSE2-SSSE3-NEXT: addb -{{[0-9]+}}(%rsp), %al ; SSE2-SSSE3-NEXT: retq ; -; AVX-LABEL: bitcast_v16i8_to_v2i8: -; AVX: # %bb.0: -; AVX-NEXT: vpmovmskb %xmm0, %ecx -; AVX-NEXT: movl %ecx, %eax -; AVX-NEXT: shrl $8, %eax -; AVX-NEXT: addb %cl, %al -; AVX-NEXT: # kill: def $al killed $al killed $eax -; AVX-NEXT: retq +; AVX12-LABEL: bitcast_v16i8_to_v2i8: +; AVX12: # %bb.0: +; AVX12-NEXT: vpmovmskb %xmm0, %ecx +; AVX12-NEXT: movl %ecx, %eax +; AVX12-NEXT: shrl $8, %eax +; AVX12-NEXT: addb %cl, %al +; AVX12-NEXT: # kill: def $al killed $al killed $eax +; AVX12-NEXT: retq +; +; AVX512-LABEL: bitcast_v16i8_to_v2i8: +; AVX512: # %bb.0: +; AVX512-NEXT: vpmovb2m %xmm0, %k0 +; AVX512-NEXT: kshiftrw $8, %k0, %k1 +; AVX512-NEXT: kmovd %k0, %ecx +; AVX512-NEXT: kmovd %k1, %eax +; AVX512-NEXT: addb %cl, %al +; AVX512-NEXT: # kill: def $al killed $al killed $eax +; AVX512-NEXT: retq %1 = icmp slt <16 x i8> %a0, zeroinitializer %2 = bitcast <16 x i1> %1 to <2 x i8> %3 = extractelement <2 x i8> %2, i32 0 @@ -242,10 +252,9 @@ define i8 @bitcast_v16i16_to_v2i8(<16 x i16> %a0) nounwind { ; AVX512-LABEL: bitcast_v16i16_to_v2i8: ; AVX512: # %bb.0: ; AVX512-NEXT: vpmovw2m %ymm0, %k0 -; AVX512-NEXT: kmovw %k0, -{{[0-9]+}}(%rsp) -; AVX512-NEXT: vmovdqa -{{[0-9]+}}(%rsp), %xmm0 -; AVX512-NEXT: vmovd %xmm0, %ecx -; AVX512-NEXT: vpextrb $1, %xmm0, %eax +; AVX512-NEXT: kshiftrw $8, %k0, %k1 +; AVX512-NEXT: kmovd %k0, %ecx +; AVX512-NEXT: kmovd %k1, %eax ; AVX512-NEXT: addb %cl, %al ; AVX512-NEXT: # kill: def $al killed $al killed $eax ; AVX512-NEXT: vzeroupper @@ -289,9 +298,10 @@ define i16 @bitcast_v32i8_to_v2i16(<32 x i8> %a0) nounwind { ; ; AVX512-LABEL: bitcast_v32i8_to_v2i16: ; AVX512: # %bb.0: -; AVX512-NEXT: vpmovmskb %ymm0, %ecx -; AVX512-NEXT: movl %ecx, %eax -; AVX512-NEXT: shrl $16, %eax +; AVX512-NEXT: vpmovb2m %ymm0, %k0 +; AVX512-NEXT: kshiftrd $16, %k0, %k1 +; AVX512-NEXT: kmovd %k0, %ecx +; AVX512-NEXT: kmovd %k1, %eax ; AVX512-NEXT: addl %ecx, %eax ; AVX512-NEXT: # kill: def $ax killed $ax killed $eax ; AVX512-NEXT: vzeroupper @@ -424,10 +434,9 @@ define i8 @bitcast_v16i32_to_v2i8(<16 x i32> %a0) nounwind { ; AVX512: # %bb.0: ; AVX512-NEXT: vpxor %xmm1, %xmm1, %xmm1 ; AVX512-NEXT: vpcmpgtd %zmm0, %zmm1, %k0 -; AVX512-NEXT: kmovw %k0, -{{[0-9]+}}(%rsp) -; AVX512-NEXT: vmovdqa -{{[0-9]+}}(%rsp), %xmm0 -; AVX512-NEXT: vmovd %xmm0, %ecx -; AVX512-NEXT: vpextrb $1, %xmm0, %eax +; AVX512-NEXT: kshiftrw $8, %k0, %k1 +; AVX512-NEXT: kmovd %k0, %ecx +; AVX512-NEXT: kmovd %k1, %eax ; AVX512-NEXT: addb %cl, %al ; AVX512-NEXT: # kill: def $al killed $al killed $eax ; AVX512-NEXT: vzeroupper @@ -479,10 +488,9 @@ define i16 @bitcast_v32i16_to_v2i16(<32 x i16> %a0) nounwind { ; AVX512-LABEL: bitcast_v32i16_to_v2i16: ; AVX512: # %bb.0: ; AVX512-NEXT: vpmovw2m %zmm0, %k0 -; AVX512-NEXT: kmovd %k0, -{{[0-9]+}}(%rsp) -; AVX512-NEXT: vmovdqa -{{[0-9]+}}(%rsp), %xmm0 -; AVX512-NEXT: vmovd %xmm0, %ecx -; AVX512-NEXT: vpextrw $1, %xmm0, %eax +; AVX512-NEXT: kshiftrd $16, %k0, %k1 +; AVX512-NEXT: kmovd %k0, %ecx +; AVX512-NEXT: kmovd %k1, %eax ; AVX512-NEXT: addl %ecx, %eax ; AVX512-NEXT: # kill: def $ax killed $ax killed $eax ; AVX512-NEXT: vzeroupper @@ -541,9 +549,10 @@ define i32 @bitcast_v64i8_to_v2i32(<64 x i8> %a0) nounwind { ; AVX512-LABEL: bitcast_v64i8_to_v2i32: ; AVX512: # %bb.0: ; AVX512-NEXT: vpmovb2m %zmm0, %k0 -; AVX512-NEXT: kmovq %k0, -{{[0-9]+}}(%rsp) -; AVX512-NEXT: movl -{{[0-9]+}}(%rsp), %eax -; AVX512-NEXT: addl -{{[0-9]+}}(%rsp), %eax +; AVX512-NEXT: kshiftrq $32, %k0, %k1 +; AVX512-NEXT: kmovd %k0, %ecx +; AVX512-NEXT: kmovd %k1, %eax +; AVX512-NEXT: addl %ecx, %eax ; AVX512-NEXT: vzeroupper ; AVX512-NEXT: retq %1 = icmp slt <64 x i8> %a0, zeroinitializer @@ -698,10 +707,9 @@ define [2 x i8] @PR58546(<16 x float> %a0) { ; AVX512: # %bb.0: ; AVX512-NEXT: vxorps %xmm1, %xmm1, %xmm1 ; AVX512-NEXT: vcmpunordps %zmm1, %zmm0, %k0 -; AVX512-NEXT: kmovw %k0, -{{[0-9]+}}(%rsp) -; AVX512-NEXT: vmovdqa -{{[0-9]+}}(%rsp), %xmm0 -; AVX512-NEXT: vmovd %xmm0, %eax -; AVX512-NEXT: vpextrb $1, %xmm0, %edx +; AVX512-NEXT: kshiftrw $8, %k0, %k1 +; AVX512-NEXT: kmovd %k0, %eax +; AVX512-NEXT: kmovd %k1, %edx ; AVX512-NEXT: # kill: def $al killed $al killed $eax ; AVX512-NEXT: # kill: def $dl killed $dl killed $edx ; AVX512-NEXT: vzeroupper