-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[InstCombine] Fix #163110: Fold icmp (shl X, L), (add (shl Y, L), 1<<L) to icmp X, (Y + 1) #165975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[InstCombine] Fix #163110: Fold icmp (shl X, L), (add (shl Y, L), 1<<L) to icmp X, (Y + 1) #165975
Conversation
…= ((Y << Log2) + K) into X == (Y + 1)
|
@llvm/pr-subscribers-llvm-transforms Author: 陈子昂 (Michael-Chen-NJU) ChangesThis patch implements a missing optimization identified in issue #163110. It folds the scaled equality comparison: where RationaleThis is a valid algebraic simplification under the Implementation Details
The patch includes tests covering Fixes: #163110 Full diff: https://github.com/llvm/llvm-project/pull/165975.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fba1ccf2c8c9b..28d3c772acdcc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6001,6 +6001,26 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
const CmpInst::Predicate Pred = I.getPredicate();
+
+ //icmp (shl nsw X, Log2), (add nsw (shl nsw Y, Log2), K) -> icmp X, (add nsw Y, 1)
+ Value *X, *Y;
+ ConstantInt *CLog2M0, *CLog2M1, *CVal;
+ auto M0 = m_NSWShl(m_Value(X), m_ConstantInt(CLog2M0));
+ auto M1 = m_NSWAdd(m_NSWShl (m_Value(Y), m_ConstantInt(CLog2M1)),
+ m_ConstantInt(CVal));
+
+ if (match(&I, m_c_ICmp(M0, M1)) && CLog2M0->getValue() == CLog2M1->getValue()) {
+ unsigned BitWidth = CLog2M0->getBitWidth();
+ unsigned ShAmt = (unsigned)CLog2M0->getLimitedValue(BitWidth);
+ APInt ExpectedK = APInt::getOneBitSet(BitWidth, ShAmt);
+ if (CVal->getValue() == ExpectedK) {
+ Value *NewRHS = Builder.CreateAdd(
+ Y, ConstantInt::get(Y->getType(), 1),
+ "", /*HasNUW=*/false, /*HasNSW=*/true);
+ return new ICmpInst(Pred, X, NewRHS);
+ }
+ }
+
Value *A, *B, *C, *D;
if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) {
if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0
diff --git a/llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll b/llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll
new file mode 100644
index 0000000000000..0f375a05528a2
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll
@@ -0,0 +1,108 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+; Test case: Fold (X << 5) == ((Y << 5) + 32) into X == (Y + 1).
+; This corresponds to the provided alive2 proof.
+
+define i1 @shl_add_const_eq_base(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_base(
+; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V3:%.*]], 1
+; CHECK-NEXT: [[V6:%.*]] = icmp eq i64 [[V1:%.*]], [[V5]]
+; CHECK-NEXT: ret i1 [[V6]]
+;
+ %v1 = shl nsw i64 %v0, 5
+ %v4 = shl nsw i64 %v3, 5
+ %v5 = add nsw i64 %v4, 32
+ %v6 = icmp eq i64 %v1, %v5
+ ret i1 %v6
+}
+
+; Test: icmp ne
+define i1 @shl_add_const_ne(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_ne(
+; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V3:%.*]], 1
+; CHECK-NEXT: [[V6:%.*]] = icmp ne i64 [[V1:%.*]], [[V5]]
+; CHECK-NEXT: ret i1 [[V6]]
+;
+ %v1 = shl nsw i64 %v0, 5
+ %v4 = shl nsw i64 %v3, 5
+ %v5 = add nsw i64 %v4, 32
+ %v6 = icmp ne i64 %v1, %v5 ; Note: icmp ne
+ ret i1 %v6
+}
+
+; Test: shl amounts do not match (5 vs 4).
+define i1 @shl_add_const_eq_mismatch_shl_amt(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_mismatch_shl_amt(
+; CHECK-NEXT: [[V1:%.*]] = shl nsw i64 [[V0:%.*]], 5
+; CHECK-NEXT: [[V4:%.*]] = shl nsw i64 [[V3:%.*]], 4
+; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V4]], 16
+; CHECK-NEXT: [[V6:%.*]] = icmp eq i64 [[V1]], [[V5]]
+; CHECK-NEXT: ret i1 [[V6]]
+;
+ %v1 = shl nsw i64 %v0, 5
+ %v4 = shl nsw i64 %v3, 4 ; Shift amount mismatch
+ %v5 = add nsw i64 %v4, 16
+ %v6 = icmp eq i64 %v1, %v5
+ ret i1 %v6
+}
+
+; Test: Constant is wrong (32 vs 64).
+define i1 @shl_add_const_eq_wrong_constant(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_wrong_constant(
+; CHECK-NEXT: [[V1:%.*]] = shl nsw i64 [[V0:%.*]], 5
+; CHECK-NEXT: [[V4:%.*]] = shl nsw i64 [[V3:%.*]], 5
+; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V4]], 64
+; CHECK-NEXT: [[V6:%.*]] = icmp eq i64 [[V1]], [[V5]]
+; CHECK-NEXT: ret i1 [[V6]]
+;
+ %v1 = shl nsw i64 %v0, 5
+ %v4 = shl nsw i64 %v3, 5
+ %v5 = add nsw i64 %v4, 64 ; Constant mismatch
+ %v6 = icmp eq i64 %v1, %v5
+ ret i1 %v6
+}
+
+; Test: Missing NSW flag on one of the shl instructions.
+define i1 @shl_add_const_eq_no_nsw_on_v1(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_no_nsw_on_v1(
+; CHECK-NEXT: [[V1:%.*]] = shl i64 [[V0:%.*]], 5
+; CHECK-NEXT: [[V4:%.*]] = shl nsw i64 [[V3:%.*]], 5
+; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V4]], 32
+; CHECK-NEXT: [[V6:%.*]] = icmp eq i64 [[V1]], [[V5]]
+; CHECK-NEXT: ret i1 [[V6]]
+;
+ %v1 = shl i64 %v0, 5 ; Missing nsw
+ %v4 = shl nsw i64 %v3, 5
+ %v5 = add nsw i64 %v4, 32
+ %v6 = icmp eq i64 %v1, %v5
+ ret i1 %v6
+}
+
+; Test: Lower bit width (i8) and different shift amount (3). Constant is 8.
+define i1 @shl_add_const_eq_i8(i8 %v0, i8 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_i8(
+; CHECK-NEXT: [[TMP1:%.*]] = add nsw i8 [[V3:%.*]], 1
+; CHECK-NEXT: [[V6:%.*]] = icmp eq i8 [[V0:%.*]], [[TMP1]]
+; CHECK-NEXT: ret i1 [[V6]]
+;
+ %v1 = shl nsw i8 %v0, 3
+ %v4 = shl nsw i8 %v3, 3
+ %v5 = add nsw i8 %v4, 8 ; 2^3 = 8
+ %v6 = icmp eq i8 %v1, %v5
+ ret i1 %v6
+}
+
+; Test: i32 bit width and larger shift amount (10). Constant is 1024.
+define i1 @shl_add_const_eq_i32(i32 %v0, i32 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_i32(
+; CHECK-NEXT: [[TMP1:%.*]] = add nsw i32 [[V3:%.*]], 1
+; CHECK-NEXT: [[V6:%.*]] = icmp eq i32 [[V0:%.*]], [[TMP1]]
+; CHECK-NEXT: ret i1 [[V6]]
+;
+ %v1 = shl nsw i32 %v0, 10
+ %v4 = shl nsw i32 %v3, 10
+ %v5 = add nsw i32 %v4, 1024 ; 2^10 = 1024
+ %v6 = icmp eq i32 %v1, %v5
+ ret i1 %v6
+}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
dtcxzyw
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we perform this fold with canEvaluateShifted + getShiftedValue? You can use the known bits to estimate the number of bits to be shifted.
…icmp-shl-add optimization
Regarding your suggestion to use canEvaluateShifted and getShiftedValue: My current approach for folding the pattern Are you suggesting we should aim for a more general, recursive shift distribution through these APIs?If so, I'm a bit unclear: If the goal is just to implement the |
This patch implements a missing optimization identified in issue #163110.
It folds the scaled equality comparison:
(X << Log2) == ((Y << Log2) + K)into the simpler form:X == (Y + 1)where$K = 2^{\text{Log}2}$ (i.e., $K = 1 \ll \text{Log}2$ ).
Rationale
This is a valid algebraic simplification under the$2^{\text{Log}2}$ .
nsw(No Signed Wrap) constraint, which allows the entire equation to be safely divided (right-shifted) byImplementation Details
m_c_NSWAddandm_c_ICmpto handle commutativity.APInt::getOneBitSetfor an elegant check of the constantThe patch includes tests covering
eq/nepredicates and various integer bit widths.Fixes: #163110