-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[InstCombine] Factorise add/sub and max/min using distributivity #101507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-transforms Author: Jorge Botto (jf-botto) ChangesThis PR fixes part of Issue 92433. The alive proof sometimes times out. I have to add Full diff: https://github.com/llvm/llvm-project/pull/101507.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index cc68fd4cf1c1b..3267d27d703e3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1505,6 +1505,80 @@ foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1,
ConstantInt::getTrue(ZeroUndef->getType()));
}
+/// Return whether "X LOp (Y ROp Z)" is always equal to
+/// "(X LOp Y) ROp (X LOp Z)".
+static bool leftDistributesOverRightIntrinsic(Intrinsic::ID LOp,
+ Intrinsic::ID ROp) {
+ switch (LOp) {
+ case Intrinsic::umax:
+ return ROp == Intrinsic::umin;
+ case Intrinsic::smax:
+ return ROp == Intrinsic::smin;
+ case Intrinsic::umin:
+ return ROp == Intrinsic::umax;
+ case Intrinsic::smin:
+ return ROp == Intrinsic::smax;
+ case Intrinsic::uadd_sat:
+ return ROp == Intrinsic::umax || ROp == Intrinsic::umin;
+ case Intrinsic::sadd_sat:
+ return ROp == Intrinsic::smax || ROp == Intrinsic::smin;
+ default:
+ return false;
+ }
+}
+
+// Attempts to factorise a common term
+// in an instruction that has the form "(A op' B) op (C op' D)
+static Instruction *
+foldCallUsingDistributiveLaws(CallInst *II, InstCombiner::BuilderTy &Builder) {
+ Value *LHS = II->getOperand(0), *RHS = II->getOperand(1);
+ Intrinsic::ID TopLevelOpcode = II->getCalledFunction()->getIntrinsicID();
+
+ if (LHS && RHS) {
+ CallInst *Op0 = dyn_cast<CallInst>(LHS);
+ CallInst *Op1 = dyn_cast<CallInst>(RHS);
+
+ if (!Op0 || !Op1)
+ return nullptr;
+
+ if (Op0->getCalledFunction()->getIntrinsicID() !=
+ Op1->getCalledFunction()->getIntrinsicID())
+ return nullptr;
+
+ Intrinsic::ID InnerOpcode = Op0->getCalledFunction()->getIntrinsicID();
+
+ bool InnerCommutative = Op0->isCommutative();
+ bool Distributive =
+ leftDistributesOverRightIntrinsic(InnerOpcode, TopLevelOpcode);
+
+ Value *A = Op0->getOperand(0);
+ Value *B = Op0->getOperand(1);
+ Value *C = Op1->getOperand(0);
+ Value *D = Op1->getOperand(1);
+
+ if (Distributive && (A == C || (InnerCommutative && A == D))) {
+ if (A != C)
+ std::swap(C, D);
+
+ Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, B, D);
+ Function *F = Intrinsic::getDeclaration(II->getModule(), InnerOpcode,
+ II->getType());
+ return CallInst::Create(F, {NewIntrinsic, A});
+ }
+
+ if (Distributive && InnerCommutative && (B == D || B == C)) {
+ if (B != D)
+ std::swap(C, D);
+
+ Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, A, C);
+ Function *F = Intrinsic::getDeclaration(II->getModule(), InnerOpcode,
+ II->getType());
+ return CallInst::Create(F, {NewIntrinsic, B});
+ }
+ }
+ return nullptr;
+}
+
/// CallInst simplification. This mostly only handles folding of intrinsic
/// instructions. For normal calls, it allows visitCallBase to do the heavy
/// lifting.
@@ -1731,6 +1805,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
foldMinimumOverTrailingOrLeadingZeroCount<Intrinsic::ctlz>(
I0, I1, DL, Builder))
return replaceInstUsesWith(*II, FoldedCtlz);
+
+ if (Instruction *I = foldCallUsingDistributiveLaws(II, Builder)) {
+ return I;
+ }
+
[[fallthrough]];
}
case Intrinsic::umax: {
@@ -1751,9 +1830,18 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
// If both operands of unsigned min/max are sign-extended, it is still ok
// to narrow the operation.
+
+ if (Instruction *I = foldCallUsingDistributiveLaws(II, Builder))
+ return I;
+
+ [[fallthrough]];
+ }
+ case Intrinsic::smax: {
+ if (Instruction *I = foldCallUsingDistributiveLaws(II, Builder))
+ return I;
+
[[fallthrough]];
}
- case Intrinsic::smax:
case Intrinsic::smin: {
Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1);
Value *X, *Y;
@@ -1929,6 +2017,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
+ if (Instruction *I = foldCallUsingDistributiveLaws(II, Builder)) {
+ return I;
+ }
+
break;
}
case Intrinsic::bitreverse: {
diff --git a/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll b/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll
new file mode 100644
index 0000000000000..10f4d8bbc7a0d
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll
@@ -0,0 +1,210 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine < %s 2>&1 | FileCheck %s
+
+define i32 @umin_of_umax(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umin_of_umax(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umax.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MIN]]
+;
+ %max1 = call i32 @llvm.umax.i32(i32 %x, i32 %z)
+ %max2 = call i32 @llvm.umax.i32(i32 %y, i32 %z)
+ %min = call i32 @llvm.umin.i32(i32 %max1, i32 %max2)
+ ret i32 %min
+}
+
+define i32 @umin_of_umax_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umin_of_umax_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umax.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MIN]]
+;
+ %max1 = call i32 @llvm.umax.i32(i32 %z, i32 %x)
+ %max2 = call i32 @llvm.umax.i32(i32 %z, i32 %y)
+ %min = call i32 @llvm.umin.i32(i32 %max1, i32 %max2)
+ ret i32 %min
+}
+
+define i32 @smin_of_smax(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smin_of_smax(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MIN]]
+;
+ %max1 = call i32 @llvm.smax.i32(i32 %x, i32 %z)
+ %max2 = call i32 @llvm.smax.i32(i32 %y, i32 %z)
+ %min = call i32 @llvm.smin.i32(i32 %max1, i32 %max2)
+ ret i32 %min
+}
+
+define i32 @smin_of_smax_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smin_of_smax_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smax.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MIN]]
+;
+ %max1 = call i32 @llvm.smax.i32(i32 %z, i32 %x)
+ %max2 = call i32 @llvm.smax.i32(i32 %z, i32 %y)
+ %min = call i32 @llvm.smin.i32(i32 %max1, i32 %max2)
+ ret i32 %min
+}
+
+define i32 @umax_of_umin(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umax_of_umin(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umin.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MAX]]
+;
+ %min1 = call i32 @llvm.umin.i32(i32 %x, i32 %z)
+ %min2 = call i32 @llvm.umin.i32(i32 %y, i32 %z)
+ %max = call i32 @llvm.umax.i32(i32 %min1, i32 %min2)
+ ret i32 %max
+}
+
+define i32 @umax_of_umin_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umax_of_umin_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umin.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MAX]]
+;
+ %min1 = call i32 @llvm.umin.i32(i32 %z, i32 %x)
+ %min2 = call i32 @llvm.umin.i32(i32 %z, i32 %y)
+ %max = call i32 @llvm.umax.i32(i32 %min1, i32 %min2)
+ ret i32 %max
+}
+
+define i32 @smax_of_smin(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smax_of_smin(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MAX]]
+;
+ %min1 = call i32 @llvm.smin.i32(i32 %x, i32 %z)
+ %min2 = call i32 @llvm.smin.i32(i32 %y, i32 %z)
+ %max = call i32 @llvm.smax.i32(i32 %min1, i32 %min2)
+ ret i32 %max
+}
+
+define i32 @smax_of_smin_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smax_of_smin_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MAX]]
+;
+ %min1 = call i32 @llvm.smin.i32(i32 %z, i32 %x)
+ %min2 = call i32 @llvm.smin.i32(i32 %z, i32 %y)
+ %max = call i32 @llvm.smax.i32(i32 %min1, i32 %min2)
+ ret i32 %max
+}
+
+define i32 @umax_of_uadd_sat(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umax_of_uadd_sat(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MAX]]
+;
+ %add1 = call i32 @llvm.uadd.sat.i32(i32 %x, i32 %z)
+ %add2 = call i32 @llvm.uadd.sat.i32(i32 %y, i32 %z)
+ %max = call i32 @llvm.umax.i32(i32 %add1, i32 %add2)
+ ret i32 %max
+}
+
+define i32 @umax_of_uadd_sat_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umax_of_uadd_sat_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.umax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MAX]]
+;
+ %add1 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %x)
+ %add2 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %y)
+ %max = call i32 @llvm.umax.i32(i32 %add1, i32 %add2)
+ ret i32 %max
+}
+
+define i32 @umin_of_uadd_sat(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umin_of_uadd_sat(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MIN]]
+;
+ %add1 = call i32 @llvm.uadd.sat.i32(i32 %x, i32 %z)
+ %add2 = call i32 @llvm.uadd.sat.i32(i32 %y, i32 %z)
+ %min = call i32 @llvm.umin.i32(i32 %add1, i32 %add2)
+ ret i32 %min
+}
+
+define i32 @umin_of_uadd_sat_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @umin_of_uadd_sat_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MIN]]
+;
+ %add1 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %x)
+ %add2 = call i32 @llvm.uadd.sat.i32(i32 %z, i32 %y)
+ %min = call i32 @llvm.umin.i32(i32 %add1, i32 %add2)
+ ret i32 %min
+}
+
+define i32 @smax_of_sadd_sat(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smax_of_sadd_sat(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MAX]]
+;
+ %add1 = call i32 @llvm.sadd.sat.i32(i32 %x, i32 %z)
+ %add2 = call i32 @llvm.sadd.sat.i32(i32 %y, i32 %z)
+ %max = call i32 @llvm.smax.i32(i32 %add1, i32 %add2)
+ ret i32 %max
+}
+
+define i32 @smax_of_sadd_sat_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smax_of_sadd_sat_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.smax.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MAX]]
+;
+ %add1 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %x)
+ %add2 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %y)
+ %max = call i32 @llvm.smax.i32(i32 %add1, i32 %add2)
+ ret i32 %max
+}
+
+define i32 @smin_of_sadd_sat(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smin_of_sadd_sat(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MIN]]
+;
+ %add1 = call i32 @llvm.sadd.sat.i32(i32 %x, i32 %z)
+ %add2 = call i32 @llvm.sadd.sat.i32(i32 %y, i32 %z)
+ %min = call i32 @llvm.smin.i32(i32 %add1, i32 %add2)
+ ret i32 %min
+}
+
+define i32 @smin_of_sadd_sat_comm(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: define i32 @smin_of_sadd_sat_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[Z]])
+; CHECK-NEXT: ret i32 [[MIN]]
+;
+ %add1 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %x)
+ %add2 = call i32 @llvm.sadd.sat.i32(i32 %z, i32 %y)
+ %min = call i32 @llvm.smin.i32(i32 %add1, i32 %add2)
+ ret i32 %min
+}
|
Try replacing all the i32s with i8. That usually helps with timeouts. |
In this case even |
Thank you! That really helps. I will update the alive link. I may have to separate the tests into 3-4 links. |
Two of the proofs are buggy: https://alive2.llvm.org/ce/z/Xz2-Vq ( |
It works for individual tests.
Thank you. Just noticed that. Will fix it. |
case Intrinsic::smax: { | ||
if (Instruction *I = foldCallUsingDistributiveLaws(II, Builder)) | ||
return I; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think you need the call in the umin
/umax
/smax
cases, they all fallthrough to smin
(and it seems a bit silly to potentially fail 4x times on the same call).
Thank you @goldsteinn for the time and effort you've put into reviewing my code. I've addressed all of your points. |
With one-use check, this patch doesn't change anything :( I think we should handle the interesting cases I listed in #92433 (comment) first. |
@dtcxzyw, I was planning on making a second PR built upon this to handle the cases you mention. Would you rather me use this PR just for those cases you mention? |
Please file a separate PR. |
This PR fixes part of Issue 92433.
The alive proof sometimes times out. I have to add
noundef
to certain variables to make it not time out more consistently. Given how min/max/add preserve undef, I believe these optimisations to be correct when undef values are passed through.Proofs: