Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement foldICmpRemConstant in InstCombineCompares #77410

Closed
wants to merge 0 commits into from
Closed

Implement foldICmpRemConstant in InstCombineCompares #77410

wants to merge 0 commits into from

Conversation

Baxi-codes
Copy link

@Baxi-codes Baxi-codes commented Jan 9, 2024

Doesn't work right now

@Baxi-codes Baxi-codes requested a review from nikic as a code owner January 9, 2024 05:02
Copy link

github-actions bot commented Jan 9, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 9, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (Baxi-codes)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/77410.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+46-1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+3-1)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7c1aff445524de..0add51b8175555 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2572,6 +2572,46 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
   return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask));
 }
 
+Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
+                                                   BinaryOperator *Rem,
+                                                   const APInt &C) {
+  const ICmpInst::Predicate Pred = Cmp.getPredicate();
+  Value *X = Rem->getOperand(0);
+  Value *Y = Rem->getOperand(1);
+
+  // Check if the remainder operation is in the required form.
+  if (!isa<BinaryOperator>(X) || !isa<BinaryOperator>(Y))
+    return nullptr;
+
+  BinaryOperator *MulX = cast<BinaryOperator>(X);
+  BinaryOperator *MulY = cast<BinaryOperator>(Y);
+
+  // Check if the operands are multiplication operations.
+  if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul)
+    return nullptr;
+
+  // Get the multiplication operands and constants.
+  Value *A = MulX->getOperand(0);
+  Value *C1 = MulX->getOperand(1);
+  Value *B = MulY->getOperand(0);
+  Value *C2 = MulY->getOperand(1);
+
+  const APInt *C1Value, *C2Value;
+
+  // Check if the constants satisfy the condition c1 % c2 == 0.
+  if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0)
+    return nullptr;
+
+  // Compute the new constant k = c1 / c2.
+  APInt K = C1Value->udiv(*C2Value);
+  Type *Ty = Rem->getType();
+
+  // Create a new remainder instruction (a * k) % b.
+  Value *NewRem = Builder.CreateURem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
+  return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
+}
+
+
 /// Fold icmp (udiv X, Y), C.
 Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
                                                     BinaryOperator *UDiv,
@@ -2963,7 +3003,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
 
   // Fold icmp pred (add X, C2), C.
   Type *Ty = Add->getType();
-
+  
   // If the add does not wrap, we can always adjust the compare by subtracting
   // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
   // are canonicalized to SGT/SLT/UGT/ULT.
@@ -3708,7 +3748,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
   case Instruction::SRem:
     if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
       return I;
+   [[fallthrough]]; 
+  case Instruction::URem:
+    if (Instruction *I = foldICmpRemConstant(Cmp, BO, C)) 
+      return I; 
     break;
+
   case Instruction::UDiv:
     if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
       return I;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 21c61bd990184d..748fe04c470e46 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -670,7 +670,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                    const APInt &C);
   Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr,
                                    const APInt &C);
-  Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
+  Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *SRem,
+                                    const APInt &C);
+  Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
                                     const APInt &C);
   Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
                                     const APInt &C);

Copy link

github-actions bot commented Jan 9, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff f6dbd4cc5f52b6d40f98cf09af22b276b8e1f289 1ebf04dceb32e44d90209070c3d6e6d7349d6559 -- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp llvm/lib/Transforms/InstCombine/InstCombineInternal.h
View the diff from clang-format here.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0add51b817..dadfae6f45 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2587,7 +2587,8 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
   BinaryOperator *MulY = cast<BinaryOperator>(Y);
 
   // Check if the operands are multiplication operations.
-  if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul)
+  if (MulX->getOpcode() != Instruction::Mul ||
+      MulY->getOpcode() != Instruction::Mul)
     return nullptr;
 
   // Get the multiplication operands and constants.
@@ -2599,7 +2600,8 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
   const APInt *C1Value, *C2Value;
 
   // Check if the constants satisfy the condition c1 % c2 == 0.
-  if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0)
+  if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) ||
+      C1Value->urem(*C2Value) != 0)
     return nullptr;
 
   // Compute the new constant k = c1 / c2.
@@ -2607,11 +2609,11 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
   Type *Ty = Rem->getType();
 
   // Create a new remainder instruction (a * k) % b.
-  Value *NewRem = Builder.CreateURem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
+  Value *NewRem = Builder.CreateURem(
+      Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
   return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
 }
 
-
 /// Fold icmp (udiv X, Y), C.
 Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
                                                     BinaryOperator *UDiv,
@@ -3003,7 +3005,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
 
   // Fold icmp pred (add X, C2), C.
   Type *Ty = Add->getType();
-  
+
   // If the add does not wrap, we can always adjust the compare by subtracting
   // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
   // are canonicalized to SGT/SLT/UGT/ULT.
@@ -3748,10 +3750,10 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
   case Instruction::SRem:
     if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
       return I;
-   [[fallthrough]]; 
+    [[fallthrough]];
   case Instruction::URem:
-    if (Instruction *I = foldICmpRemConstant(Cmp, BO, C)) 
-      return I; 
+    if (Instruction *I = foldICmpRemConstant(Cmp, BO, C))
+      return I;
     break;
 
   case Instruction::UDiv:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 748fe04c47..9be32f69bf 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -673,7 +673,7 @@ public:
   Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *SRem,
                                     const APInt &C);
   Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
-                                    const APInt &C);
+                                   const APInt &C);
   Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
                                     const APInt &C);
   Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div,

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I applied your patch locally. It works for me:

> bin/opt -passes=instcombine test.ll -S
; ModuleID = 'test.ll'
source_filename = "test.ll"

define i1 @src(i8 noundef %0, i8 noundef %1) {
  %3 = mul i8 %0, 3
  %4 = urem i8 %3, %1
  %5 = icmp eq i8 %4, 0
  ret i1 %5
}

@@ -2963,7 +3003,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,

// Fold icmp pred (add X, C2), C.
Type *Ty = Add->getType();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop the indents here.

break;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop the newline here.
Please run git clang-format HEAD~1 and amend your last commit before pushing.

@@ -3708,7 +3748,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
case Instruction::SRem:
if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
return I;
[[fallthrough]];
case Instruction::URem:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some tests?
You should add tests in the first commit. Then you should apply your code changes and show the test changes in the second commit.
Example: https://github.com/llvm/llvm-project/pull/76685/commits
CAUTION: Please generate check lines using llvm/utils/update_test_checks.py.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do I do this? Do I undo the changes, add tests, and commit, then do the changes again?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
BinaryOperator *Rem,
const APInt &C) {
const ICmpInst::Predicate Pred = Cmp.getPredicate();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const ICmpInst::Predicate Pred = Cmp.getPredicate();
assert((Rem->getOpcode() == Instruction::SRem || Rem->getOpcode() == Instruction::URem) && "Only for srem/urem!");
const ICmpInst::Predicate Pred = Cmp.getPredicate();

BinaryOperator *Rem,
const APInt &C) {
const ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Rem->getOperand(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Value *X = Rem->getOperand(0);
if (!ICmpInst::isEquality(Pred) || !C.isZero())
return nullptr;
Value *X = Rem->getOperand(0);

It only holds for ==/!= 0.

Comment on lines 2582 to 2603
// Check if the remainder operation is in the required form.
if (!isa<BinaryOperator>(X) || !isa<BinaryOperator>(Y))
return nullptr;

BinaryOperator *MulX = cast<BinaryOperator>(X);
BinaryOperator *MulY = cast<BinaryOperator>(Y);

// Check if the operands are multiplication operations.
if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul)
return nullptr;

// Get the multiplication operands and constants.
Value *A = MulX->getOperand(0);
Value *C1 = MulX->getOperand(1);
Value *B = MulY->getOperand(0);
Value *C2 = MulY->getOperand(1);

const APInt *C1Value, *C2Value;

// Check if the constants satisfy the condition c1 % c2 == 0.
if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0)
return nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Check if the remainder operation is in the required form.
if (!isa<BinaryOperator>(X) || !isa<BinaryOperator>(Y))
return nullptr;
BinaryOperator *MulX = cast<BinaryOperator>(X);
BinaryOperator *MulY = cast<BinaryOperator>(Y);
// Check if the operands are multiplication operations.
if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul)
return nullptr;
// Get the multiplication operands and constants.
Value *A = MulX->getOperand(0);
Value *C1 = MulX->getOperand(1);
Value *B = MulY->getOperand(0);
Value *C2 = MulY->getOperand(1);
const APInt *C1Value, *C2Value;
// Check if the constants satisfy the condition c1 % c2 == 0.
if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0)
return nullptr;
Value *A, *B;
const APInt *C1, *C2;
if (Rem->getOpcode() == Instruction::SRem) {
if (!match(X, m_NSWMul(m_Value(A), m_APInt(C1))))
retur nullptr;
if (!match(Y, m_NSWMul(m_Value(B), m_APInt(C2))))
retur nullptr;
if (!C1->srem(*C2).isZero())
return nullptr;
}
else {
... // match with m_NUWMul and use urem
}

Be careful of nsw/nuw flags.

return nullptr;

// Compute the new constant k = c1 / c2.
APInt K = C1Value->udiv(*C2Value);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For srem, use sdiv instead.

Type *Ty = Rem->getType();

// Create a new remainder instruction (a * k) % b.
Value *NewRem = Builder.CreateURem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same problem as above.
BTW, I think we can preserve the NSW/NUW flags for the mul.

@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 9, 2024

Doesn't work right now

Please update your PR description and attach alive2 proofs.

@Baxi-codes Baxi-codes closed this Jan 20, 2024
@Baxi-codes
Copy link
Author

Closed this pr, will create a new one

@Baxi-codes
Copy link
Author

Doesn't work right now

Please update your PR description and attach alive2 proofs.

nsw with srem doesn't verify
https://alive2.llvm.org/ce/z/n7wYL5

@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 22, 2024

Doesn't work right now

Please update your PR description and attach alive2 proofs.

nsw with srem doesn't verify https://alive2.llvm.org/ce/z/n7wYL5

https://alive2.llvm.org/ce/z/R73iWv

sdiv INT_MIN, -1 will overflow.

@Baxi-codes
Copy link
Author

Doesn't work right now

Please update your PR description and attach alive2 proofs.

nsw with srem doesn't verify https://alive2.llvm.org/ce/z/n7wYL5

https://alive2.llvm.org/ce/z/R73iWv

sdiv INT_MIN, -1 will overflow.

I am not sure how to implement that transformation.
The original transformation turns out to be correct if both nsw and nuw flags are set
https://alive2.llvm.org/ce/z/THjHEk

@Baxi-codes
Copy link
Author

Please review the new PR
#79383

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[InstCombine] Missed optimization for (x * z) % (y * z) == 0 => x % y == 0
3 participants