-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[InstCombine] Fold @llvm.experimental.get.vector.length when cnt <= max_lanes #169293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ax_lanes On RISC-V, some loops that the loop vectorizer vectorizes pre-LTO may turn out to have the exact trip count exposed after LTO, see llvm#164762. If the trip count is small enough we can fold away the @llvm.experimental.get.vector.length intrinsic based on this corollary from the LangRef: > If %cnt is less than or equal to %max_lanes, the return value is equal to %cnt. This on its own doesn't remove the @llvm.experimental.get.vector.length in llvm#164762 since we also need to teach computeKnownBits about @llvm.experimental.get.vector.length and the sub recurrence, but this PR is a starting point. I've added this in InstCombine rather than InstSimplify since we may need to insert a truncation (@llvm.experimental.get.vector.length can take an i64 %cnt argument, but always truncates the result to i32). Note that there was something similar done in VPlan in llvm#167647 for when the loop vectorizer knows the trip count.
|
@llvm/pr-subscribers-llvm-transforms Author: Luke Lau (lukel97) ChangesOn RISC-V, some loops that the loop vectorizer vectorizes pre-LTO may turn out to have the exact trip count exposed after LTO, see #164762. If the trip count is small enough we can fold away the @llvm.experimental.get.vector.length intrinsic based on this corollary from the LangRef: > If %cnt is less than or equal to %max_lanes, the return value is equal to %cnt. This on its own doesn't remove the @llvm.experimental.get.vector.length in #164762 since we also need to teach computeKnownBits about @llvm.experimental.get.vector.length and the sub recurrence, but this PR is a starting point. I've added this in InstCombine rather than InstSimplify since we may need to insert a truncation (@llvm.experimental.get.vector.length can take an i64 %cnt argument, but always truncates the result to i32). Note that there was something similar done in VPlan in #167647 for when the loop vectorizer knows the trip count. Full diff: https://github.com/llvm/llvm-project/pull/169293.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 8e4edefec42fd..247f615ed0b54 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -4005,6 +4005,22 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
break;
}
+ case Intrinsic::experimental_get_vector_length: {
+ // get.vector.length(Cnt, MaxLanes) --> Cnt when Cnt <= MaxLanes
+ ConstantRange Cnt = computeConstantRangeIncludingKnownBits(
+ II->getArgOperand(0), false, SQ.getWithInstruction(II));
+ ConstantRange MaxLanes = cast<ConstantInt>(II->getArgOperand(1))
+ ->getValue()
+ .zext(Cnt.getBitWidth());
+ if (cast<ConstantInt>(II->getArgOperand(2))->getZExtValue())
+ MaxLanes = MaxLanes.multiply(
+ getVScaleRange(II->getFunction(), Cnt.getBitWidth()));
+
+ if (Cnt.icmp(CmpInst::ICMP_ULE, MaxLanes))
+ return replaceInstUsesWith(
+ *II, Builder.CreateTrunc(II->getArgOperand(0), II->getType()));
+ return nullptr;
+ }
default: {
// Handle target specific intrinsics
std::optional<Instruction *> V = targetInstCombineIntrinsic(*II);
diff --git a/llvm/test/Transforms/InstCombine/get_vector_length.ll b/llvm/test/Transforms/InstCombine/get_vector_length.ll
new file mode 100644
index 0000000000000..96a7f3058c43c
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/get_vector_length.ll
@@ -0,0 +1,80 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
+; RUN: opt < %s -passes=instcombine,verify -S | FileCheck %s
+
+define i32 @cnt_known_lt() {
+; CHECK-LABEL: define i32 @cnt_known_lt() {
+; CHECK-NEXT: ret i32 1
+;
+ %x = call i32 @llvm.experimental.get.vector.length(i32 1, i32 2, i1 false)
+ ret i32 %x
+}
+
+define i32 @cnt_not_known_lt() {
+; CHECK-LABEL: define i32 @cnt_not_known_lt() {
+; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 2, i32 1, i1 false)
+; CHECK-NEXT: ret i32 [[X]]
+;
+ %x = call i32 @llvm.experimental.get.vector.length(i32 2, i32 1, i1 false)
+ ret i32 %x
+}
+
+define i32 @cnt_known_lt_scalable() vscale_range(2, 4) {
+; CHECK-LABEL: define i32 @cnt_known_lt_scalable(
+; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: ret i32 2
+;
+ %x = call i32 @llvm.experimental.get.vector.length(i32 2, i32 1, i1 true)
+ ret i32 %x
+}
+
+define i32 @cnt_not_known_lt_scalable() {
+; CHECK-LABEL: define i32 @cnt_not_known_lt_scalable() {
+; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 2, i32 1, i1 true)
+; CHECK-NEXT: ret i32 [[X]]
+;
+ %x = call i32 @llvm.experimental.get.vector.length(i32 2, i32 1, i1 true)
+ ret i32 %x
+}
+
+define i32 @cnt_known_lt_runtime(i32 %x) {
+; CHECK-LABEL: define i32 @cnt_known_lt_runtime(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[ICMP:%.*]] = icmp ult i32 [[X]], 4
+; CHECK-NEXT: call void @llvm.assume(i1 [[ICMP]])
+; CHECK-NEXT: ret i32 [[X]]
+;
+ %icmp = icmp ule i32 %x, 3
+ call void @llvm.assume(i1 %icmp)
+ %y = call i32 @llvm.experimental.get.vector.length(i32 %x, i32 3, i1 false)
+ ret i32 %y
+}
+
+define i32 @cnt_known_lt_runtime_trunc(i64 %x) {
+; CHECK-LABEL: define i32 @cnt_known_lt_runtime_trunc(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT: [[ICMP:%.*]] = icmp ult i64 [[X]], 4
+; CHECK-NEXT: call void @llvm.assume(i1 [[ICMP]])
+; CHECK-NEXT: [[Y:%.*]] = trunc nuw nsw i64 [[X]] to i32
+; CHECK-NEXT: ret i32 [[Y]]
+;
+ %icmp = icmp ule i64 %x, 3
+ call void @llvm.assume(i1 %icmp)
+ %y = call i32 @llvm.experimental.get.vector.length(i64 %x, i32 3, i1 false)
+ ret i32 %y
+}
+
+; FIXME: We should be able to deduce the constant range from AssumptionCache
+; rather than relying on KnownBits, which in this case only knows x <= 3.
+define i32 @cnt_known_lt_runtime_assumption(i32 %x) {
+; CHECK-LABEL: define i32 @cnt_known_lt_runtime_assumption(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[ICMP:%.*]] = icmp ult i32 [[X]], 3
+; CHECK-NEXT: call void @llvm.assume(i1 [[ICMP]])
+; CHECK-NEXT: [[Y:%.*]] = call i32 @llvm.experimental.get.vector.length.i32(i32 [[X]], i32 2, i1 false)
+; CHECK-NEXT: ret i32 [[Y]]
+;
+ %icmp = icmp ule i32 %x, 2
+ call void @llvm.assume(i1 %icmp)
+ %y = call i32 @llvm.experimental.get.vector.length(i32 %x, i32 2, i1 false)
+ ret i32 %y
+}
|
|
|
||
| if (Cnt.icmp(CmpInst::ICMP_ULE, MaxLanes)) | ||
| return replaceInstUsesWith( | ||
| *II, Builder.CreateTrunc(II->getArgOperand(0), II->getType())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use CreateZextOrTrunc? The cnt parameter has llvm_anyint_ty in Intrinsics.td
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done in 2f71e8a
| ret i32 %y | ||
| } | ||
|
|
||
| ; FIXME: We should be able to deduce the constant range from AssumptionCache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why doesn't this work?
llvm-project/llvm/lib/Analysis/ValueTracking.cpp
Lines 10303 to 10328 in 1580f4b
| if (CtxI && AC) { | |
| // Try to restrict the range based on information from assumptions. | |
| for (auto &AssumeVH : AC->assumptionsFor(V)) { | |
| if (!AssumeVH) | |
| continue; | |
| CallInst *I = cast<CallInst>(AssumeVH); | |
| assert(I->getParent()->getParent() == CtxI->getParent()->getParent() && | |
| "Got assumption for the wrong function!"); | |
| assert(I->getIntrinsicID() == Intrinsic::assume && | |
| "must be an assume intrinsic"); | |
| if (!isValidAssumeForContext(I, CtxI, DT)) | |
| continue; | |
| Value *Arg = I->getArgOperand(0); | |
| ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg); | |
| // Currently we just use information from comparisons. | |
| if (!Cmp || Cmp->getOperand(0) != V) | |
| continue; | |
| // TODO: Set "ForSigned" parameter via Cmp->isSigned()? | |
| ConstantRange RHS = | |
| computeConstantRange(Cmp->getOperand(1), /* ForSigned */ false, | |
| UseInstrInfo, AC, I, DT, Depth + 1); | |
| CR = CR.intersectWith( | |
| ConstantRange::makeAllowedICmpRegion(Cmp->getPredicate(), RHS)); | |
| } | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think computeConstantRangeIncludingKnownBits doesn't pass in the AssumptionCache to computeConstantRange currently. I presume it's a simple fix in another PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw the reason for using computeConstantRangeIncludingKnownBits and not just computeConstantRange directly is because for the motivating case I think we need to be able handle a simple PHI recurrence, which only computeKnownBits seems to be able to handle at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get.vector.length being in the recurrence makes it no longer "simple", doesn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I can tell that is correct, I didn't see anywhere in computeKnownBits where get.vector.length is handled currently. Here's a test for this that could get added to get_vector_length.ll:
define void @loop_get_len() {
; CHECK-LABEL: define void @loop_get_len() {
; CHECK-NEXT: [[ENTRY:.*]]:
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ 16, %[[ENTRY]] ], [ [[REM:%.*]], %[[LOOP]] ]
; CHECK-NEXT: [[GETLEN:%.*]] = tail call i32 @llvm.experimental.get.vector.length.i32(i32 [[PHI]], i32 16, i1 false)
; CHECK-NEXT: [[REM]] = sub i32 [[PHI]], [[GETLEN]]
; CHECK-NEXT: [[EXIT_COND:%.*]] = icmp eq i32 [[REM]], 0
; CHECK-NEXT: br i1 [[EXIT_COND]], label %[[EXIT:.*]], label %[[LOOP]]
; CHECK: [[EXIT]]:
; CHECK-NEXT: ret void
;
entry:
br label %loop
loop:
%phi = phi i32 [ 16, %entry ], [ %rem, %loop ]
%getlen = tail call i32 @llvm.experimental.get.vector.length.i32(i32 %phi, i32 16, i1 false)
%rem = sub i32 %phi, %getlen
%exit_cond = icmp eq i32 %rem, 0
br i1 %exit_cond, label %exit, label %loop
exit:
ret void
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get.vector.length being in the recurrence makes it no longer "simple", doesn't it?
matchSimpleRecurrence doesn't care what the step is so it can still handle get.vector.length in the recurrence. The only issue is that we currently don't deduce anything apart from the low zero bits in computeKnownBitsFromOperator.
From what I can tell that is correct, I didn't see anywhere in computeKnownBits where get.vector.length is handled currently
Yes, we should separately add support for that too.
But for the reduced example in #164762 (comment) I think we only need to teach computeKnownBitsFromOperator that sub recurrences with nuw should only decrease, so the number of leading zeros should increase. I gave this a super quick go and it seems to work, haven't checked if its generally correct yet:
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index dbceb8e55784..c2ba85d4bfac 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1857,6 +1857,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
Known3.countMinTrailingZeros()));
auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(BO);
+
+ if (!OverflowOp)
+ break;
+
+ if (Opcode == Instruction::Sub && Q.IIQ.hasNoUnsignedWrap(OverflowOp)) {
+ Known.Zero.setHighBits(std::min(Known2.countMinLeadingZeros(),
+ Known3.countMinLeadingZeros()));
+ }
+That diff above in combination with this patch is enough to flatten the loop.
Actually Known2 and Known3 weren't what I thought they were, I don't think the above is correct.
| ConstantRange MaxLanes = cast<ConstantInt>(II->getArgOperand(1)) | ||
| ->getValue() | ||
| .zext(Cnt.getBitWidth()); | ||
| if (cast<ConstantInt>(II->getArgOperand(2))->getZExtValue()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When checking to see if a boolean ConstantInt is true, is there a preference for using getZExtValue() versus isOne()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure but I think isOne() is a bit cleaner, thanks!
|
We have some optimizations around this in our downstream. I can extract after Thanskgiving. |
dtcxzyw
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
On RISC-V, some loops that the loop vectorizer vectorizes pre-LTO may turn out to have the exact trip count exposed after LTO, see #164762.
If the trip count is small enough we can fold away the @llvm.experimental.get.vector.length intrinsic based on this corollary from the LangRef:
This on its own doesn't remove the @llvm.experimental.get.vector.length in #164762 since we also need to teach computeKnownBits about @llvm.experimental.get.vector.length and the sub recurrence, but this PR is a starting point.
I've added this in InstCombine rather than InstSimplify since we may need to insert a truncation (@llvm.experimental.get.vector.length can take an i64 %cnt argument, the result is always i32).
Note that there was something similar done in VPlan in #167647 for when the loop vectorizer knows the trip count.