diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 426e3143ac9b2..0f1cb5f1e2366 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -48539,13 +48539,28 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC, } // MOVMSK(SHUFFLE(X,u)) -> MOVMSK(X) iff every element is referenced. - SmallVector ShuffleMask; + // Since we peek through a bitcast, we need to be careful if the base vector + // type has smaller elements than the MOVMSK type. In that case, even if + // all the elements are demanded by the shuffle mask, only the "high" + // elements which have highbits that align with highbits in the MOVMSK vec + // elements are actually demanded. A simplification of spurious operations + // on the "low" elements take place during other simplifications. + // + // For example: + // MOVMSK64(BITCAST(SHUF32 X, (1,0,3,2))) even though all the elements are + // demanded, because we are swapping around the result can change. + // + // To address this, we check that we can scale the shuffle mask to MOVMSK + // element width (this will ensure "high" elements match). Its slightly overly + // conservative, but fine for an edge case fold. + SmallVector ShuffleMask, ScaledMaskUnused; SmallVector ShuffleInputs; if (NumElts <= CmpBits && getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs, ShuffleMask, DAG) && ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) && - ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits()) { + ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits() && + scaleShuffleElements(ShuffleMask, NumElts, ScaledMaskUnused)) { unsigned NumShuffleElts = ShuffleMask.size(); APInt DemandedElts = APInt::getZero(NumShuffleElts); for (int M : ShuffleMask) { diff --git a/llvm/test/CodeGen/X86/movmsk-cmp.ll b/llvm/test/CodeGen/X86/movmsk-cmp.ll index f7ba49ce0e127..278a6a8b128eb 100644 --- a/llvm/test/CodeGen/X86/movmsk-cmp.ll +++ b/llvm/test/CodeGen/X86/movmsk-cmp.ll @@ -4458,13 +4458,14 @@ define i32 @PR39665_c_ray_opt(<2 x double> %x, <2 x double> %y) { define i32 @pr67287(<2 x i64> %broadcast.splatinsert25) { ; SSE2-LABEL: pr67287: ; SSE2: # %bb.0: # %entry -; SSE2-NEXT: movl $3, %eax -; SSE2-NEXT: testl %eax, %eax -; SSE2-NEXT: jne .LBB97_2 -; SSE2-NEXT: # %bb.1: # %entry ; SSE2-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 ; SSE2-NEXT: pxor %xmm1, %xmm1 ; SSE2-NEXT: pcmpeqd %xmm0, %xmm1 +; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[1,0,3,2] +; SSE2-NEXT: movmskpd %xmm0, %eax +; SSE2-NEXT: testl %eax, %eax +; SSE2-NEXT: jne .LBB97_2 +; SSE2-NEXT: # %bb.1: # %entry ; SSE2-NEXT: movd %xmm1, %eax ; SSE2-NEXT: testb $1, %al ; SSE2-NEXT: jne .LBB97_2