-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ValueTracking] Add support for llvm.vector.reduce.{xor,or,and}
ops.
#88320
Conversation
llvm.vector.reduce.*
ops.
ec8857a
to
4e949bd
Compare
The missing intrins are |
llvm/lib/Analysis/ValueTracking.cpp
Outdated
dyn_cast<FixedVectorType>(I->getOperand(0)->getType())) { | ||
computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); | ||
KnownBits SingleKnown = Known; | ||
for (unsigned i = 1, e = VecTy->getNumElements(); i < e; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can reduce the computational complexity from O(n)
to O(lgn)
with exponentiation by squaring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I figured most likely N isn't too big here and this isn't a particularly hot path so it wasn't worth to code complexity, but no real strong opinions.
llvm/lib/Analysis/ValueTracking.cpp
Outdated
case Intrinsic::vector_reduce_fminimum: | ||
computeKnownFPClass(II->getArgOperand(0), Known, InterestedClasses, | ||
Depth + 1, Q); | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please split this into a separate PR for review by FP folks? I'm not sure things are quite as simple as that (e.g., do we need to account for denormal flush like the scalar min/max handling above)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, see: #88408
llvm/lib/Analysis/ValueTracking.cpp
Outdated
@@ -2904,6 +2942,41 @@ static bool isKnownNonZeroFromOperator(const Operator *I, | |||
case Intrinsic::vector_reduce_smax: | |||
case Intrinsic::vector_reduce_smin: | |||
return isKnownNonZero(II->getArgOperand(0), Depth, Q); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we easily handle "or" reduction for isKnownNonZero()?
llvm/lib/Analysis/ValueTracking.cpp
Outdated
break; | ||
return isKnownNonZero(II->getArgOperand(0), Depth, Q); | ||
} | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I don't think this is worthwhile.
4e949bd
to
95dbc18
Compare
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-analysis Author: None (goldsteinn) Changes
Full diff: https://github.com/llvm/llvm-project/pull/88320.diff 3 Files Affected:
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 3a10de72a27562..4e9d446ddb1807 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1621,14 +1621,32 @@ static void computeKnownBitsFromOperator(const Operator *I,
computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
Known = KnownBits::ssub_sat(Known, Known2);
break;
- // for min/max reduce, any bit common to each element in the input vec
- // is set in the output.
+ // for min/max/and/or reduce, any bit common to each element in the
+ // input vec is set in the output.
+ case Intrinsic::vector_reduce_and:
+ case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_umax:
case Intrinsic::vector_reduce_umin:
case Intrinsic::vector_reduce_smax:
case Intrinsic::vector_reduce_smin:
computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
break;
+ case Intrinsic::vector_reduce_xor:
+ computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
+ // The zeros common to all vecs are zero in the output.
+ // If the number of elements is odd, then the common ones remain. If the
+ // number of elements is even, then the common ones becomes zeros.
+ if (auto *VecTy =
+ dyn_cast<FixedVectorType>(I->getOperand(0)->getType())) {
+ // Even, so the ones become zeros.
+ if ((VecTy->getNumElements() % 2) == 0) {
+ Known.Zero |= Known.One;
+ Known.One.clearAllBits();
+ }
+
+ } else
+ Known.One.clearAllBits();
+ break;
case Intrinsic::umin:
computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q);
@@ -2898,7 +2916,8 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth,
II->getArgOperand(0), II->getArgOperand(1),
/*NSW=*/true, /* NUW=*/false);
- // umin/smin/smax/smin of all non-zero elements is always non-zero.
+ // umin/smin/smax/smin/or of all non-zero elements is always non-zero.
+ case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_umax:
case Intrinsic::vector_reduce_umin:
case Intrinsic::vector_reduce_smax:
diff --git a/llvm/test/Transforms/InstCombine/known-bits.ll b/llvm/test/Transforms/InstCombine/known-bits.ll
index d210b19bb7faf2..6c3701a32f227f 100644
--- a/llvm/test/Transforms/InstCombine/known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/known-bits.ll
@@ -999,5 +999,130 @@ define i1 @extract_value_smul_fail(i8 %xx, i8 %yy) {
ret i1 %r
}
+define i8 @known_reduce_or(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_or(
+; CHECK-NEXT: ret i8 1
+;
+ %x = or <2 x i8> %xx, <i8 5, i8 3>
+ %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+ %r = and i8 %v, 1
+ ret i8 %r
+}
+
+define i8 @known_reduce_or_fail(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_or_fail(
+; CHECK-NEXT: [[X:%.*]] = or <2 x i8> [[XX:%.*]], <i8 5, i8 3>
+; CHECK-NEXT: [[V:%.*]] = call i8 @llvm.vector.reduce.or.v2i8(<2 x i8> [[X]])
+; CHECK-NEXT: [[R:%.*]] = and i8 [[V]], 4
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or <2 x i8> %xx, <i8 5, i8 3>
+ %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+ %r = and i8 %v, 4
+ ret i8 %r
+}
+
+define i8 @known_reduce_and(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_and(
+; CHECK-NEXT: ret i8 1
+;
+ %x = or <2 x i8> %xx, <i8 5, i8 3>
+ %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+ %r = and i8 %v, 1
+ ret i8 %r
+}
+
+define i8 @known_reduce_and_fail(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_and_fail(
+; CHECK-NEXT: [[X:%.*]] = or <2 x i8> [[XX:%.*]], <i8 5, i8 3>
+; CHECK-NEXT: [[V:%.*]] = call i8 @llvm.vector.reduce.or.v2i8(<2 x i8> [[X]])
+; CHECK-NEXT: [[R:%.*]] = and i8 [[V]], 2
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or <2 x i8> %xx, <i8 5, i8 3>
+ %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+ %r = and i8 %v, 2
+ ret i8 %r
+}
+
+define i8 @known_reduce_xor_even(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_even(
+; CHECK-NEXT: ret i8 0
+;
+ %x = or <2 x i8> %xx, <i8 5, i8 3>
+ %v = call i8 @llvm.vector.reduce.xor(<2 x i8> %x)
+ %r = and i8 %v, 1
+ ret i8 %r
+}
+
+define i8 @known_reduce_xor_even2(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_even2(
+; CHECK-NEXT: ret i8 0
+;
+ %x = and <2 x i8> %xx, <i8 15, i8 15>
+ %v = call i8 @llvm.vector.reduce.xor(<2 x i8> %x)
+ %r = and i8 %v, 16
+ ret i8 %r
+}
+
+define i8 @known_reduce_xor_even_fail(<2 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_even_fail(
+; CHECK-NEXT: [[X:%.*]] = or <2 x i8> [[XX:%.*]], <i8 5, i8 3>
+; CHECK-NEXT: [[V:%.*]] = call i8 @llvm.vector.reduce.xor.v2i8(<2 x i8> [[X]])
+; CHECK-NEXT: [[R:%.*]] = and i8 [[V]], 2
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or <2 x i8> %xx, <i8 5, i8 3>
+ %v = call i8 @llvm.vector.reduce.xor(<2 x i8> %x)
+ %r = and i8 %v, 2
+ ret i8 %r
+}
+
+define i8 @known_reduce_xor_odd(<3 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_odd(
+; CHECK-NEXT: ret i8 1
+;
+ %x = or <3 x i8> %xx, <i8 5, i8 3, i8 9>
+ %v = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> %x)
+ %r = and i8 %v, 1
+ ret i8 %r
+}
+
+define i8 @known_reduce_xor_odd2(<3 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_odd2(
+; CHECK-NEXT: ret i8 0
+;
+ %x = and <3 x i8> %xx, <i8 15, i8 15, i8 31>
+ %v = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> %x)
+ %r = and i8 %v, 32
+ ret i8 %r
+}
+
+define i8 @known_reduce_xor_odd2_fail(<3 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_odd2_fail(
+; CHECK-NEXT: [[X:%.*]] = and <3 x i8> [[XX:%.*]], <i8 15, i8 15, i8 31>
+; CHECK-NEXT: [[V:%.*]] = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> [[X]])
+; CHECK-NEXT: [[R:%.*]] = and i8 [[V]], 16
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = and <3 x i8> %xx, <i8 15, i8 15, i8 31>
+ %v = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> %x)
+ %r = and i8 %v, 16
+ ret i8 %r
+}
+
+define i8 @known_reduce_xor_odd_fail(<3 x i8> %xx) {
+; CHECK-LABEL: @known_reduce_xor_odd_fail(
+; CHECK-NEXT: [[X:%.*]] = or <3 x i8> [[XX:%.*]], <i8 5, i8 3, i8 9>
+; CHECK-NEXT: [[V:%.*]] = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> [[X]])
+; CHECK-NEXT: [[R:%.*]] = and i8 [[V]], 2
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %x = or <3 x i8> %xx, <i8 5, i8 3, i8 9>
+ %v = call i8 @llvm.vector.reduce.xor.v3i8(<3 x i8> %x)
+ %r = and i8 %v, 2
+ ret i8 %r
+}
+
declare void @use(i1)
declare void @sink(i8)
diff --git a/llvm/test/Transforms/InstSimplify/known-non-zero.ll b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
index d9b8f5eed32390..fd2862eb04a24d 100644
--- a/llvm/test/Transforms/InstSimplify/known-non-zero.ll
+++ b/llvm/test/Transforms/InstSimplify/known-non-zero.ll
@@ -377,3 +377,26 @@ define <2 x i1> @insert_nonzero_any_idx_fail(<2 x i8> %xx, i8 %yy, i32 %idx) {
%r = icmp eq <2 x i8> %ins, zeroinitializer
ret <2 x i1> %r
}
+
+define i1 @nonzero_reduce_or(<2 x i8> %xx) {
+; CHECK-LABEL: @nonzero_reduce_or(
+; CHECK-NEXT: ret i1 false
+;
+ %x = add nuw <2 x i8> %xx, <i8 1, i8 1>
+ %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+ %r = icmp eq i8 %v, 0
+ ret i1 %r
+}
+
+define i1 @nonzero_reduce_or_fail(<2 x i8> %xx) {
+; CHECK-LABEL: @nonzero_reduce_or_fail(
+; CHECK-NEXT: [[X:%.*]] = add nsw <2 x i8> [[XX:%.*]], <i8 1, i8 1>
+; CHECK-NEXT: [[V:%.*]] = call i8 @llvm.vector.reduce.or.v2i8(<2 x i8> [[X]])
+; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[V]], 0
+; CHECK-NEXT: ret i1 [[R]]
+;
+ %x = add nsw <2 x i8> %xx, <i8 1, i8 1>
+ %v = call i8 @llvm.vector.reduce.or(<2 x i8> %x)
+ %r = icmp eq i8 %v, 0
+ ret i1 %r
+}
|
llvm.vector.reduce.*
ops.llvm.vector.reduce.{xor,or,and}
ops.
llvm.vector.reduce.{xor,or,and}
ops.llvm.vector.reduce.{xor,or,and}
ops.
llvm/lib/Analysis/ValueTracking.cpp
Outdated
if (auto *VecTy = | ||
dyn_cast<FixedVectorType>(I->getOperand(0)->getType())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can also handle scalable if you need to know an even count?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How? IIUC the only thing we know about scalable is that sizeof(vec) == N * sizeof(scalar) where N >= MinCount. How can we get odd/even?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's <scale * N * element>, if you know N is even it doesn't matter what scale is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, thanks. Done.
95dbc18
to
ed99b2a
Compare
llvm/lib/Analysis/ValueTracking.cpp
Outdated
// number of elements is even, then the common ones becomes zeros. | ||
auto *VecTy = cast<VectorType>(I->getOperand(0)->getType()); | ||
// Even, so the ones become zeros. | ||
if (VecTy->getElementCount().isKnownEven()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this have an else that sets the ones to unknown? Otherwise what about something like vscale x 3
which is not known even but might be.
(Actually, I'd just condition the whole xor code on isKnownEven, we don't care about odd-size vectors anyway.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you're right. Pushed fix/test.
Its very minimal code complexity to handle odd, so left it in for now (although think I agree its moreso for completeness than practical applicability). LMK if you feel strongly about dropping.
…ce.{or,and}`; NFC
ed99b2a
to
c6ed849
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
// If the number of elements is odd, then the common ones remain. If the | ||
// number of elements is even, then the common ones becomes zeros. | ||
auto *VecTy = cast<VectorType>(I->getOperand(0)->getType()); | ||
// Even, so the ones become zeros. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant with the preceding comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guess the idea is the top comment explains what the block will do. The lower comments explain what the next few lines will do. With this minimal code though yeah it does become a bit redundant, but comments are like sex so...
computeKnownBits
ofllvm.vector.reduce.{or,and}
; NFCcomputeKnownBits
forllvm.vector.reduce.{or,and}
computeKnownBits
ofllvm.vector.reduce.xor
; NFCcomputeKnownBits
forllvm.vector.reduce.xor
isKnownNonZero
ofllvm.vector.reduce.or
; NFCisKnownNonZero
forllvm.vector.reduce.or