diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 8824a05e3aa6c..188faf00329ec 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -8231,14 +8231,30 @@ isImpliedCondMatchingOperands(CmpInst::Predicate LPred, return std::nullopt; } -/// Return true if "icmp LPred X, LC" implies "icmp RPred X, RC" is true. -/// Return false if "icmp LPred X, LC" implies "icmp RPred X, RC" is false. -/// Otherwise, return std::nullopt if we can't infer anything. +/// Return true if "icmp LPred X, LC" implies "icmp RPred cast(X), RC" is true. +/// Return false if "icmp LPred X, LC" implies "icmp RPred cast(X), RC" is +/// false. Otherwise, return std::nullopt if we can't infer anything. static std::optional isImpliedCondCommonOperandWithConstants( - CmpInst::Predicate LPred, const APInt &LC, CmpInst::Predicate RPred, - const APInt &RC) { + const Value *L0, CmpInst::Predicate LPred, const APInt &LC, const Value *R0, + CmpInst::Predicate RPred, const APInt &RC) { ConstantRange DomCR = ConstantRange::makeExactICmpRegion(LPred, LC); ConstantRange CR = ConstantRange::makeExactICmpRegion(RPred, RC); + + if (L0 == R0) + ; // noop + // Example: icmp eq X, 3 --> icmp sgt trunc(X), 2 + else if (match(R0, m_Trunc(m_Specific(L0)))) + DomCR = DomCR.truncate(RC.getBitWidth()); + // Example: icmp slt trunc(X), 3 --> icmp ne X, 3 + else if (match(L0, m_Trunc(m_Specific(R0)))) { + // Try to prove by negation + DomCR = DomCR.inverse(); + CR = CR.inverse(); + std::swap(DomCR, CR); + DomCR = DomCR.truncate(LC.getBitWidth()); + } else + return std::nullopt; + ConstantRange Intersection = DomCR.intersectWith(CR); ConstantRange Difference = DomCR.difference(CR); if (Intersection.isEmptySet()) @@ -8272,8 +8288,19 @@ static std::optional isImpliedCondICmps(const ICmpInst *LHS, // Can we infer anything when the 0-operands match and the 1-operands are // constants (not necessarily matching)? const APInt *LC, *RC; - if (L0 == R0 && match(L1, m_APInt(LC)) && match(R1, m_APInt(RC))) - return isImpliedCondCommonOperandWithConstants(LPred, *LC, RPred, *RC); + if (match(L1, m_APInt(LC)) && match(R1, m_APInt(RC))) { + if (auto Res = isImpliedCondCommonOperandWithConstants(L0, LPred, *LC, R0, + RPred, *RC)) + return Res; + + if (match(L0, m_Trunc(m_Specific(R0)))) { + // When L0 == trunc(R0), we use the law of excluded middle to cover some + // missing cases. + if (auto Res = isImpliedCondCommonOperandWithConstants( + L0, LPred, *LC, R0, ICmpInst::getInversePredicate(RPred), *RC)) + return !*Res; + } + } // L0 = R0 = L1 + R1, L0 >=u L1 implies R0 >=u R1, L0