Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ValueTracking] Add support for llvm.vector.reduce.{xor,or,and} ops. #88320

Closed

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Apr 10, 2024

  • [ValueTracking] Add tests for computeKnownBits of llvm.vector.reduce.{or,and}; NFC
  • [ValueTracking] Implement computeKnownBits for llvm.vector.reduce.{or,and}
  • [ValueTracking] Add tests for computeKnownBits of llvm.vector.reduce.xor; NFC
  • [ValueTracking] Implement computeKnownBits for llvm.vector.reduce.xor
  • [ValueTracking] Add tests for isKnownNonZero of llvm.vector.reduce.or; NFC
  • [ValueTracking] Implement isKnownNonZero for llvm.vector.reduce.or

@goldsteinn goldsteinn requested a review from nikic as a code owner April 10, 2024 21:10
@goldsteinn goldsteinn changed the title goldsteinn/vec reduce support [ValueTracking] Add support for most llvm.vector.reduce.* ops. Apr 10, 2024
@goldsteinn goldsteinn force-pushed the goldsteinn/vec-reduce-support branch from ec8857a to 4e949bd Compare April 10, 2024 21:14
@goldsteinn
Copy link
Contributor Author

The missing intrins are fadd and fmul which i don't think are so simple to implement.
Also note that add/mul non-zero logic might be more complex than is warranted (proofs in the commit). Happy to drop.

dyn_cast<FixedVectorType>(I->getOperand(0)->getType())) {
computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
KnownBits SingleKnown = Known;
for (unsigned i = 1, e = VecTy->getNumElements(); i < e; ++i) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can reduce the computational complexity from O(n) to O(lgn) with exponentiation by squaring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I figured most likely N isn't too big here and this isn't a particularly hot path so it wasn't worth to code complexity, but no real strong opinions.

case Intrinsic::vector_reduce_fminimum:
computeKnownFPClass(II->getArgOperand(0), Known, InterestedClasses,
Depth + 1, Q);
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please split this into a separate PR for review by FP folks? I'm not sure things are quite as simple as that (e.g., do we need to account for denormal flush like the scalar min/max handling above)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, see: #88408

@@ -2904,6 +2942,41 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
case Intrinsic::vector_reduce_smax:
case Intrinsic::vector_reduce_smin:
return isKnownNonZero(II->getArgOperand(0), Depth, Q);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we easily handle "or" reduction for isKnownNonZero()?

break;
return isKnownNonZero(II->getArgOperand(0), Depth, Q);
}
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't think this is worthwhile.

llvm/lib/Analysis/ValueTracking.cpp Show resolved Hide resolved
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 11, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: None (goldsteinn)

Changes
  • [ValueTracking] Add tests for computeKnownBits of llvm.vector.reduce.{or,and}; NFC
  • [ValueTracking] Implement computeKnownBits for llvm.vector.reduce.{or,and}
  • [ValueTracking] Add tests for computeKnownBits of llvm.vector.reduce.xor; NFC
  • [ValueTracking] Implement computeKnownBits for llvm.vector.reduce.xor
  • [ValueTracking] Add tests for computeKnownBits of llvm.vector.reduce.{add,mul}; NFC
  • [ValueTracking] Implement computeKnownBits for llvm.vector.reduce.{add,mul}
  • [ValueTracking] Add tests for computeKnownFPClass of llvm.vector.reduce.{fmin,fmax,fmaximum,fminimum}; NFC
  • [ValueTracking] Implement computeKnownFPClass for llvm.vector.reduce.{fmin,fmax,fmaximum,fminimum}
  • [ValueTracking] Add tests for isKnownNonZero of llvm.vector.reduce.{add,mul}; NFC
  • [ValueTracking] Implement isKnownNonZero for llvm.vector.reduce.{add,mul}

Full diff: https://github.com/llvm/llvm-project/pull/88320.diff

3 Files Affected:

  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+22-3)
  • (modified) llvm/test/Transforms/InstCombine/known-bits.ll (+125)
  • (modified) llvm/test/Transforms/InstSimplify/known-non-zero.ll (+23)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 3a10de72a27562..4e9d446ddb1807 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1621,14 +1621,32 @@ static void computeKnownBitsFromOperator(const Operator *I,
         computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
         Known = KnownBits::ssub_sat(Known, Known2);
         break;
-        // for min/max reduce, any bit common to each element in the input vec
-        // is set in the output.
+        // for min/max/and/or reduce, any bit common to each element in the
+        // input vec is set in the output.
+      case Intrinsic::vector_reduce_and:
+      case Intrinsic::vector_reduce_or:
       case Intrinsic::vector_reduce_umax:
       case Intrinsic::vector_reduce_umin:
       case Intrinsic::vector_reduce_smax:
       case Intrinsic::vector_reduce_smin:
         computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
         break;
+      case Intrinsic::vector_reduce_xor:
+        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+        // The zeros common to all vecs are zero in the output.
+        // If the number of elements is odd, then the common ones remain. If the
+        // number of elements is even, then the common ones becomes zeros.
+        if (auto *VecTy =
+                dyn_cast<FixedVectorType>(I->getOperand(0)->getType())) {
+          // Even, so the ones become zeros.
+          if ((VecTy->getNumElements() % 2) == 0) {
+            Known.Zero |= Known.One;
+            Known.One.clearAllBits();
+          }
+
+        } else
+          Known.One.clearAllBits();
+        break;
       case Intrinsic::umin:
         computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
         computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
@@ -2898,7 +2916,8 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
         return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth,
                             II->getArgOperand(0), II->getArgOperand(1),
                             /*NSW=*/true, /* NUW=*/false);
-        // umin/smin/smax/smin of all non-zero elements is always non-zero.
+        // umin/smin/smax/smin/or of all non-zero elements is always non-zero.
+      case Intrinsic::vector_reduce_or:
       case Intrinsic::vector_reduce_umax:
       case Intrinsic::vector_reduce_umin:
       case Intrinsic::vector_reduce_smax:
diff --git a/llvm/test/Transforms/InstCombine/known-bits.ll b/llvm/test/Transforms/InstCombine/known-bits.ll
index d210b19bb7faf2..6c3701a32f227f 100644
--- a/llvm/test/Transforms/InstCombine/known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/known-bits.ll
@@ -999,5 +999,130 @@ define i1 @extract_value_smul_fail(i8 %xx, i8 %yy) {
   ret i1 %r
 }
 
+define i8 @known_reduce_or(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_or(
+; CHECK-NEXT:    ret i8 1
+;
+  %x = or <2 x i8> %xx, <i8 5, i8 3>
+  %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+  %r = and i8 %v, 1
+  ret i8 %r
+}
+
+define i8 @known_reduce_or_fail(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_or_fail(
+; CHECK-NEXT:    [[X:%.*]] = or <2 x i8> [[XX:%.*]], <i8 5, i8 3>
+; CHECK-NEXT:    [[V:%.*]] = call i8 @llvm.vector.reduce.or.v2i8(<2 x i8> [[X]])
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[V]], 4
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or <2 x i8> %xx, <i8 5, i8 3>
+  %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+  %r = and i8 %v, 4
+  ret i8 %r
+}
+
+define i8 @known_reduce_and(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_and(
+; CHECK-NEXT:    ret i8 1
+;
+  %x = or <2 x i8> %xx, <i8 5, i8 3>
+  %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+  %r = and i8 %v, 1
+  ret i8 %r
+}
+
+define i8 @known_reduce_and_fail(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_and_fail(
+; CHECK-NEXT:    [[X:%.*]] = or <2 x i8> [[XX:%.*]], <i8 5, i8 3>
+; CHECK-NEXT:    [[V:%.*]] = call i8 @llvm.vector.reduce.or.v2i8(<2 x i8> [[X]])
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[V]], 2
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or <2 x i8> %xx, <i8 5, i8 3>
+  %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+  %r = and i8 %v, 2
+  ret i8 %r
+}
+
+define i8 @known_reduce_xor_even(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_even(
+; CHECK-NEXT:    ret i8 0
+;
+  %x = or <2 x i8> %xx, <i8 5, i8 3>
+  %v = call i8 @llvm.vector.reduce.xor(<2 x i8> %x)
+  %r = and i8 %v, 1
+  ret i8 %r
+}
+
+define i8 @known_reduce_xor_even2(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_even2(
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and <2 x i8> %xx, <i8 15, i8 15>
+  %v = call i8 @llvm.vector.reduce.xor(<2 x i8> %x)
+  %r = and i8 %v, 16
+  ret i8 %r
+}
+
+define i8 @known_reduce_xor_even_fail(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_even_fail(
+; CHECK-NEXT:    [[X:%.*]] = or <2 x i8> [[XX:%.*]], <i8 5, i8 3>
+; CHECK-NEXT:    [[V:%.*]] = call i8 @llvm.vector.reduce.xor.v2i8(<2 x i8> [[X]])
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[V]], 2
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or <2 x i8> %xx, <i8 5, i8 3>
+  %v = call i8 @llvm.vector.reduce.xor(<2 x i8> %x)
+  %r = and i8 %v, 2
+  ret i8 %r
+}
+
+define i8 @known_reduce_xor_odd(<3 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_odd(
+; CHECK-NEXT:    ret i8 1
+;
+  %x = or <3 x i8> %xx, <i8 5, i8 3, i8 9>
+  %v = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> %x)
+  %r = and i8 %v, 1
+  ret i8 %r
+}
+
+define i8 @known_reduce_xor_odd2(<3 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_odd2(
+; CHECK-NEXT:    ret i8 0
+;
+  %x = and <3 x i8> %xx, <i8 15, i8 15, i8 31>
+  %v = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> %x)
+  %r = and i8 %v, 32
+  ret i8 %r
+}
+
+define i8 @known_reduce_xor_odd2_fail(<3 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_odd2_fail(
+; CHECK-NEXT:    [[X:%.*]] = and <3 x i8> [[XX:%.*]], <i8 15, i8 15, i8 31>
+; CHECK-NEXT:    [[V:%.*]] = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> [[X]])
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[V]], 16
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = and <3 x i8> %xx, <i8 15, i8 15, i8 31>
+  %v = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> %x)
+  %r = and i8 %v, 16
+  ret i8 %r
+}
+
+define i8 @known_reduce_xor_odd_fail(<3 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_odd_fail(
+; CHECK-NEXT:    [[X:%.*]] = or <3 x i8> [[XX:%.*]], <i8 5, i8 3, i8 9>
+; CHECK-NEXT:    [[V:%.*]] = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> [[X]])
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[V]], 2
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %x = or <3 x i8> %xx, <i8 5, i8 3, i8 9>
+  %v = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> %x)
+  %r = and i8 %v, 2
+  ret i8 %r
+}
+
 declare void @use(i1)
 declare void @sink(i8)
diff --git a/llvm/test/Transforms/InstSimplify/known-non-zero.ll b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
index d9b8f5eed32390..fd2862eb04a24d 100644
--- a/llvm/test/Transforms/InstSimplify/known-non-zero.ll
+++ b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
@@ -377,3 +377,26 @@ define <2 x i1> @insert_nonzero_any_idx_fail(<2 x i8> %xx, i8 %yy, i32 %idx) {
   %r = icmp eq <2 x i8> %ins, zeroinitializer
   ret <2 x i1> %r
 }
+
+define i1 @nonzero_reduce_or(<2 x i8> %xx) {
+; CHECK-LABEL: @nonzero_reduce_or(
+; CHECK-NEXT:    ret i1 false
+;
+  %x = add nuw <2 x i8> %xx, <i8 1, i8 1>
+  %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+  %r = icmp eq i8 %v, 0
+  ret i1 %r
+}
+
+define i1 @nonzero_reduce_or_fail(<2 x i8> %xx) {
+; CHECK-LABEL: @nonzero_reduce_or_fail(
+; CHECK-NEXT:    [[X:%.*]] = add nsw <2 x i8> [[XX:%.*]], <i8 1, i8 1>
+; CHECK-NEXT:    [[V:%.*]] = call i8 @llvm.vector.reduce.or.v2i8(<2 x i8> [[X]])
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[V]], 0
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %x = add nsw <2 x i8> %xx, <i8 1, i8 1>
+  %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+  %r = icmp eq i8 %v, 0
+  ret i1 %r
+}

@goldsteinn goldsteinn changed the title [ValueTracking] Add support for most llvm.vector.reduce.* ops. [ValueTracking] Add support for most llvm.vector.reduce.{xor,or,and} ops. Apr 11, 2024
@goldsteinn goldsteinn changed the title [ValueTracking] Add support for most llvm.vector.reduce.{xor,or,and} ops. [ValueTracking] Add support for llvm.vector.reduce.{xor,or,and} ops. Apr 11, 2024
Comment on lines 1639 to 1640
if (auto *VecTy =
dyn_cast<FixedVectorType>(I->getOperand(0)->getType())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also handle scalable if you need to know an even count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? IIUC the only thing we know about scalable is that sizeof(vec) == N * sizeof(scalar) where N >= MinCount. How can we get odd/even?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's <scale * N * element>, if you know N is even it doesn't matter what scale is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks. Done.

@goldsteinn goldsteinn force-pushed the goldsteinn/vec-reduce-support branch from 95dbc18 to ed99b2a Compare April 12, 2024 20:09
// number of elements is even, then the common ones becomes zeros.
auto *VecTy = cast<VectorType>(I->getOperand(0)->getType());
// Even, so the ones become zeros.
if (VecTy->getElementCount().isKnownEven()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this have an else that sets the ones to unknown? Otherwise what about something like vscale x 3 which is not known even but might be.

(Actually, I'd just condition the whole xor code on isKnownEven, we don't care about odd-size vectors anyway.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right. Pushed fix/test.

Its very minimal code complexity to handle odd, so left it in for now (although think I agree its moreso for completeness than practical applicability). LMK if you feel strongly about dropping.

@goldsteinn goldsteinn force-pushed the goldsteinn/vec-reduce-support branch from ed99b2a to c6ed849 Compare April 13, 2024 05:42
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// If the number of elements is odd, then the common ones remain. If the
// number of elements is even, then the common ones becomes zeros.
auto *VecTy = cast<VectorType>(I->getOperand(0)->getType());
// Even, so the ones become zeros.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant with the preceding comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess the idea is the top comment explains what the block will do. The lower comments explain what the next few lines will do. With this minimal code though yeah it does become a bit redundant, but comments are like sex so...

bazuzi pushed a commit to bazuzi/llvm-project that referenced this pull request Apr 15, 2024
aniplcc pushed a commit to aniplcc/llvm-project that referenced this pull request Apr 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants