Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Missed optimization for (x * z) % (y * z) == 0 => x % y == 0 #76585

Open
Kmeakin opened this issue Dec 29, 2023 · 8 comments
Open

Comments

@Kmeakin
Copy link
Contributor

Kmeakin commented Dec 29, 2023

(x * z) % (y * z) == 0 => x % y == 0 when the multiplications/modulus do not overflow:

alive

define dso_local i1 @src1(i8 noundef %0, i8 noundef %1, i8 noundef %2) {
  %4 = mul nuw i8 %2, %0
  %5 = mul nuw i8 %2, %1
  %6 = urem i8 %4, %5
  %7 = icmp eq i8 %6, 0
  ret i1 %7
}

define dso_local i1 @tgt1(i8 noundef %0, i8 noundef %1, i8 noundef %2) {
  %4 = urem i8 %0, %1
  %5 = icmp eq i8 %4, 0
  ret i1 %5
}

define dso_local i1 @src2(i8 noundef %0, i8 noundef %1, i8 noundef %2) {
  %4 = icmp ne i8 %1, -1
  tail call void @llvm.assume(i1 %4)
  %5 = mul nsw i8 %2, %0
  %6 = mul nsw i8 %2, %1
  %7 = srem i8 %5, %6
  %8 = icmp eq i8 %7, 0
  ret i1 %8
}


define dso_local i1 @tgt2(i8 noundef %0, i8 noundef %1, i8 noundef %2) {
  %4 = srem i8 %0, %1
  %5 = icmp eq i8 %4, 0
  ret i1 %5
}

declare void @llvm.assume(i1 noundef)
@RKSimon
Copy link
Collaborator

RKSimon commented Dec 30, 2023

We fail to remove constants with common factors as well:

define  i1 @src(i8 noundef %0, i8 noundef %1) {
  %3 = mul nuw i8 9, %0
  %4 = mul nuw i8 3, %1
  %5 = urem i8 %3, %4
  %6 = icmp eq i8 %5, 0
  ret i1 %6
}

define  i1 @tgt(i8 noundef %0, i8 noundef %1) {
  %3 = mul nuw i8 3, %0
  %4 = urem i8 %3, %1
  %5 = icmp eq i8 %4, 0
  ret i1 %5
}
Transformation seems to be correct!

@Baxi-codes
Copy link

Hello, I would like to work on this issue, can someone guide me through it?

@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 4, 2024

Hello, I would like to work on this issue, can someone guide me through it?

You can implement the folding here:

case Instruction::SRem:
if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
return I;
break;
case Instruction::UDiv:
if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
return I;
[[fallthrough]];
case Instruction::SDiv:
if (Instruction *I = foldICmpDivConstant(Cmp, BO, C))
return I;
break;

case Instruction::SRem: 
   if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C)) 
     return I; 
- break;
+   [[fallthrough]]; 
+ case Instruction::URem:
+   if (Instruction *I = foldICmpRemConstant(Cmp, BO, C)) 
+     return I; 
+   break;
 case Instruction::UDiv: 
   if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) 
     return I; 
   [[fallthrough]]; 
 case Instruction::SDiv: 
   if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) 
     return I; 
   break; 

@Baxi-codes
Copy link

Baxi-codes commented Jan 7, 2024

Sorry for the late question, please confirm if I understand this correctly. I have to implement foldICmpRemConstant(Cmp, BO, C) which folds expressions where (a*c1)%(b*c2) == c3 => (a*k) % b == c3 if c1%c2 is 0 and k = c1/c2?

@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 7, 2024

Sorry for the late question, please confirm if I understand this correctly. I have to implement foldICmpRemConstant(Cmp, BO, C) which folds expressions where (a*c1)%(b*c2) == c3 => (a*k) % b == c3 if c1%c2 is 0 and k = c1/c2?

Yeah, you are right. Be careful with the nsw/nuw flags in mul :)

@Baxi-codes
Copy link

I managed to write something, but it didn't work at all. I run it on:

define i1 @src(i8 noundef %0, i8 noundef %1) {
  %3 = mul nuw i8 9, %0
  %4 = mul nuw i8 3, %1
  %5 = urem i8 %3, %4
  %6 = icmp eq i8 %5, 0
  ret i1 %6
}

and it doesn't change anything:

; ModuleID = '/home/dhairya/dev/tmp/test.ll'
source_filename = "/home/dhairya/dev/tmp/test.ll"

define i1 @src(i8 noundef %0, i8 noundef %1) {
  %3 = mul nuw i8 9, %0
  %4 = mul nuw i8 3, %1
  %5 = urem i8 %3, %4
  %6 = icmp eq i8 %5, 0
  ret i1 %6
}

Is something wrong with the match?

Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
                                                   BinaryOperator *Rem,
                                                   const APInt &C) {
  const ICmpInst::Predicate Pred = Cmp.getPredicate();
  Value *X = Rem->getOperand(0);
  Value *Y = Rem->getOperand(1);

  // Check if the remainder operation is in the required form.
  if (!isa<BinaryOperator>(X) || !isa<BinaryOperator>(Y))
    return nullptr;

  BinaryOperator *MulX = cast<BinaryOperator>(X);
  BinaryOperator *MulY = cast<BinaryOperator>(Y);

  // Check if the operands are multiplication operations.
  if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul)
    return nullptr;

  // Get the multiplication operands and constants.
  Value *A = MulX->getOperand(0);
  Value *C1 = MulX->getOperand(1);
  Value *B = MulY->getOperand(0);
  Value *C2 = MulY->getOperand(1);

  const APInt *C1Value, *C2Value;

  // Check if the constants satisfy the condition c1 % c2 == 0.
  if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0)
    return nullptr;

  // Compute the new constant k = c1 / c2.
  APInt K = C1Value->udiv(*C2Value);
  Type *Ty = Rem->getType();

  // Create a new remainder instruction (a * k) % b.
  Value *NewRem = Builder.CreateURem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
  return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
}

Please be patient with me, as I am still in my learning phase 😅

Also, I found that in InstCombineInternal.h, it had

Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
                                    const APInt &C);

instead of

Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *SRem,
                                    const APInt &C);

So I have changed it.

@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 8, 2024

I managed to write something, but it didn't work at all. I run it on:

define i1 @src(i8 noundef %0, i8 noundef %1) {
  %3 = mul nuw i8 9, %0
  %4 = mul nuw i8 3, %1
  %5 = urem i8 %3, %4
  %6 = icmp eq i8 %5, 0
  ret i1 %6
}

and it doesn't change anything:

; ModuleID = '/home/dhairya/dev/tmp/test.ll'
source_filename = "/home/dhairya/dev/tmp/test.ll"

define i1 @src(i8 noundef %0, i8 noundef %1) {
  %3 = mul nuw i8 9, %0
  %4 = mul nuw i8 3, %1
  %5 = urem i8 %3, %4
  %6 = icmp eq i8 %5, 0
  ret i1 %6
}

Is something wrong with the match?

Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
                                                   BinaryOperator *Rem,
                                                   const APInt &C) {
  const ICmpInst::Predicate Pred = Cmp.getPredicate();
  Value *X = Rem->getOperand(0);
  Value *Y = Rem->getOperand(1);

  // Check if the remainder operation is in the required form.
  if (!isa<BinaryOperator>(X) || !isa<BinaryOperator>(Y))
    return nullptr;

  BinaryOperator *MulX = cast<BinaryOperator>(X);
  BinaryOperator *MulY = cast<BinaryOperator>(Y);

  // Check if the operands are multiplication operations.
  if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul)
    return nullptr;

  // Get the multiplication operands and constants.
  Value *A = MulX->getOperand(0);
  Value *C1 = MulX->getOperand(1);
  Value *B = MulY->getOperand(0);
  Value *C2 = MulY->getOperand(1);

  const APInt *C1Value, *C2Value;

  // Check if the constants satisfy the condition c1 % c2 == 0.
  if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0)
    return nullptr;

  // Compute the new constant k = c1 / c2.
  APInt K = C1Value->udiv(*C2Value);
  Type *Ty = Rem->getType();

  // Create a new remainder instruction (a * k) % b.
  Value *NewRem = Builder.CreateURem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
  return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
}

Please be patient with me, as I am still in my learning phase 😅

Also, I found that in InstCombineInternal.h, it had

Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
                                    const APInt &C);

instead of

Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *SRem,
                                    const APInt &C);

So I have changed it.

Please create a pull request. I cannot leave reviews on the issue comment.

@Baxi-codes
Copy link

PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants