Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Simplify (X / C0) * C1 + (X % C0) * C2 to (X / C0) * (C1 - C2 * C0) + X * C2 #76285

Merged
merged 4 commits into from
Apr 24, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Dec 23, 2023

Since DivRemPairPass runs after ReassociatePass in the optimization pipeline, I decided to do this simplification in InstCombine.

Alive2: https://alive2.llvm.org/ce/z/Jgsiqf
Fixes #76128.

@dtcxzyw dtcxzyw requested a review from nikic as a code owner December 23, 2023 13:28
@dtcxzyw dtcxzyw changed the title [InstCombine] Simplify (X / C0) * C1 + (X % C0) * C2 to `(X / C0) * (C1 - C2 * C0) + X * C2 [InstCombine] Simplify (X / C0) * C1 + (X % C0) * C2 to (X / C0) * (C1 - C2 * C0) + X * C2 Dec 23, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 23, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

Since DivRemPairPass runs after ReassociatePass in the optimization pipeline, I decided to do this simplification in InstCombine.

Alive2: https://alive2.llvm.org/ce/z/xxfgqo
Fixes #76128.


Full diff: https://github.com/llvm/llvm-project/pull/76285.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+29)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+4)
  • (modified) llvm/test/Transforms/InstCombine/add4.ll (+162)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 719a2678fc189a..a0d5c0943b210b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1132,6 +1132,8 @@ static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) {
 
 // Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1)
 // does not overflow.
+// Simplifies (X / C0) * C1 + (X % C0) * C2 to
+// (X / C0) * (C1 - C2 * C0) + X * C2
 Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) {
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
   Value *X, *MulOpV;
@@ -1159,6 +1161,33 @@ Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) {
     }
   }
 
+  // Match I = (X / C0) * C1 + (X % C0) * C2
+  Value *Div, *Rem;
+  APInt C1, C2;
+  if (!LHS->hasOneUse() || !MatchMul(LHS, Div, C1))
+    Div = LHS, C1 = APInt(I.getType()->getScalarSizeInBits(), 1);
+  if (!RHS->hasOneUse() || !MatchMul(RHS, Rem, C2))
+    Rem = RHS, C2 = APInt(I.getType()->getScalarSizeInBits(), 1);
+  if (match(Div, m_IRem(m_Value(), m_Value()))) {
+    std::swap(Div, Rem);
+    std::swap(C1, C2);
+  }
+  Value *DivOpV;
+  APInt DivOpC;
+  if (MatchRem(Rem, X, C0, IsSigned) &&
+      MatchDiv(Div, DivOpV, DivOpC, IsSigned) && X == DivOpV && C0 == DivOpC) {
+    if (!isGuaranteedNotToBeUndef(X, &AC, &I, &DT))
+      return nullptr;
+    APInt NewC = C1 - C2 * C0;
+    if (!NewC.isZero() && !Rem->hasOneUse())
+      return nullptr;
+    Value *MulXC2 = Builder.CreateMul(X, ConstantInt::get(X->getType(), C2));
+    if (NewC.isZero())
+      return MulXC2;
+    return Builder.CreateAdd(
+        Builder.CreateMul(Div, ConstantInt::get(X->getType(), NewC)), MulXC2);
+  }
+
   return nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 5e362f4117d051..281a335ef78c25 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3872,6 +3872,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
     }
   }
 
+  if (cast<PossiblyDisjointInst>(I).isDisjoint())
+    if (Value *V = SimplifyAddWithRemainder(I))
+      return replaceInstUsesWith(I, V);
+
   return nullptr;
 }
 
diff --git a/llvm/test/Transforms/InstCombine/add4.ll b/llvm/test/Transforms/InstCombine/add4.ll
index 7fc164c8b9a7c9..77f7fc7b35cd44 100644
--- a/llvm/test/Transforms/InstCombine/add4.ll
+++ b/llvm/test/Transforms/InstCombine/add4.ll
@@ -1,6 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
+declare void @use(i32)
+
 define i64 @match_unsigned(i64 %x) {
 ; CHECK-LABEL: @match_unsigned(
 ; CHECK-NEXT:    [[UREM:%.*]] = urem i64 [[X:%.*]], 19136
@@ -127,3 +129,163 @@ define i32 @not_match_overflow(i32 %x) {
   %t4 = add i32 %t, %t3
   ret i32 %t4
 }
+
+; Tests from PR76128.
+define i32 @fold_add_udiv_urem(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_sdiv_srem(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_sdiv_srem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = sdiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nsw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = sdiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = srem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_to_mul(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_to_mul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ADD:%.*]] = mul i32 [[VAL:%.*]], 3
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 7
+  %mul1 = mul i32 %div, 21
+  %rem = urem i32 %val, 7
+  %mul2 = mul i32 %rem, 3
+  %add = add i32 %mul1, %mul2
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_to_mul_multiuse(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_to_mul_multiuse(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL:%.*]], 7
+; CHECK-NEXT:    call void @use(i32 [[REM]])
+; CHECK-NEXT:    [[ADD:%.*]] = mul i32 [[VAL]], 3
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 7
+  %mul1 = mul i32 %div, 21
+  %rem = urem i32 %val, 7
+  call void @use(i32 %rem)
+  %mul2 = mul i32 %rem, 3
+  %add = add i32 %mul1, %mul2
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_commuted(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_commuted(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = add i32 %rem, %shl
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_or_disjoint(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_or_disjoint(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw i32 [[DIV]], 6
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[TMP0]], [[VAL]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = or disjoint i32 %shl, %rem
+  ret i32 %add
+}
+; Negative tests
+define i32 @fold_add_udiv_urem_without_noundef(i32 %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_without_noundef(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_multiuse_mul(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_urem_multiuse_mul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    call void @use(i32 [[SHL]])
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  call void @use(i32 %shl)
+  %rem = urem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_srem(i32 noundef %val) {
+; CHECK-LABEL: @fold_add_udiv_srem(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], 10
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = srem i32 [[VAL]], 10
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, 10
+  %shl = shl i32 %div, 4
+  %rem = srem i32 %val, 10
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}
+define i32 @fold_add_udiv_urem_non_constant(i32 noundef %val, i32 noundef %c) {
+; CHECK-LABEL: @fold_add_udiv_urem_non_constant(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DIV:%.*]] = udiv i32 [[VAL:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[DIV]], 4
+; CHECK-NEXT:    [[REM:%.*]] = urem i32 [[VAL]], [[C]]
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[SHL]], [[REM]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %div = udiv i32 %val, %c
+  %shl = shl i32 %div, 4
+  %rem = urem i32 %val, %c
+  %add = add i32 %shl, %rem
+  ret i32 %add
+}

@@ -3872,6 +3872,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
}
}

if (cast<PossiblyDisjointInst>(I).isDisjoint())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is ugly. But I don't know the best approach to treat or disjoint as add nsw nuw.

It is a phase ordering issue. We cannot assume that ReassociatePass will revert the canonicalization add -> or disjoint. BTW, I see this change achieves higher coverage in my benchmark (in openexr/ImfTimeCode.cpp and z3/mpf.cpp).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably have something like visitAddLike that is called from both visitAdd and visitOr where we can place folds we want to apply both to add and or disjoint?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a less ugly way at least would be use m_AddLike matcher. But I agree with nikic.

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Dec 23, 2023
@dtcxzyw
Copy link
Member Author

dtcxzyw commented Jan 31, 2024

Ping.

Copy link

github-actions bot commented Jan 31, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@nikic
Copy link
Contributor

nikic commented Apr 19, 2024

As it took me a while to get it: The core of the transform is that x % c is x - x / c * c (https://alive2.llvm.org/ce/z/I8CBWy).

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine to me but needs a rebase.

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Apr 19, 2024

This looks fine to me but needs a rebase.

Done.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw dtcxzyw merged commit 945eeb2 into llvm:main Apr 24, 2024
2 of 4 checks passed
@dtcxzyw dtcxzyw deleted the fold-add-div-rem branch April 24, 2024 09:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Reassociate] Missing optimization: fold div(v, a) * b + rem(v, a) to div(v, a) * (b - a) + v
4 participants