Skip to content

Commit

Permalink
[InstCombine] Optimize and of icmps with power-of-2 and contiguous masks
Browse files Browse the repository at this point in the history
Add an instance combine optimization for expressions of the form:

(%arg u< C1) & ((%arg & C2) != C2) -> %arg u< C2

Where C1 is a power-of-2 and C2 is a contiguous mask starting 1 bit below
C1. This commit resolves GitHub missed-optimization issue #54856.

Validation of scalar tests:
  - https://alive2.llvm.org/ce/z/JfKjiU
  - https://alive2.llvm.org/ce/z/AruHY_
  - https://alive2.llvm.org/ce/z/JAiR6t
  - https://alive2.llvm.org/ce/z/S2X2e5
  - https://alive2.llvm.org/ce/z/4cycdE
  - https://alive2.llvm.org/ce/z/NcDiLP

Validation of vector tests:
  - https://alive2.llvm.org/ce/z/ABY6tE
  - https://alive2.llvm.org/ce/z/BTJi3s
  - https://alive2.llvm.org/ce/z/3BKWpu
  - https://alive2.llvm.org/ce/z/RrAbkj
  - https://alive2.llvm.org/ce/z/nM6fsN

Reviewed By: goldstein.w.n

Differential Revision: https://reviews.llvm.org/D125717
  • Loading branch information
jmciver committed Jun 9, 2023
1 parent ea868d5 commit 1001f90
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 96 deletions.
8 changes: 8 additions & 0 deletions llvm/include/llvm/IR/PatternMatch.h
Expand Up @@ -445,6 +445,14 @@ inline cst_pred_ty<is_any_apint> m_AnyIntegralConstant() {
return cst_pred_ty<is_any_apint>();
}

struct is_shifted_mask {
bool isValue(const APInt &C) { return C.isShiftedMask(); }
};

inline cst_pred_ty<is_shifted_mask> m_ShiftedMask() {
return cst_pred_ty<is_shifted_mask>();
}

struct is_all_ones {
bool isValue(const APInt &C) { return C.isAllOnes(); }
};
Expand Down
105 changes: 105 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Expand Up @@ -955,6 +955,108 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd,
return nullptr;
}

/// Try to fold (icmp(A & B) == 0) & (icmp(A & D) != E) into (icmp A u< D) iff
/// B is a contiguous set of ones starting from the most significant bit
/// (negative power of 2), D and E are equal, and D is a contiguous set of ones
/// starting at the most significant zero bit in B. Parameter B supports masking
/// using undef/poison in either scalar or vector values.
static Value *foldNegativePower2AndShiftedMask(
Value *A, Value *B, Value *D, Value *E, ICmpInst::Predicate PredL,
ICmpInst::Predicate PredR, InstCombiner::BuilderTy &Builder) {
assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
"Expected equality predicates for masked type of icmps.");
if (PredL != ICmpInst::ICMP_EQ || PredR != ICmpInst::ICMP_NE)
return nullptr;

if (!match(B, m_NegatedPower2()) || !match(D, m_ShiftedMask()) ||
!match(E, m_ShiftedMask()))
return nullptr;

// Test scalar arguments for conversion. B has been validated earlier to be a
// negative power of two and thus is guaranteed to have one or more contiguous
// ones starting from the MSB followed by zero or more contiguous zeros. D has
// been validated earlier to be a shifted set of one or more contiguous ones.
// In order to match, B leading ones and D leading zeros should be equal. The
// predicate that B be a negative power of 2 prevents the condition of there
// ever being zero leading ones. Thus 0 == 0 cannot occur. The predicate that
// D always be a shifted mask prevents the condition of D equaling 0. This
// prevents matching the condition where B contains the maximum number of
// leading one bits (-1) and D contains the maximum number of leading zero
// bits (0).
auto isReducible = [](const Value *B, const Value *D, const Value *E) {
const APInt *BCst, *DCst, *ECst;
return match(B, m_APIntAllowUndef(BCst)) && match(D, m_APInt(DCst)) &&
match(E, m_APInt(ECst)) && *DCst == *ECst &&
(isa<UndefValue>(B) ||
(BCst->countLeadingOnes() == DCst->countLeadingZeros()));
};

// Test vector type arguments for conversion.
if (const auto *BVTy = dyn_cast<VectorType>(B->getType())) {
const auto *BFVTy = dyn_cast<FixedVectorType>(BVTy);
const auto *BConst = dyn_cast<Constant>(B);
const auto *DConst = dyn_cast<Constant>(D);
const auto *EConst = dyn_cast<Constant>(E);

if (!BFVTy || !BConst || !DConst || !EConst)
return nullptr;

for (unsigned I = 0; I != BFVTy->getNumElements(); ++I) {
const auto *BElt = BConst->getAggregateElement(I);
const auto *DElt = DConst->getAggregateElement(I);
const auto *EElt = EConst->getAggregateElement(I);

if (!BElt || !DElt || !EElt)
return nullptr;
if (!isReducible(BElt, DElt, EElt))
return nullptr;
}
} else {
// Test scalar type arguments for conversion.
if (!isReducible(B, D, E))
return nullptr;
}
return Builder.CreateICmp(ICmpInst::ICMP_ULT, A, D);
}

/// Try to fold ((icmp X u< P) & (icmp(X & M) != M)) or ((icmp X s> -1) &
/// (icmp(X & M) != M)) into (icmp X u< M). Where P is a power of 2, M < P, and
/// M is a contiguous shifted mask starting at the right most significant zero
/// bit in P. SGT is supported as when P is the largest representable power of
/// 2, an earlier optimization converts the expression into (icmp X s> -1).
/// Parameter P supports masking using undef/poison in either scalar or vector
/// values.
static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
bool JoinedByAnd,
InstCombiner::BuilderTy &Builder) {
if (!JoinedByAnd)
return nullptr;
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
CmpPred1 = Cmp1->getPredicate();
// Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
// 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
// SignMask) == 0).
std::optional<std::pair<unsigned, unsigned>> MaskPair =
getMaskedTypeForICmpPair(A, B, C, D, E, Cmp0, Cmp1, CmpPred0, CmpPred1);
if (!MaskPair)
return nullptr;

const auto compareBMask = BMask_NotMixed | BMask_NotAllOnes;
unsigned CmpMask0 = MaskPair->first;
unsigned CmpMask1 = MaskPair->second;
if ((CmpMask0 & Mask_AllZeros) && (CmpMask1 == compareBMask)) {
if (Value *V = foldNegativePower2AndShiftedMask(A, B, D, E, CmpPred0,
CmpPred1, Builder))
return V;
} else if ((CmpMask0 == compareBMask) && (CmpMask1 & Mask_AllZeros)) {
if (Value *V = foldNegativePower2AndShiftedMask(A, D, B, C, CmpPred1,
CmpPred0, Builder))
return V;
}
return nullptr;
}

/// Commuted variants are assumed to be handled by calling this function again
/// with the parameters swapped.
static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
Expand Down Expand Up @@ -2925,6 +3027,9 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldIsPowerOf2(LHS, RHS, IsAnd, Builder))
return V;

if (Value *V = foldPowerOf2AndShiftedMask(LHS, RHS, IsAnd, Builder))
return V;

// TODO: Verify whether this is safe for logical and/or.
if (!IsLogical) {
if (Value *X = foldUnsignedUnderflowCheck(LHS, RHS, IsAnd, Q, Builder))
Expand Down

0 comments on commit 1001f90

Please sign in to comment.