Skip to content

Conversation

@Michael-Chen-NJU
Copy link
Contributor

This patch implements a missing optimization identified in issue #163110.

It folds the scaled equality comparison: (X << Log2) == ((Y << Log2) + K) into the simpler form: X == (Y + 1)

where $K = 2^{\text{Log}2}$ (i.e., $K = 1 \ll \text{Log}2$).

Rationale

This is a valid algebraic simplification under the nsw (No Signed Wrap) constraint, which allows the entire equation to be safely divided (right-shifted) by $2^{\text{Log}2}$.

Implementation Details

  1. Uses m_c_NSWAdd and m_c_ICmp to handle commutativity.
  2. Uses APInt::getOneBitSet for an elegant check of the constant $K$.

The patch includes tests covering eq/ne predicates and various integer bit widths.

Fixes: #163110

@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms labels Nov 1, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 1, 2025

@llvm/pr-subscribers-llvm-transforms

Author: 陈子昂 (Michael-Chen-NJU)

Changes

This patch implements a missing optimization identified in issue #163110.

It folds the scaled equality comparison: (X &lt;&lt; Log2) == ((Y &lt;&lt; Log2) + K) into the simpler form: X == (Y + 1)

where $K = 2^{\text{Log}2}$ (i.e., $K = 1 \ll \text{Log}2$).

Rationale

This is a valid algebraic simplification under the nsw (No Signed Wrap) constraint, which allows the entire equation to be safely divided (right-shifted) by $2^{\text{Log}2}$.

Implementation Details

  1. Uses m_c_NSWAdd and m_c_ICmp to handle commutativity.
  2. Uses APInt::getOneBitSet for an elegant check of the constant $K$.

The patch includes tests covering eq/ne predicates and various integer bit widths.

Fixes: #163110


Full diff: https://github.com/llvm/llvm-project/pull/165975.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+20)
  • (added) llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll (+108)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fba1ccf2c8c9b..28d3c772acdcc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6001,6 +6001,26 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
   const CmpInst::Predicate Pred = I.getPredicate();
+
+  //icmp (shl nsw X, Log2), (add nsw (shl nsw Y, Log2), K) -> icmp X, (add nsw Y, 1)
+  Value *X, *Y;
+  ConstantInt *CLog2M0, *CLog2M1, *CVal;
+  auto M0 = m_NSWShl(m_Value(X), m_ConstantInt(CLog2M0));
+  auto M1 = m_NSWAdd(m_NSWShl (m_Value(Y), m_ConstantInt(CLog2M1)),
+                      m_ConstantInt(CVal));
+
+  if (match(&I, m_c_ICmp(M0, M1)) && CLog2M0->getValue() == CLog2M1->getValue()) {
+    unsigned BitWidth = CLog2M0->getBitWidth();
+    unsigned ShAmt = (unsigned)CLog2M0->getLimitedValue(BitWidth);
+    APInt ExpectedK = APInt::getOneBitSet(BitWidth, ShAmt);
+    if (CVal->getValue() == ExpectedK) {
+      Value *NewRHS = Builder.CreateAdd(
+          Y, ConstantInt::get(Y->getType(), 1),
+          "", /*HasNUW=*/false, /*HasNSW=*/true);
+      return new ICmpInst(Pred, X, NewRHS);
+    }
+  }
+  
   Value *A, *B, *C, *D;
   if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) {
     if (A == Op1 || B == Op1) { // (A^B) == A  ->  B == 0
diff --git a/llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll b/llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll
new file mode 100644
index 0000000000000..0f375a05528a2
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll
@@ -0,0 +1,108 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+; Test case: Fold (X << 5) == ((Y << 5) + 32) into X == (Y + 1).
+; This corresponds to the provided alive2 proof.
+
+define i1 @shl_add_const_eq_base(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_base(
+; CHECK-NEXT:    [[V5:%.*]] = add nsw i64 [[V3:%.*]], 1
+; CHECK-NEXT:    [[V6:%.*]] = icmp eq i64 [[V1:%.*]], [[V5]]
+; CHECK-NEXT:    ret i1 [[V6]]
+;
+  %v1 = shl nsw i64 %v0, 5
+  %v4 = shl nsw i64 %v3, 5
+  %v5 = add nsw i64 %v4, 32
+  %v6 = icmp eq i64 %v1, %v5
+  ret i1 %v6
+}
+
+; Test: icmp ne
+define i1 @shl_add_const_ne(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_ne(
+; CHECK-NEXT:    [[V5:%.*]] = add nsw i64 [[V3:%.*]], 1
+; CHECK-NEXT:    [[V6:%.*]] = icmp ne i64 [[V1:%.*]], [[V5]]
+; CHECK-NEXT:    ret i1 [[V6]]
+;
+  %v1 = shl nsw i64 %v0, 5
+  %v4 = shl nsw i64 %v3, 5
+  %v5 = add nsw i64 %v4, 32
+  %v6 = icmp ne i64 %v1, %v5 ; Note: icmp ne
+  ret i1 %v6
+}
+
+; Test: shl amounts do not match (5 vs 4).
+define i1 @shl_add_const_eq_mismatch_shl_amt(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_mismatch_shl_amt(
+; CHECK-NEXT:    [[V1:%.*]] = shl nsw i64 [[V0:%.*]], 5
+; CHECK-NEXT:    [[V4:%.*]] = shl nsw i64 [[V3:%.*]], 4
+; CHECK-NEXT:    [[V5:%.*]] = add nsw i64 [[V4]], 16
+; CHECK-NEXT:    [[V6:%.*]] = icmp eq i64 [[V1]], [[V5]]
+; CHECK-NEXT:    ret i1 [[V6]]
+;
+  %v1 = shl nsw i64 %v0, 5
+  %v4 = shl nsw i64 %v3, 4  ; Shift amount mismatch
+  %v5 = add nsw i64 %v4, 16
+  %v6 = icmp eq i64 %v1, %v5
+  ret i1 %v6
+}
+
+; Test: Constant is wrong (32 vs 64).
+define i1 @shl_add_const_eq_wrong_constant(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_wrong_constant(
+; CHECK-NEXT:    [[V1:%.*]] = shl nsw i64 [[V0:%.*]], 5
+; CHECK-NEXT:    [[V4:%.*]] = shl nsw i64 [[V3:%.*]], 5
+; CHECK-NEXT:    [[V5:%.*]] = add nsw i64 [[V4]], 64
+; CHECK-NEXT:    [[V6:%.*]] = icmp eq i64 [[V1]], [[V5]]
+; CHECK-NEXT:    ret i1 [[V6]]
+;
+  %v1 = shl nsw i64 %v0, 5
+  %v4 = shl nsw i64 %v3, 5
+  %v5 = add nsw i64 %v4, 64  ; Constant mismatch
+  %v6 = icmp eq i64 %v1, %v5
+  ret i1 %v6
+}
+
+; Test: Missing NSW flag on one of the shl instructions.
+define i1 @shl_add_const_eq_no_nsw_on_v1(i64 %v0, i64 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_no_nsw_on_v1(
+; CHECK-NEXT:    [[V1:%.*]] = shl i64 [[V0:%.*]], 5
+; CHECK-NEXT:    [[V4:%.*]] = shl nsw i64 [[V3:%.*]], 5
+; CHECK-NEXT:    [[V5:%.*]] = add nsw i64 [[V4]], 32
+; CHECK-NEXT:    [[V6:%.*]] = icmp eq i64 [[V1]], [[V5]]
+; CHECK-NEXT:    ret i1 [[V6]]
+;
+  %v1 = shl i64 %v0, 5 ; Missing nsw
+  %v4 = shl nsw i64 %v3, 5
+  %v5 = add nsw i64 %v4, 32
+  %v6 = icmp eq i64 %v1, %v5
+  ret i1 %v6
+}
+
+; Test: Lower bit width (i8) and different shift amount (3). Constant is 8.
+define i1 @shl_add_const_eq_i8(i8 %v0, i8 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_i8(
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i8 [[V3:%.*]], 1
+; CHECK-NEXT:    [[V6:%.*]] = icmp eq i8 [[V0:%.*]], [[TMP1]]
+; CHECK-NEXT:    ret i1 [[V6]]
+;
+  %v1 = shl nsw i8 %v0, 3
+  %v4 = shl nsw i8 %v3, 3
+  %v5 = add nsw i8 %v4, 8 ; 2^3 = 8
+  %v6 = icmp eq i8 %v1, %v5
+  ret i1 %v6
+}
+
+; Test: i32 bit width and larger shift amount (10). Constant is 1024.
+define i1 @shl_add_const_eq_i32(i32 %v0, i32 %v3) {
+; CHECK-LABEL: @shl_add_const_eq_i32(
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i32 [[V3:%.*]], 1
+; CHECK-NEXT:    [[V6:%.*]] = icmp eq i32 [[V0:%.*]], [[TMP1]]
+; CHECK-NEXT:    ret i1 [[V6]]
+;
+  %v1 = shl nsw i32 %v0, 10
+  %v4 = shl nsw i32 %v3, 10
+  %v5 = add nsw i32 %v4, 1024 ; 2^10 = 1024
+  %v6 = icmp eq i32 %v1, %v5
+  ret i1 %v6
+}

@github-actions
Copy link

github-actions bot commented Nov 1, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we perform this fold with canEvaluateShifted + getShiftedValue? You can use the known bits to estimate the number of bits to be shifted.

@Michael-Chen-NJU
Copy link
Contributor Author

Can we perform this fold with canEvaluateShifted + getShiftedValue? You can use the known bits to estimate the number of bits to be shifted.

Regarding your suggestion to use canEvaluateShifted and getShiftedValue:

My current approach for folding the pattern (X << L) == ((Y << L) + K) into X == (Y + 1) is a direct pattern match and substitution in foldICmpEquality. Since this optimization relies on the exact structure where $K = 2^L$ and requires the nsw flag for the algebraic simplification (effectively dividing by $2^L$), I thought a targeted pattern match was the most straightforward way to implement this specific, proven optimization.

Are you suggesting we should aim for a more general, recursive shift distribution through these APIs?If so, I'm a bit unclear:
Is this specific optimization expected to be handled by a more generic shift distribution rule?Or are you suggesting that using getShiftedValue is necessary to handle other, potentially more complex, but unproven patterns that might contain this specific case as a sub-pattern?

If the goal is just to implement the $X = Y + 1$ simplification for this exact pattern, does my current code not already meet the requirement?Any further clarification on the intended scope would be super helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Missed Optimization: simplify scaled equality to direct +1 comparison

3 participants