diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index bfe586c0d1c975..6bd3c826c9e4e7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4404,16 +4404,37 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, } { - // Try to remove shared constant multiplier from equality comparison: - // X * C == Y * C (with no overflowing/aliasing) --> X == Y - Value *X, *Y; - const APInt *C; - if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 && - match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality()) - if (!C->countTrailingZeros() || - (BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || - (BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) - return new ICmpInst(Pred, X, Y); + // Try to remove shared multiplier from comparison: + // X * Z u{lt/le/gt/ge}/eq/ne Y * Z + Value *X, *Y, *Z; + if (Pred == ICmpInst::getUnsignedPredicate(Pred) && + ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) && + match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) || + (match(Op0, m_Mul(m_Value(Z), m_Value(X))) && + match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))))) { + bool NonZero; + if (ICmpInst::isEquality(Pred)) { + KnownBits ZKnown = computeKnownBits(Z, 0, &I); + // if Z % 2 != 0 + // X * Z eq/ne Y * Z -> X eq/ne Y + if (ZKnown.countMaxTrailingZeros() == 0) + return new ICmpInst(Pred, X, Y); + NonZero = !ZKnown.One.isZero() || + isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + // if Z != 0 and nsw(X * Z) and nsw(Y * Z) + // X * Z eq/ne Y * Z -> X eq/ne Y + if (NonZero && BO0 && BO1 && BO0->hasNoSignedWrap() && + BO1->hasNoSignedWrap()) + return new ICmpInst(Pred, X, Y); + } else + NonZero = isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + + // If Z != 0 and nuw(X * Z) and nuw(Y * Z) + // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y + if (NonZero && BO0 && BO1 && BO0->hasNoUnsignedWrap() && + BO1->hasNoUnsignedWrap()) + return new ICmpInst(Pred, X, Y); + } } BinaryOperator *SRem = nullptr; diff --git a/llvm/test/Transforms/InstCombine/icmp-mul.ll b/llvm/test/Transforms/InstCombine/icmp-mul.ll index e26ea68e83049c..7adcaf390f5324 100644 --- a/llvm/test/Transforms/InstCombine/icmp-mul.ll +++ b/llvm/test/Transforms/InstCombine/icmp-mul.ll @@ -1047,9 +1047,7 @@ define i1 @mul_xy_z_assumeodd_eq(i8 %x, i8 %y, i8 %z) { ; CHECK-NEXT: [[LB:%.*]] = and i8 [[Z:%.*]], 1 ; CHECK-NEXT: [[NZ:%.*]] = icmp ne i8 [[LB]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[NZ]]) -; CHECK-NEXT: [[MULX:%.*]] = mul i8 [[X:%.*]], [[Z]] -; CHECK-NEXT: [[MULY:%.*]] = mul i8 [[Z]], [[Y:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[MULX]], [[MULY]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; %lb = and i8 %z, 1 @@ -1064,9 +1062,8 @@ define i1 @mul_xy_z_assumeodd_eq(i8 %x, i8 %y, i8 %z) { define <2 x i1> @reused_mul_nsw_xy_z_setnonzero_vec_ne(<2 x i8> %x, <2 x i8> %y, <2 x i8> %zi) { ; CHECK-LABEL: @reused_mul_nsw_xy_z_setnonzero_vec_ne( ; CHECK-NEXT: [[Z:%.*]] = or <2 x i8> [[ZI:%.*]], -; CHECK-NEXT: [[MULX:%.*]] = mul nsw <2 x i8> [[Z]], [[X:%.*]] ; CHECK-NEXT: [[MULY:%.*]] = mul nsw <2 x i8> [[Z]], [[Y:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[MULY]], [[MULX]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[Y]], [[X:%.*]] ; CHECK-NEXT: call void @usev2xi8(<2 x i8> [[MULY]]) ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; @@ -1098,8 +1095,7 @@ define i1 @mul_nuw_xy_z_assumenonzero_uge(i8 %x, i8 %y, i8 %z) { ; CHECK-NEXT: [[NZ:%.*]] = icmp ne i8 [[Z:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[NZ]]) ; CHECK-NEXT: [[MULX:%.*]] = mul nuw i8 [[X:%.*]], [[Z]] -; CHECK-NEXT: [[MULY:%.*]] = mul nuw i8 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp uge i8 [[MULY]], [[MULX]] +; CHECK-NEXT: [[CMP:%.*]] = icmp uge i8 [[Y:%.*]], [[X]] ; CHECK-NEXT: call void @use(i8 [[MULX]]) ; CHECK-NEXT: ret i1 [[CMP]] ; @@ -1114,10 +1110,7 @@ define i1 @mul_nuw_xy_z_assumenonzero_uge(i8 %x, i8 %y, i8 %z) { define <2 x i1> @mul_nuw_xy_z_setnonzero_vec_eq(<2 x i8> %x, <2 x i8> %y, <2 x i8> %zi) { ; CHECK-LABEL: @mul_nuw_xy_z_setnonzero_vec_eq( -; CHECK-NEXT: [[Z:%.*]] = or <2 x i8> [[ZI:%.*]], -; CHECK-NEXT: [[MULX:%.*]] = mul nuw <2 x i8> [[Z]], [[X:%.*]] -; CHECK-NEXT: [[MULY:%.*]] = mul nuw <2 x i8> [[Z]], [[Y:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[MULX]], [[MULY]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %z = or <2 x i8> %zi, @@ -1132,9 +1125,7 @@ define i1 @mul_nuw_xy_z_brnonzero_ult(i8 %x, i8 %y, i8 %z) { ; CHECK-NEXT: [[NZ_NOT:%.*]] = icmp eq i8 [[Z:%.*]], 0 ; CHECK-NEXT: br i1 [[NZ_NOT]], label [[FALSE:%.*]], label [[TRUE:%.*]] ; CHECK: true: -; CHECK-NEXT: [[MULX:%.*]] = mul nuw i8 [[X:%.*]], [[Z]] -; CHECK-NEXT: [[MULY:%.*]] = mul nuw i8 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[MULY]], [[MULX]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[Y:%.*]], [[X:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; CHECK: false: ; CHECK-NEXT: call void @use(i8 [[Z]])