-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[X86][ISel] Fix logic for optimizing movmsk(bitcast(shuffle(x)))
#68369
Conversation
@llvm/pr-subscribers-backend-x86 Changes
Full diff: https://github.com/llvm/llvm-project/pull/68369.diff 3 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c4cd2a672fe7b26..dab823e71aa9367 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -45836,18 +45836,52 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
// MOVMSK(SHUFFLE(X,u)) -> MOVMSK(X) iff every element is referenced.
SmallVector<int, 32> ShuffleMask;
SmallVector<SDValue, 2> ShuffleInputs;
+ SDValue BaseVec = peekThroughBitcasts(Vec);
if (NumElts <= CmpBits &&
- getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs,
- ShuffleMask, DAG) &&
+ getTargetShuffleInputs(BaseVec, ShuffleInputs, ShuffleMask, DAG) &&
ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) &&
ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits()) {
unsigned NumShuffleElts = ShuffleMask.size();
- APInt DemandedElts = APInt::getZero(NumShuffleElts);
- for (int M : ShuffleMask) {
- assert(0 <= M && M < (int)NumShuffleElts && "Bad unary shuffle index");
- DemandedElts.setBit(M);
+
+ APInt Result = APInt::getZero(NumShuffleElts);
+ APInt ImportantLocs;
+ // Since we peek through a bitcast, we need to be careful if the base vector
+ // type has smaller elements than the MOVMSK type. In that case, even if
+ // all the elements are demanded by the shuffle mask, only the "high"
+ // elements which have highbits that align with highbits in the MOVMSK vec
+ // elements are actually demanded. A simplification of spurious operations
+ // on the "low" elements take place during other simplifications.
+ //
+ // For example:
+ // MOVMSK64(BITCAST(SHUF32 X, (1,0,3,2))) even though all the elements are
+ // demanded, because we are swapping around the result can change.
+ //
+ // To address this, we need to make sure all the "high" elements are moved
+ // to other "high" locations.
+ MVT BaseVT = BaseVec.getSimpleValueType();
+ unsigned BaseNumElts = BaseVT.getVectorNumElements();
+ if (BaseNumElts > NumElts) {
+ ImportantLocs = APInt::getZero(NumShuffleElts);
+ assert((BaseNumElts % NumElts) == 0 &&
+ "Vec with unsupported element size");
+ unsigned Scale = BaseNumElts / NumElts;
+ for (unsigned i = 0; i < BaseNumElts; ++i) {
+ if ((i % Scale) == (Scale - 1))
+ ImportantLocs.setBit(i);
+ }
+ } else {
+ ImportantLocs = APInt::getAllOnes(NumShuffleElts);
+ }
+
+ for (unsigned ShufSrc = 0; ShufSrc < ShuffleMask.size(); ++ShufSrc) {
+ int ShufDst = ShuffleMask[ShufSrc];
+ assert(0 <= ShufDst && ShufDst < (int)NumShuffleElts &&
+ "Bad unary shuffle index");
+ if (ImportantLocs[ShufSrc] && ImportantLocs[ShufDst])
+ Result.setBit(ShufSrc);
}
- if (DemandedElts.isAllOnes()) {
+
+ if (Result == ImportantLocs) {
SDLoc DL(EFLAGS);
SDValue Result = DAG.getBitcast(VecVT, ShuffleInputs[0]);
Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
diff --git a/llvm/test/CodeGen/X86/combine-ptest.ll b/llvm/test/CodeGen/X86/combine-ptest.ll
index 337edef96beee2c..e5854564251a3e8 100644
--- a/llvm/test/CodeGen/X86/combine-ptest.ll
+++ b/llvm/test/CodeGen/X86/combine-ptest.ll
@@ -265,7 +265,6 @@ define i32 @ptestz_v2i64_signbits(<2 x i64> %c, i32 %a, i32 %b) {
; SSE41-LABEL: ptestz_v2i64_signbits:
; SSE41: # %bb.0:
; SSE41-NEXT: movl %edi, %eax
-; SSE41-NEXT: pshufd {{.*#+}} xmm0 = xmm0[1,1,3,3]
; SSE41-NEXT: movmskps %xmm0, %ecx
; SSE41-NEXT: testl %ecx, %ecx
; SSE41-NEXT: cmovnel %esi, %eax
diff --git a/llvm/test/CodeGen/X86/movmsk-cmp.ll b/llvm/test/CodeGen/X86/movmsk-cmp.ll
index a0901e265f5ae97..f26bbb7e5c2bdac 100644
--- a/llvm/test/CodeGen/X86/movmsk-cmp.ll
+++ b/llvm/test/CodeGen/X86/movmsk-cmp.ll
@@ -4430,3 +4430,141 @@ define i32 @PR39665_c_ray_opt(<2 x double> %x, <2 x double> %y) {
%r = select i1 %u, i32 42, i32 99
ret i32 %r
}
+
+define i32 @pr67287(<2 x i64> %broadcast.splatinsert25) {
+; SSE2-LABEL: pr67287:
+; SSE2: # %bb.0: # %entry
+; SSE2-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE2-NEXT: pxor %xmm1, %xmm1
+; SSE2-NEXT: pcmpeqd %xmm0, %xmm1
+; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[1,0,3,2]
+; SSE2-NEXT: movmskpd %xmm0, %eax
+; SSE2-NEXT: testl %eax, %eax
+; SSE2-NEXT: jne .LBB97_2
+; SSE2-NEXT: # %bb.1: # %entry
+; SSE2-NEXT: movd %xmm1, %eax
+; SSE2-NEXT: testb $1, %al
+; SSE2-NEXT: jne .LBB97_2
+; SSE2-NEXT: # %bb.3: # %middle.block
+; SSE2-NEXT: xorl %eax, %eax
+; SSE2-NEXT: retq
+; SSE2-NEXT: .LBB97_2:
+; SSE2-NEXT: movw $0, 0
+; SSE2-NEXT: xorl %eax, %eax
+; SSE2-NEXT: retq
+;
+; SSE41-LABEL: pr67287:
+; SSE41: # %bb.0: # %entry
+; SSE41-NEXT: pxor %xmm1, %xmm1
+; SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
+; SSE41-NEXT: pcmpeqq %xmm1, %xmm0
+; SSE41-NEXT: movmskpd %xmm0, %eax
+; SSE41-NEXT: testl %eax, %eax
+; SSE41-NEXT: jne .LBB97_2
+; SSE41-NEXT: # %bb.1: # %entry
+; SSE41-NEXT: movd %xmm0, %eax
+; SSE41-NEXT: testb $1, %al
+; SSE41-NEXT: jne .LBB97_2
+; SSE41-NEXT: # %bb.3: # %middle.block
+; SSE41-NEXT: xorl %eax, %eax
+; SSE41-NEXT: retq
+; SSE41-NEXT: .LBB97_2:
+; SSE41-NEXT: movw $0, 0
+; SSE41-NEXT: xorl %eax, %eax
+; SSE41-NEXT: retq
+;
+; AVX1-LABEL: pr67287:
+; AVX1: # %bb.0: # %entry
+; AVX1-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; AVX1-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
+; AVX1-NEXT: vpcmpeqq %xmm1, %xmm0, %xmm0
+; AVX1-NEXT: vtestpd %xmm0, %xmm0
+; AVX1-NEXT: jne .LBB97_2
+; AVX1-NEXT: # %bb.1: # %entry
+; AVX1-NEXT: vmovd %xmm0, %eax
+; AVX1-NEXT: testb $1, %al
+; AVX1-NEXT: jne .LBB97_2
+; AVX1-NEXT: # %bb.3: # %middle.block
+; AVX1-NEXT: xorl %eax, %eax
+; AVX1-NEXT: retq
+; AVX1-NEXT: .LBB97_2:
+; AVX1-NEXT: movw $0, 0
+; AVX1-NEXT: xorl %eax, %eax
+; AVX1-NEXT: retq
+;
+; AVX2-LABEL: pr67287:
+; AVX2: # %bb.0: # %entry
+; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; AVX2-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; AVX2-NEXT: vpcmpeqq %xmm1, %xmm0, %xmm0
+; AVX2-NEXT: vtestpd %xmm0, %xmm0
+; AVX2-NEXT: jne .LBB97_2
+; AVX2-NEXT: # %bb.1: # %entry
+; AVX2-NEXT: vmovd %xmm0, %eax
+; AVX2-NEXT: testb $1, %al
+; AVX2-NEXT: jne .LBB97_2
+; AVX2-NEXT: # %bb.3: # %middle.block
+; AVX2-NEXT: xorl %eax, %eax
+; AVX2-NEXT: retq
+; AVX2-NEXT: .LBB97_2:
+; AVX2-NEXT: movw $0, 0
+; AVX2-NEXT: xorl %eax, %eax
+; AVX2-NEXT: retq
+;
+; KNL-LABEL: pr67287:
+; KNL: # %bb.0: # %entry
+; KNL-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; KNL-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; KNL-NEXT: vptestnmq %zmm0, %zmm0, %k0
+; KNL-NEXT: kmovw %k0, %eax
+; KNL-NEXT: testb $3, %al
+; KNL-NEXT: jne .LBB97_2
+; KNL-NEXT: # %bb.1: # %entry
+; KNL-NEXT: kmovw %k0, %eax
+; KNL-NEXT: testb $1, %al
+; KNL-NEXT: jne .LBB97_2
+; KNL-NEXT: # %bb.3: # %middle.block
+; KNL-NEXT: xorl %eax, %eax
+; KNL-NEXT: vzeroupper
+; KNL-NEXT: retq
+; KNL-NEXT: .LBB97_2:
+; KNL-NEXT: movw $0, 0
+; KNL-NEXT: xorl %eax, %eax
+; KNL-NEXT: vzeroupper
+; KNL-NEXT: retq
+;
+; SKX-LABEL: pr67287:
+; SKX: # %bb.0: # %entry
+; SKX-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; SKX-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; SKX-NEXT: vptestnmq %xmm0, %xmm0, %k0
+; SKX-NEXT: kortestb %k0, %k0
+; SKX-NEXT: jne .LBB97_2
+; SKX-NEXT: # %bb.1: # %entry
+; SKX-NEXT: kmovd %k0, %eax
+; SKX-NEXT: testb $1, %al
+; SKX-NEXT: jne .LBB97_2
+; SKX-NEXT: # %bb.3: # %middle.block
+; SKX-NEXT: xorl %eax, %eax
+; SKX-NEXT: retq
+; SKX-NEXT: .LBB97_2:
+; SKX-NEXT: movw $0, 0
+; SKX-NEXT: xorl %eax, %eax
+; SKX-NEXT: retq
+entry:
+ %0 = and <2 x i64> %broadcast.splatinsert25, <i64 4294967295, i64 4294967295>
+ %1 = icmp eq <2 x i64> %0, zeroinitializer
+ %shift = shufflevector <2 x i1> %1, <2 x i1> zeroinitializer, <2 x i32> <i32 1, i32 poison>
+ %2 = or <2 x i1> %1, %shift
+ %3 = extractelement <2 x i1> %2, i64 0
+ %4 = extractelement <2 x i1> %1, i64 0
+ %5 = or i1 %3, %4
+ br i1 %5, label %6, label %middle.block
+
+6: ; preds = %entry
+ store i16 0, ptr null, align 2
+ br label %middle.block
+
+middle.block: ; preds = %6, %entry
+ ret i32 0
+}
|
movmsk(bitcast(shuffle(x)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch!
@@ -265,7 +265,6 @@ define i32 @ptestz_v2i64_signbits(<2 x i64> %c, i32 %a, i32 %b) { | |||
; SSE41-LABEL: ptestz_v2i64_signbits: | |||
; SSE41: # %bb.0: | |||
; SSE41-NEXT: movl %edi, %eax | |||
; SSE41-NEXT: pshufd {{.*#+}} xmm0 = xmm0[1,1,3,3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is definitely wrong as we've gone from just needing the signbits of a <2 x i64> to the signbits of a <4 x i32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Err, yeah I had flipped shuffle src/dst in the codes.
if (NumElts <= CmpBits && | ||
getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs, | ||
ShuffleMask, DAG) && | ||
getTargetShuffleInputs(BaseVec, ShuffleInputs, ShuffleMask, DAG) && | ||
ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) && | ||
ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all we need to do is to add a scaleShuffleElements(ShuffleMask, NumElts, ScaledMask) check here to ensure that the shuffle mask can be scaled back to the original mask width:
// MOVMSK(SHUFFLE(X,u)) -> MOVMSK(X) iff every element is referenced.
SmallVector<int, 32> ShuffleMask, ScaledMask;
SmallVector<SDValue, 2> ShuffleInputs;
if (NumElts <= CmpBits &&
getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs,
ShuffleMask, DAG) &&
ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) &&
ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits() &&
scaleShuffleElements(ShuffleMask, NumElts, ScaledMask)) {
APInt DemandedElts = APInt::getZero(NumElts);
for (int M : ScaledMask) {
assert(0 <= M && M < (int)NumElts && "Bad unary shuffle index");
DemandedElts.setBit(M);
}
if (DemandedElts.isAllOnes()) {
SDLoc DL(EFLAGS);
SDValue Result = DAG.getBitcast(VecVT, ShuffleInputs[0]);
Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
Result =
DAG.getZExtOrTrunc(Result, DL, EFLAGS.getOperand(0).getValueType());
return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
EFLAGS.getOperand(1));
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are alot of ways to fix this. Posting V2 that I think is the most precise version. If you still have reservations can give this a try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer my proposal tbh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is its overly conservative. As long as all the "high" elements stay "high", we are fine. We don't need the proper words to stay together.
b2bdd1f
to
9de12e0
Compare
Have you any cases where all of this is necessary vs my smaller patch? This code was to handle a uncommon edge case in the first place (a movmsk variant of reduce(permute(x)) -> reduce(x)), and we're adding a lot of code to handle it. Ideally the SimplifyDemanded* calls will have simplified the shuffle anyhow if we're only demanding the sign bits . |
Prior logic would remove the shuffle iff all of the elements in `x` where used. This is incorrect. The issue is `movmsk` only cares about the highbits, so if the width of the elements in `x` is smaller than the width of the elements for the `movmsk`, then the shuffle, even if it preserves all the elements, may change which ones are used by the highbits. For example: `movmsk64(bitcast(shuffle32(x, (1,0,3,2))))` Even though the shuffle mask `(1,0,3,2)` preserves all the elements, it flips which will be relevant to the `movmsk64` (x[1] and x[3] before and x[0] and x[2] after). The fix here, is to ensure that the shuffle mask can be scaled to the element width of the `movmsk` instruction. This ensure that the "high" elements stay "high". This is overly conservative as it misses cases like `(1,1,3,3)` where the "high" elements stay intact despite not be scalable, but for an relatively edge-case optimization that should generally be handled during simplifyDemandedBits, it seems okay.
9de12e0
to
6adb853
Compare
movmsk(bitcast(shuffle(x)))
movmsk(bitcast(shuffle(x)))
Okay, changed to your version + updated commit/PR message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - cheers!
Pushed: 1684c65 |
movmsk
; PR67287movmsk(bitcast(shuffle(x)))
; PR67287Fixes #67287.