Skip to content

Commit

Permalink
[InstSimplify] Reduce code duplication in icmp of binop folds (NFC)
Browse files Browse the repository at this point in the history
For folds where we check for the binop on both the LHS and RHS,
extract a function that expects it on the LHS and call it with
swapped order.
  • Loading branch information
nikic committed Aug 2, 2020
1 parent 8d1b950 commit a0addbb
Showing 1 changed file with 82 additions and 133 deletions.
215 changes: 82 additions & 133 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2753,14 +2753,87 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
return nullptr;
}

static Value *simplifyICmpWithBinOpOnLHS(
CmpInst::Predicate Pred, BinaryOperator *LBO, Value *RHS,
const SimplifyQuery &Q, unsigned MaxRecurse) {
Type *ITy = GetCompareTy(RHS); // The return type.

Value *Y = nullptr;
// icmp pred (or X, Y), X
if (match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) {
if (Pred == ICmpInst::ICMP_ULT)
return getFalse(ITy);
if (Pred == ICmpInst::ICMP_UGE)
return getTrue(ITy);

if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (RHSKnown.isNonNegative() && YKnown.isNegative())
return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy);
if (RHSKnown.isNegative() || YKnown.isNonNegative())
return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy);
}
}

// icmp pred (and X, Y), X
if (match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) {
if (Pred == ICmpInst::ICMP_UGT)
return getFalse(ITy);
if (Pred == ICmpInst::ICMP_ULE)
return getTrue(ITy);
}

// icmp pred (urem X, Y), Y
if (match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
switch (Pred) {
default:
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE: {
KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (!Known.isNonNegative())
break;
LLVM_FALLTHROUGH;
}
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
return getFalse(ITy);
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE: {
KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (!Known.isNonNegative())
break;
LLVM_FALLTHROUGH;
}
case ICmpInst::ICMP_NE:
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
return getTrue(ITy);
}
}

// x >> y <=u x
// x udiv y <=u x.
if (match(LBO, m_LShr(m_Specific(RHS), m_Value())) ||
match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) {
// icmp pred (X op Y), X
if (Pred == ICmpInst::ICMP_UGT)
return getFalse(ITy);
if (Pred == ICmpInst::ICMP_ULE)
return getTrue(ITy);
}

return nullptr;
}

/// TODO: A large part of this logic is duplicated in InstCombine's
/// foldICmpBinOp(). We should be able to share that and avoid the code
/// duplication.
static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
Value *RHS, const SimplifyQuery &Q,
unsigned MaxRecurse) {
Type *ITy = GetCompareTy(LHS); // The return type.

BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS);
BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS);
if (MaxRecurse && (LBO || RBO)) {
Expand Down Expand Up @@ -2831,56 +2904,14 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
}
}

{
Value *Y = nullptr;
// icmp pred (or X, Y), X
if (LBO && match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) {
if (Pred == ICmpInst::ICMP_ULT)
return getFalse(ITy);
if (Pred == ICmpInst::ICMP_UGE)
return getTrue(ITy);

if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (RHSKnown.isNonNegative() && YKnown.isNegative())
return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy);
if (RHSKnown.isNegative() || YKnown.isNonNegative())
return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy);
}
}
// icmp pred X, (or X, Y)
if (RBO && match(RBO, m_c_Or(m_Value(Y), m_Specific(LHS)))) {
if (Pred == ICmpInst::ICMP_ULE)
return getTrue(ITy);
if (Pred == ICmpInst::ICMP_UGT)
return getFalse(ITy);

if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) {
KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (LHSKnown.isNonNegative() && YKnown.isNegative())
return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy);
if (LHSKnown.isNegative() || YKnown.isNonNegative())
return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy);
}
}
}
if (LBO)
if (Value *V = simplifyICmpWithBinOpOnLHS(Pred, LBO, RHS, Q, MaxRecurse))
return V;

// icmp pred (and X, Y), X
if (LBO && match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) {
if (Pred == ICmpInst::ICMP_UGT)
return getFalse(ITy);
if (Pred == ICmpInst::ICMP_ULE)
return getTrue(ITy);
}
// icmp pred X, (and X, Y)
if (RBO && match(RBO, m_c_And(m_Value(), m_Specific(LHS)))) {
if (Pred == ICmpInst::ICMP_UGE)
return getTrue(ITy);
if (Pred == ICmpInst::ICMP_ULT)
return getFalse(ITy);
}
if (RBO)
if (Value *V = simplifyICmpWithBinOpOnLHS(
ICmpInst::getSwappedPredicate(Pred), RBO, LHS, Q, MaxRecurse))
return V;

// 0 - (zext X) pred C
if (!CmpInst::isUnsigned(Pred) && match(LHS, m_Neg(m_ZExt(m_Value())))) {
Expand All @@ -2904,88 +2935,6 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
}
}

// icmp pred (urem X, Y), Y
if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
switch (Pred) {
default:
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE: {
KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (!Known.isNonNegative())
break;
LLVM_FALLTHROUGH;
}
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
return getFalse(ITy);
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE: {
KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (!Known.isNonNegative())
break;
LLVM_FALLTHROUGH;
}
case ICmpInst::ICMP_NE:
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
return getTrue(ITy);
}
}

// icmp pred X, (urem Y, X)
if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) {
switch (Pred) {
default:
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE: {
KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (!Known.isNonNegative())
break;
LLVM_FALLTHROUGH;
}
case ICmpInst::ICMP_NE:
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
return getTrue(ITy);
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE: {
KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (!Known.isNonNegative())
break;
LLVM_FALLTHROUGH;
}
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
return getFalse(ITy);
}
}

// x >> y <=u x
// x udiv y <=u x.
if (LBO && (match(LBO, m_LShr(m_Specific(RHS), m_Value())) ||
match(LBO, m_UDiv(m_Specific(RHS), m_Value())))) {
// icmp pred (X op Y), X
if (Pred == ICmpInst::ICMP_UGT)
return getFalse(ITy);
if (Pred == ICmpInst::ICMP_ULE)
return getTrue(ITy);
}

// x >=u x >> y
// x >=u x udiv y.
if (RBO && (match(RBO, m_LShr(m_Specific(LHS), m_Value())) ||
match(RBO, m_UDiv(m_Specific(LHS), m_Value())))) {
// icmp pred X, (X op Y)
if (Pred == ICmpInst::ICMP_ULT)
return getFalse(ITy);
if (Pred == ICmpInst::ICMP_UGE)
return getTrue(ITy);
}

// handle:
// CI2 << X == CI
// CI2 << X != CI
Expand Down

0 comments on commit a0addbb

Please sign in to comment.