Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86][ISel] Fix logic for optimizing movmsk(bitcast(shuffle(x))) #68369

Closed
wants to merge 2 commits into from

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Oct 6, 2023

  • [X86] Add tests for incorrectly optimizing out shuffle used in movmsk; PR67287
  • [X86] Fix/improve logic for optimizing movmsk(bitcast(shuffle(x))); PR67287

Fixes #67287.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 6, 2023

@llvm/pr-subscribers-backend-x86

Changes
  • [X86] Add tests for incorrectly optimizing out shuffle used in movmsk; PR67287
  • [X86] Fix/improve logic for optimizing movmsk(bitcast(shuffle(x))); PR67287

Full diff: https://github.com/llvm/llvm-project/pull/68369.diff

3 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+41-7)
  • (modified) llvm/test/CodeGen/X86/combine-ptest.ll (-1)
  • (modified) llvm/test/CodeGen/X86/movmsk-cmp.ll (+138)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c4cd2a672fe7b26..dab823e71aa9367 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -45836,18 +45836,52 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
   // MOVMSK(SHUFFLE(X,u)) -> MOVMSK(X) iff every element is referenced.
   SmallVector<int, 32> ShuffleMask;
   SmallVector<SDValue, 2> ShuffleInputs;
+  SDValue BaseVec = peekThroughBitcasts(Vec);
   if (NumElts <= CmpBits &&
-      getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs,
-                             ShuffleMask, DAG) &&
+      getTargetShuffleInputs(BaseVec, ShuffleInputs, ShuffleMask, DAG) &&
       ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) &&
       ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits()) {
     unsigned NumShuffleElts = ShuffleMask.size();
-    APInt DemandedElts = APInt::getZero(NumShuffleElts);
-    for (int M : ShuffleMask) {
-      assert(0 <= M && M < (int)NumShuffleElts && "Bad unary shuffle index");
-      DemandedElts.setBit(M);
+
+    APInt Result = APInt::getZero(NumShuffleElts);
+    APInt ImportantLocs;
+    // Since we peek through a bitcast, we need to be careful if the base vector
+    // type has smaller elements than the MOVMSK type.  In that case, even if
+    // all the elements are demanded by the shuffle mask, only the "high"
+    // elements which have highbits that align with highbits in the MOVMSK vec
+    // elements are actually demanded. A simplification of spurious operations
+    // on the "low" elements take place during other simplifications.
+    //
+    // For example:
+    // MOVMSK64(BITCAST(SHUF32 X, (1,0,3,2))) even though all the elements are
+    // demanded, because we are swapping around the result can change.
+    //
+    // To address this, we need to make sure all the "high" elements are moved
+    // to other "high" locations.
+    MVT BaseVT = BaseVec.getSimpleValueType();
+    unsigned BaseNumElts = BaseVT.getVectorNumElements();
+    if (BaseNumElts > NumElts) {
+      ImportantLocs = APInt::getZero(NumShuffleElts);
+      assert((BaseNumElts % NumElts) == 0 &&
+             "Vec with unsupported element size");
+      unsigned Scale = BaseNumElts / NumElts;
+      for (unsigned i = 0; i < BaseNumElts; ++i) {
+        if ((i % Scale) == (Scale - 1))
+          ImportantLocs.setBit(i);
+      }
+    } else {
+      ImportantLocs = APInt::getAllOnes(NumShuffleElts);
+    }
+
+    for (unsigned ShufSrc = 0; ShufSrc < ShuffleMask.size(); ++ShufSrc) {
+      int ShufDst = ShuffleMask[ShufSrc];
+      assert(0 <= ShufDst && ShufDst < (int)NumShuffleElts &&
+             "Bad unary shuffle index");
+      if (ImportantLocs[ShufSrc] && ImportantLocs[ShufDst])
+        Result.setBit(ShufSrc);
     }
-    if (DemandedElts.isAllOnes()) {
+
+    if (Result == ImportantLocs) {
       SDLoc DL(EFLAGS);
       SDValue Result = DAG.getBitcast(VecVT, ShuffleInputs[0]);
       Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
diff --git a/llvm/test/CodeGen/X86/combine-ptest.ll b/llvm/test/CodeGen/X86/combine-ptest.ll
index 337edef96beee2c..e5854564251a3e8 100644
--- a/llvm/test/CodeGen/X86/combine-ptest.ll
+++ b/llvm/test/CodeGen/X86/combine-ptest.ll
@@ -265,7 +265,6 @@ define i32 @ptestz_v2i64_signbits(<2 x i64> %c, i32 %a, i32 %b) {
 ; SSE41-LABEL: ptestz_v2i64_signbits:
 ; SSE41:       # %bb.0:
 ; SSE41-NEXT:    movl %edi, %eax
-; SSE41-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[1,1,3,3]
 ; SSE41-NEXT:    movmskps %xmm0, %ecx
 ; SSE41-NEXT:    testl %ecx, %ecx
 ; SSE41-NEXT:    cmovnel %esi, %eax
diff --git a/llvm/test/CodeGen/X86/movmsk-cmp.ll b/llvm/test/CodeGen/X86/movmsk-cmp.ll
index a0901e265f5ae97..f26bbb7e5c2bdac 100644
--- a/llvm/test/CodeGen/X86/movmsk-cmp.ll
+++ b/llvm/test/CodeGen/X86/movmsk-cmp.ll
@@ -4430,3 +4430,141 @@ define i32 @PR39665_c_ray_opt(<2 x double> %x, <2 x double> %y) {
   %r = select i1 %u, i32 42, i32 99
   ret i32 %r
 }
+
+define i32 @pr67287(<2 x i64> %broadcast.splatinsert25) {
+; SSE2-LABEL: pr67287:
+; SSE2:       # %bb.0: # %entry
+; SSE2-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE2-NEXT:    pxor %xmm1, %xmm1
+; SSE2-NEXT:    pcmpeqd %xmm0, %xmm1
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[1,0,3,2]
+; SSE2-NEXT:    movmskpd %xmm0, %eax
+; SSE2-NEXT:    testl %eax, %eax
+; SSE2-NEXT:    jne .LBB97_2
+; SSE2-NEXT:  # %bb.1: # %entry
+; SSE2-NEXT:    movd %xmm1, %eax
+; SSE2-NEXT:    testb $1, %al
+; SSE2-NEXT:    jne .LBB97_2
+; SSE2-NEXT:  # %bb.3: # %middle.block
+; SSE2-NEXT:    xorl %eax, %eax
+; SSE2-NEXT:    retq
+; SSE2-NEXT:  .LBB97_2:
+; SSE2-NEXT:    movw $0, 0
+; SSE2-NEXT:    xorl %eax, %eax
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: pr67287:
+; SSE41:       # %bb.0: # %entry
+; SSE41-NEXT:    pxor %xmm1, %xmm1
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
+; SSE41-NEXT:    pcmpeqq %xmm1, %xmm0
+; SSE41-NEXT:    movmskpd %xmm0, %eax
+; SSE41-NEXT:    testl %eax, %eax
+; SSE41-NEXT:    jne .LBB97_2
+; SSE41-NEXT:  # %bb.1: # %entry
+; SSE41-NEXT:    movd %xmm0, %eax
+; SSE41-NEXT:    testb $1, %al
+; SSE41-NEXT:    jne .LBB97_2
+; SSE41-NEXT:  # %bb.3: # %middle.block
+; SSE41-NEXT:    xorl %eax, %eax
+; SSE41-NEXT:    retq
+; SSE41-NEXT:  .LBB97_2:
+; SSE41-NEXT:    movw $0, 0
+; SSE41-NEXT:    xorl %eax, %eax
+; SSE41-NEXT:    retq
+;
+; AVX1-LABEL: pr67287:
+; AVX1:       # %bb.0: # %entry
+; AVX1-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX1-NEXT:    vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
+; AVX1-NEXT:    vpcmpeqq %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vtestpd %xmm0, %xmm0
+; AVX1-NEXT:    jne .LBB97_2
+; AVX1-NEXT:  # %bb.1: # %entry
+; AVX1-NEXT:    vmovd %xmm0, %eax
+; AVX1-NEXT:    testb $1, %al
+; AVX1-NEXT:    jne .LBB97_2
+; AVX1-NEXT:  # %bb.3: # %middle.block
+; AVX1-NEXT:    xorl %eax, %eax
+; AVX1-NEXT:    retq
+; AVX1-NEXT:  .LBB97_2:
+; AVX1-NEXT:    movw $0, 0
+; AVX1-NEXT:    xorl %eax, %eax
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: pr67287:
+; AVX2:       # %bb.0: # %entry
+; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX2-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; AVX2-NEXT:    vpcmpeqq %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    vtestpd %xmm0, %xmm0
+; AVX2-NEXT:    jne .LBB97_2
+; AVX2-NEXT:  # %bb.1: # %entry
+; AVX2-NEXT:    vmovd %xmm0, %eax
+; AVX2-NEXT:    testb $1, %al
+; AVX2-NEXT:    jne .LBB97_2
+; AVX2-NEXT:  # %bb.3: # %middle.block
+; AVX2-NEXT:    xorl %eax, %eax
+; AVX2-NEXT:    retq
+; AVX2-NEXT:  .LBB97_2:
+; AVX2-NEXT:    movw $0, 0
+; AVX2-NEXT:    xorl %eax, %eax
+; AVX2-NEXT:    retq
+;
+; KNL-LABEL: pr67287:
+; KNL:       # %bb.0: # %entry
+; KNL-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; KNL-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; KNL-NEXT:    vptestnmq %zmm0, %zmm0, %k0
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    testb $3, %al
+; KNL-NEXT:    jne .LBB97_2
+; KNL-NEXT:  # %bb.1: # %entry
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    testb $1, %al
+; KNL-NEXT:    jne .LBB97_2
+; KNL-NEXT:  # %bb.3: # %middle.block
+; KNL-NEXT:    xorl %eax, %eax
+; KNL-NEXT:    vzeroupper
+; KNL-NEXT:    retq
+; KNL-NEXT:  .LBB97_2:
+; KNL-NEXT:    movw $0, 0
+; KNL-NEXT:    xorl %eax, %eax
+; KNL-NEXT:    vzeroupper
+; KNL-NEXT:    retq
+;
+; SKX-LABEL: pr67287:
+; SKX:       # %bb.0: # %entry
+; SKX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; SKX-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; SKX-NEXT:    vptestnmq %xmm0, %xmm0, %k0
+; SKX-NEXT:    kortestb %k0, %k0
+; SKX-NEXT:    jne .LBB97_2
+; SKX-NEXT:  # %bb.1: # %entry
+; SKX-NEXT:    kmovd %k0, %eax
+; SKX-NEXT:    testb $1, %al
+; SKX-NEXT:    jne .LBB97_2
+; SKX-NEXT:  # %bb.3: # %middle.block
+; SKX-NEXT:    xorl %eax, %eax
+; SKX-NEXT:    retq
+; SKX-NEXT:  .LBB97_2:
+; SKX-NEXT:    movw $0, 0
+; SKX-NEXT:    xorl %eax, %eax
+; SKX-NEXT:    retq
+entry:
+  %0 = and <2 x i64> %broadcast.splatinsert25, <i64 4294967295, i64 4294967295>
+  %1 = icmp eq <2 x i64> %0, zeroinitializer
+  %shift = shufflevector <2 x i1> %1, <2 x i1> zeroinitializer, <2 x i32> <i32 1, i32 poison>
+  %2 = or <2 x i1> %1, %shift
+  %3 = extractelement <2 x i1> %2, i64 0
+  %4 = extractelement <2 x i1> %1, i64 0
+  %5 = or i1 %3, %4
+  br i1 %5, label %6, label %middle.block
+
+6:                                                ; preds = %entry
+  store i16 0, ptr null, align 2
+  br label %middle.block
+
+middle.block:                                     ; preds = %6, %entry
+  ret i32 0
+}

@dtcxzyw dtcxzyw changed the title goldsteinn/bugfix pr67287 [X86][ISel] Improve logic for optimizing movmsk(bitcast(shuffle(x))) Oct 6, 2023
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

@@ -265,7 +265,6 @@ define i32 @ptestz_v2i64_signbits(<2 x i64> %c, i32 %a, i32 %b) {
; SSE41-LABEL: ptestz_v2i64_signbits:
; SSE41: # %bb.0:
; SSE41-NEXT: movl %edi, %eax
; SSE41-NEXT: pshufd {{.*#+}} xmm0 = xmm0[1,1,3,3]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely wrong as we've gone from just needing the signbits of a <2 x i64> to the signbits of a <4 x i32>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err, yeah I had flipped shuffle src/dst in the codes.

if (NumElts <= CmpBits &&
getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs,
ShuffleMask, DAG) &&
getTargetShuffleInputs(BaseVec, ShuffleInputs, ShuffleMask, DAG) &&
ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) &&
ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all we need to do is to add a scaleShuffleElements(ShuffleMask, NumElts, ScaledMask) check here to ensure that the shuffle mask can be scaled back to the original mask width:

  // MOVMSK(SHUFFLE(X,u)) -> MOVMSK(X) iff every element is referenced.
  SmallVector<int, 32> ShuffleMask, ScaledMask;
  SmallVector<SDValue, 2> ShuffleInputs;
  if (NumElts <= CmpBits &&
      getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs,
                             ShuffleMask, DAG) &&
      ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) &&
      ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits() &&
      scaleShuffleElements(ShuffleMask, NumElts, ScaledMask)) {
    APInt DemandedElts = APInt::getZero(NumElts);
    for (int M : ScaledMask) {
      assert(0 <= M && M < (int)NumElts && "Bad unary shuffle index");
      DemandedElts.setBit(M);
    }
    if (DemandedElts.isAllOnes()) {
      SDLoc DL(EFLAGS);
      SDValue Result = DAG.getBitcast(VecVT, ShuffleInputs[0]);
      Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
      Result =
          DAG.getZExtOrTrunc(Result, DL, EFLAGS.getOperand(0).getValueType());
      return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
                         EFLAGS.getOperand(1));
    }
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are alot of ways to fix this. Posting V2 that I think is the most precise version. If you still have reservations can give this a try.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer my proposal tbh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is its overly conservative. As long as all the "high" elements stay "high", we are fine. We don't need the proper words to stay together.

@RKSimon
Copy link
Collaborator

RKSimon commented Oct 8, 2023

Have you any cases where all of this is necessary vs my smaller patch? This code was to handle a uncommon edge case in the first place (a movmsk variant of reduce(permute(x)) -> reduce(x)), and we're adding a lot of code to handle it. Ideally the SimplifyDemanded* calls will have simplified the shuffle anyhow if we're only demanding the sign bits .

Prior logic would remove the shuffle iff all of the elements in `x`
where used. This is incorrect.

The issue is `movmsk` only cares about the highbits, so if the width
of the elements in `x` is smaller than the width of the elements
for the `movmsk`, then the shuffle, even if it preserves all the elements,
may change which ones are used by the highbits.

For example:
`movmsk64(bitcast(shuffle32(x, (1,0,3,2))))`

Even though the shuffle mask `(1,0,3,2)` preserves all the elements, it
flips which will be relevant to the `movmsk64` (x[1] and x[3]
before and x[0] and x[2] after).

The fix here, is to ensure that the shuffle mask can be scaled to the
element width of the `movmsk` instruction. This ensure that the
"high" elements stay "high". This is overly conservative as it
misses cases like `(1,1,3,3)` where the "high" elements stay
intact despite not be scalable, but for an relatively edge-case
optimization that should generally be handled during
simplifyDemandedBits, it seems okay.
@goldsteinn goldsteinn changed the title [X86][ISel] Improve logic for optimizing movmsk(bitcast(shuffle(x))) [X86][ISel] Fix logic for optimizing movmsk(bitcast(shuffle(x))) Oct 8, 2023
@goldsteinn
Copy link
Contributor Author

Have you any cases where all of this is necessary vs my smaller patch? This code was to handle a uncommon edge case in the first place (a movmsk variant of reduce(permute(x)) -> reduce(x)), and we're adding a lot of code to handle it. Ideally the SimplifyDemanded* calls will have simplified the shuffle anyhow if we're only demanding the sign bits .

Okay, changed to your version + updated commit/PR message

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - cheers!

@goldsteinn
Copy link
Contributor Author

Pushed: 1684c65

@goldsteinn goldsteinn closed this Oct 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wrong code at -O2 on x86_64-linux_gnu since ddfee6d (recent regression)
3 participants