Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release/18.x: [InstCombine] Fix miscompilation in PR83947 (#83993) #84021

Closed
wants to merge 1 commit into from

Conversation

llvmbot
Copy link
Collaborator

@llvmbot llvmbot commented Mar 5, 2024

Backport a1a590e

Requested by: @dtcxzyw

https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.

(cherry picked from commit a1a590e)
@llvmbot llvmbot requested a review from nikic as a code owner March 5, 2024 14:43
@llvmbot llvmbot added this to the LLVM 18.X Release milestone Mar 5, 2024
@llvmbot
Copy link
Collaborator Author

llvmbot commented Mar 5, 2024

@nikic What do you think about merging this PR to the release branch?

@llvmbot
Copy link
Collaborator Author

llvmbot commented Mar 5, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: None (llvmbot)

Changes

Backport a1a590e

Requested by: @dtcxzyw


Full diff: https://github.com/llvm/llvm-project/pull/84021.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/VectorUtils.h (+5)
  • (modified) llvm/lib/Analysis/VectorUtils.cpp (+25)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+8-5)
  • (added) llvm/test/Transforms/InstCombine/pr83947.ll (+67)
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7a92e62b53c53d..c6eb66cc9660ca 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -406,6 +406,11 @@ bool maskIsAllZeroOrUndef(Value *Mask);
 /// lanes can be assumed active.
 bool maskIsAllOneOrUndef(Value *Mask);
 
+/// Given a mask vector of i1, Return true if any of the elements of this
+/// predicate mask are known to be true or undef.  That is, return true if at
+/// least one lane can be assumed active.
+bool maskContainsAllOneOrUndef(Value *Mask);
+
 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 73facc76a92b2c..bf7bc0ba84a033 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1012,6 +1012,31 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) {
   return true;
 }
 
+bool llvm::maskContainsAllOneOrUndef(Value *Mask) {
+  assert(isa<VectorType>(Mask->getType()) &&
+         isa<IntegerType>(Mask->getType()->getScalarType()) &&
+         cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
+             1 &&
+         "Mask must be a vector of i1");
+
+  auto *ConstMask = dyn_cast<Constant>(Mask);
+  if (!ConstMask)
+    return false;
+  if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
+    return true;
+  if (isa<ScalableVectorType>(ConstMask->getType()))
+    return false;
+  for (unsigned
+           I = 0,
+           E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
+       I != E; ++I) {
+    if (auto *MaskElt = ConstMask->getAggregateElement(I))
+      if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
+        return true;
+  }
+  return false;
+}
+
 /// TODO: This is a lot like known bits, but for
 /// vectors.  Is there something we can common this with?
 APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index a647be2d26c761..bc43edb5e62065 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -412,11 +412,14 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
   if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
     // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
     if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
-      Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
-      StoreInst *S =
-          new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
-      S->copyMetadata(II);
-      return S;
+      if (maskContainsAllOneOrUndef(ConstMask)) {
+        Align Alignment =
+            cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+        StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false,
+                                     Alignment);
+        S->copyMetadata(II);
+        return S;
+      }
     }
     // scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
     // lastlane), ptr
diff --git a/llvm/test/Transforms/InstCombine/pr83947.ll b/llvm/test/Transforms/InstCombine/pr83947.ll
new file mode 100644
index 00000000000000..c1d601ff637183
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/pr83947.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+@c = global i32 0, align 4
+@b = global i32 0, align 4
+
+define void @masked_scatter1() {
+; CHECK-LABEL: define void @masked_scatter1() {
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @c, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer), i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> splat (ptr @c), i32 4, <vscale x 4 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+  ret void
+}
+
+define void @masked_scatter2() {
+; CHECK-LABEL: define void @masked_scatter2() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 true))
+  ret void
+}
+
+define void @masked_scatter3() {
+; CHECK-LABEL: define void @masked_scatter3() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef)
+  ret void
+}
+
+define void @masked_scatter4() {
+; CHECK-LABEL: define void @masked_scatter4() {
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 false))
+  ret void
+}
+
+define void @masked_scatter5() {
+; CHECK-LABEL: define void @masked_scatter5() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @masked_scatter6() {
+; CHECK-LABEL: define void @masked_scatter6() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 undef, i1 false>)
+  ret void
+}
+
+define void @masked_scatter7() {
+; CHECK-LABEL: define void @masked_scatter7() {
+; CHECK-NEXT:    call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> <ptr @c, ptr @c>, i32 4, <2 x i1> <i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c)>)
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+  ret void
+}

@nikic
Copy link
Contributor

nikic commented Mar 6, 2024

There is a test failure.

@dtcxzyw dtcxzyw self-assigned this Mar 7, 2024
@dtcxzyw
Copy link
Member

dtcxzyw commented Mar 7, 2024

@nikic Do you know how to add a commit to this PR?

diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
index 2704905f7a35..c87c1199f727 100644
--- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
@@ -292,7 +292,11 @@ entry:
 define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(ptr %dst, i16 %val) {
 ; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    store i16 [[VAL:%.*]], ptr [[DST:%.*]], align 2
+; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[DST:%.*]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 4 x ptr> [[BROADCAST_SPLATINSERT]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    [[BROADCAST_VALUE:%.*]] = insertelement <vscale x 4 x i16> poison, i16 [[VAL:%.*]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLATVALUE:%.*]] = shufflevector <vscale x 4 x i16> [[BROADCAST_VALUE]], <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i16.nxv4p0(<vscale x 4 x i16> [[BROADCAST_SPLATVALUE]], <vscale x 4 x ptr> [[BROADCAST_SPLAT]], i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer, i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret void
 ;
 entry:

It is a regression. But I think it should be fixed by fd07b8f.

@nikic
Copy link
Contributor

nikic commented Mar 7, 2024

@dtcxzyw You should be able to push to the branch in the llvmbot repo.

@dtcxzyw
Copy link
Member

dtcxzyw commented Mar 7, 2024

@nikic Do you know how to add a commit to this PR?

diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
index 2704905f7a35..c87c1199f727 100644
--- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
@@ -292,7 +292,11 @@ entry:
 define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(ptr %dst, i16 %val) {
 ; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    store i16 [[VAL:%.*]], ptr [[DST:%.*]], align 2
+; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[DST:%.*]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 4 x ptr> [[BROADCAST_SPLATINSERT]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    [[BROADCAST_VALUE:%.*]] = insertelement <vscale x 4 x i16> poison, i16 [[VAL:%.*]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLATVALUE:%.*]] = shufflevector <vscale x 4 x i16> [[BROADCAST_VALUE]], <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i16.nxv4p0(<vscale x 4 x i16> [[BROADCAST_SPLATVALUE]], <vscale x 4 x ptr> [[BROADCAST_SPLAT]], i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer, i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret void
 ;
 entry:

It is a regression. But I think it should be fixed by fd07b8f.

I am trying to fix it in the GitHub codespace.

@dtcxzyw
Copy link
Member

dtcxzyw commented Mar 7, 2024

remote: Permission to llvmbot/llvm-project.git denied to dtcxzyw.
fatal: unable to access 'https://github.com/llvmbot/llvm-project/': The requested URL returned error: 403

@dtcxzyw dtcxzyw closed this Mar 7, 2024
@tstellar
Copy link
Collaborator

This is fixed in #84298.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants