-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement foldICmpRemConstant in InstCombineCompares #77410
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-transforms Author: None (Baxi-codes) ChangesFull diff: https://github.com/llvm/llvm-project/pull/77410.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7c1aff445524de..0add51b8175555 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2572,6 +2572,46 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask));
}
+Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
+ BinaryOperator *Rem,
+ const APInt &C) {
+ const ICmpInst::Predicate Pred = Cmp.getPredicate();
+ Value *X = Rem->getOperand(0);
+ Value *Y = Rem->getOperand(1);
+
+ // Check if the remainder operation is in the required form.
+ if (!isa<BinaryOperator>(X) || !isa<BinaryOperator>(Y))
+ return nullptr;
+
+ BinaryOperator *MulX = cast<BinaryOperator>(X);
+ BinaryOperator *MulY = cast<BinaryOperator>(Y);
+
+ // Check if the operands are multiplication operations.
+ if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul)
+ return nullptr;
+
+ // Get the multiplication operands and constants.
+ Value *A = MulX->getOperand(0);
+ Value *C1 = MulX->getOperand(1);
+ Value *B = MulY->getOperand(0);
+ Value *C2 = MulY->getOperand(1);
+
+ const APInt *C1Value, *C2Value;
+
+ // Check if the constants satisfy the condition c1 % c2 == 0.
+ if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0)
+ return nullptr;
+
+ // Compute the new constant k = c1 / c2.
+ APInt K = C1Value->udiv(*C2Value);
+ Type *Ty = Rem->getType();
+
+ // Create a new remainder instruction (a * k) % b.
+ Value *NewRem = Builder.CreateURem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
+ return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
+}
+
+
/// Fold icmp (udiv X, Y), C.
Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
BinaryOperator *UDiv,
@@ -2963,7 +3003,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
// Fold icmp pred (add X, C2), C.
Type *Ty = Add->getType();
-
+
// If the add does not wrap, we can always adjust the compare by subtracting
// the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
// are canonicalized to SGT/SLT/UGT/ULT.
@@ -3708,7 +3748,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
case Instruction::SRem:
if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
return I;
+ [[fallthrough]];
+ case Instruction::URem:
+ if (Instruction *I = foldICmpRemConstant(Cmp, BO, C))
+ return I;
break;
+
case Instruction::UDiv:
if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
return I;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 21c61bd990184d..748fe04c470e46 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -670,7 +670,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
const APInt &C);
Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr,
const APInt &C);
- Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
+ Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *SRem,
+ const APInt &C);
+ Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
const APInt &C);
Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
const APInt &C);
|
You can test this locally with the following command:git-clang-format --diff f6dbd4cc5f52b6d40f98cf09af22b276b8e1f289 1ebf04dceb32e44d90209070c3d6e6d7349d6559 -- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp llvm/lib/Transforms/InstCombine/InstCombineInternal.h View the diff from clang-format here.diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0add51b817..dadfae6f45 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2587,7 +2587,8 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
BinaryOperator *MulY = cast<BinaryOperator>(Y);
// Check if the operands are multiplication operations.
- if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul)
+ if (MulX->getOpcode() != Instruction::Mul ||
+ MulY->getOpcode() != Instruction::Mul)
return nullptr;
// Get the multiplication operands and constants.
@@ -2599,7 +2600,8 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
const APInt *C1Value, *C2Value;
// Check if the constants satisfy the condition c1 % c2 == 0.
- if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0)
+ if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) ||
+ C1Value->urem(*C2Value) != 0)
return nullptr;
// Compute the new constant k = c1 / c2.
@@ -2607,11 +2609,11 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
Type *Ty = Rem->getType();
// Create a new remainder instruction (a * k) % b.
- Value *NewRem = Builder.CreateURem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
+ Value *NewRem = Builder.CreateURem(
+ Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
}
-
/// Fold icmp (udiv X, Y), C.
Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
BinaryOperator *UDiv,
@@ -3003,7 +3005,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
// Fold icmp pred (add X, C2), C.
Type *Ty = Add->getType();
-
+
// If the add does not wrap, we can always adjust the compare by subtracting
// the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
// are canonicalized to SGT/SLT/UGT/ULT.
@@ -3748,10 +3750,10 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
case Instruction::SRem:
if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
return I;
- [[fallthrough]];
+ [[fallthrough]];
case Instruction::URem:
- if (Instruction *I = foldICmpRemConstant(Cmp, BO, C))
- return I;
+ if (Instruction *I = foldICmpRemConstant(Cmp, BO, C))
+ return I;
break;
case Instruction::UDiv:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 748fe04c47..9be32f69bf 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -673,7 +673,7 @@ public:
Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *SRem,
const APInt &C);
Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
- const APInt &C);
+ const APInt &C);
Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
const APInt &C);
Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I applied your patch locally. It works for me:
> bin/opt -passes=instcombine test.ll -S
; ModuleID = 'test.ll'
source_filename = "test.ll"
define i1 @src(i8 noundef %0, i8 noundef %1) {
%3 = mul i8 %0, 3
%4 = urem i8 %3, %1
%5 = icmp eq i8 %4, 0
ret i1 %5
}
@@ -2963,7 +3003,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, | |||
|
|||
// Fold icmp pred (add X, C2), C. | |||
Type *Ty = Add->getType(); | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop the indents here.
break; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop the newline here.
Please run git clang-format HEAD~1
and amend your last commit before pushing.
@@ -3708,7 +3748,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, | |||
case Instruction::SRem: | |||
if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C)) | |||
return I; | |||
[[fallthrough]]; | |||
case Instruction::URem: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add some tests?
You should add tests in the first commit. Then you should apply your code changes and show the test changes in the second commit.
Example: https://github.com/llvm/llvm-project/pull/76685/commits
CAUTION: Please generate check lines using llvm/utils/update_test_checks.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do I do this? Do I undo the changes, add tests, and commit, then do the changes again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp, | ||
BinaryOperator *Rem, | ||
const APInt &C) { | ||
const ICmpInst::Predicate Pred = Cmp.getPredicate(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const ICmpInst::Predicate Pred = Cmp.getPredicate(); | |
assert((Rem->getOpcode() == Instruction::SRem || Rem->getOpcode() == Instruction::URem) && "Only for srem/urem!"); | |
const ICmpInst::Predicate Pred = Cmp.getPredicate(); |
BinaryOperator *Rem, | ||
const APInt &C) { | ||
const ICmpInst::Predicate Pred = Cmp.getPredicate(); | ||
Value *X = Rem->getOperand(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Value *X = Rem->getOperand(0); | |
if (!ICmpInst::isEquality(Pred) || !C.isZero()) | |
return nullptr; | |
Value *X = Rem->getOperand(0); |
It only holds for ==/!= 0
.
// Check if the remainder operation is in the required form. | ||
if (!isa<BinaryOperator>(X) || !isa<BinaryOperator>(Y)) | ||
return nullptr; | ||
|
||
BinaryOperator *MulX = cast<BinaryOperator>(X); | ||
BinaryOperator *MulY = cast<BinaryOperator>(Y); | ||
|
||
// Check if the operands are multiplication operations. | ||
if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul) | ||
return nullptr; | ||
|
||
// Get the multiplication operands and constants. | ||
Value *A = MulX->getOperand(0); | ||
Value *C1 = MulX->getOperand(1); | ||
Value *B = MulY->getOperand(0); | ||
Value *C2 = MulY->getOperand(1); | ||
|
||
const APInt *C1Value, *C2Value; | ||
|
||
// Check if the constants satisfy the condition c1 % c2 == 0. | ||
if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0) | ||
return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Check if the remainder operation is in the required form. | |
if (!isa<BinaryOperator>(X) || !isa<BinaryOperator>(Y)) | |
return nullptr; | |
BinaryOperator *MulX = cast<BinaryOperator>(X); | |
BinaryOperator *MulY = cast<BinaryOperator>(Y); | |
// Check if the operands are multiplication operations. | |
if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul) | |
return nullptr; | |
// Get the multiplication operands and constants. | |
Value *A = MulX->getOperand(0); | |
Value *C1 = MulX->getOperand(1); | |
Value *B = MulY->getOperand(0); | |
Value *C2 = MulY->getOperand(1); | |
const APInt *C1Value, *C2Value; | |
// Check if the constants satisfy the condition c1 % c2 == 0. | |
if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0) | |
return nullptr; | |
Value *A, *B; | |
const APInt *C1, *C2; | |
if (Rem->getOpcode() == Instruction::SRem) { | |
if (!match(X, m_NSWMul(m_Value(A), m_APInt(C1)))) | |
retur nullptr; | |
if (!match(Y, m_NSWMul(m_Value(B), m_APInt(C2)))) | |
retur nullptr; | |
if (!C1->srem(*C2).isZero()) | |
return nullptr; | |
} | |
else { | |
... // match with m_NUWMul and use urem | |
} |
Be careful of nsw/nuw
flags.
return nullptr; | ||
|
||
// Compute the new constant k = c1 / c2. | ||
APInt K = C1Value->udiv(*C2Value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For srem, use sdiv
instead.
Type *Ty = Rem->getType(); | ||
|
||
// Create a new remainder instruction (a * k) % b. | ||
Value *NewRem = Builder.CreateURem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same problem as above.
BTW, I think we can preserve the NSW/NUW flags for the mul.
Please update your PR description and attach alive2 proofs. |
Closed this pr, will create a new one |
nsw with srem doesn't verify |
https://alive2.llvm.org/ce/z/R73iWv
|
I am not sure how to implement that transformation. |
Please review the new PR |
Doesn't work right now