Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR83993 to 18.x #84298

Merged
merged 1 commit into from
Mar 13, 2024
Merged

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Mar 7, 2024

Backport #83993

It is an alternative to #84021.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 7, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Yingwei Zheng (dtcxzyw)

Changes

Backport #83993

It is an alternative to #84021.


Full diff: https://github.com/llvm/llvm-project/pull/84298.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/VectorUtils.h (+5)
  • (modified) llvm/lib/Analysis/VectorUtils.cpp (+25)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+8-5)
  • (modified) llvm/test/Transforms/InstCombine/masked_intrinsics.ll (+5-1)
  • (added) llvm/test/Transforms/InstCombine/pr83947.ll (+67)
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7a92e62b53c53d..c6eb66cc9660ca 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -406,6 +406,11 @@ bool maskIsAllZeroOrUndef(Value *Mask);
 /// lanes can be assumed active.
 bool maskIsAllOneOrUndef(Value *Mask);
 
+/// Given a mask vector of i1, Return true if any of the elements of this
+/// predicate mask are known to be true or undef.  That is, return true if at
+/// least one lane can be assumed active.
+bool maskContainsAllOneOrUndef(Value *Mask);
+
 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 73facc76a92b2c..bf7bc0ba84a033 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1012,6 +1012,31 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) {
   return true;
 }
 
+bool llvm::maskContainsAllOneOrUndef(Value *Mask) {
+  assert(isa<VectorType>(Mask->getType()) &&
+         isa<IntegerType>(Mask->getType()->getScalarType()) &&
+         cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
+             1 &&
+         "Mask must be a vector of i1");
+
+  auto *ConstMask = dyn_cast<Constant>(Mask);
+  if (!ConstMask)
+    return false;
+  if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
+    return true;
+  if (isa<ScalableVectorType>(ConstMask->getType()))
+    return false;
+  for (unsigned
+           I = 0,
+           E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
+       I != E; ++I) {
+    if (auto *MaskElt = ConstMask->getAggregateElement(I))
+      if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
+        return true;
+  }
+  return false;
+}
+
 /// TODO: This is a lot like known bits, but for
 /// vectors.  Is there something we can common this with?
 APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index a647be2d26c761..bc43edb5e62065 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -412,11 +412,14 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
   if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
     // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
     if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
-      Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
-      StoreInst *S =
-          new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
-      S->copyMetadata(II);
-      return S;
+      if (maskContainsAllOneOrUndef(ConstMask)) {
+        Align Alignment =
+            cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+        StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false,
+                                     Alignment);
+        S->copyMetadata(II);
+        return S;
+      }
     }
     // scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
     // lastlane), ptr
diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
index 2704905f7a358d..c87c1199f727ea 100644
--- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll
@@ -292,7 +292,11 @@ entry:
 define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(ptr %dst, i16 %val) {
 ; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    store i16 [[VAL:%.*]], ptr [[DST:%.*]], align 2
+; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[DST:%.*]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 4 x ptr> [[BROADCAST_SPLATINSERT]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    [[BROADCAST_VALUE:%.*]] = insertelement <vscale x 4 x i16> poison, i16 [[VAL:%.*]], i64 0
+; CHECK-NEXT:    [[BROADCAST_SPLATVALUE:%.*]] = shufflevector <vscale x 4 x i16> [[BROADCAST_VALUE]], <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i16.nxv4p0(<vscale x 4 x i16> [[BROADCAST_SPLATVALUE]], <vscale x 4 x ptr> [[BROADCAST_SPLAT]], i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer, i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret void
 ;
 entry:
diff --git a/llvm/test/Transforms/InstCombine/pr83947.ll b/llvm/test/Transforms/InstCombine/pr83947.ll
new file mode 100644
index 00000000000000..c1d601ff637183
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/pr83947.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+@c = global i32 0, align 4
+@b = global i32 0, align 4
+
+define void @masked_scatter1() {
+; CHECK-LABEL: define void @masked_scatter1() {
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @c, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer), i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> splat (ptr @c), i32 4, <vscale x 4 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+  ret void
+}
+
+define void @masked_scatter2() {
+; CHECK-LABEL: define void @masked_scatter2() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 true))
+  ret void
+}
+
+define void @masked_scatter3() {
+; CHECK-LABEL: define void @masked_scatter3() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef)
+  ret void
+}
+
+define void @masked_scatter4() {
+; CHECK-LABEL: define void @masked_scatter4() {
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 false))
+  ret void
+}
+
+define void @masked_scatter5() {
+; CHECK-LABEL: define void @masked_scatter5() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @masked_scatter6() {
+; CHECK-LABEL: define void @masked_scatter6() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 undef, i1 false>)
+  ret void
+}
+
+define void @masked_scatter7() {
+; CHECK-LABEL: define void @masked_scatter7() {
+; CHECK-NEXT:    call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> <ptr @c, ptr @c>, i32 4, <2 x i1> <i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c)>)
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+  ret void
+}

https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.
@tstellar tstellar merged commit 3f8711f into llvm:release/18.x Mar 13, 2024
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants