Skip to content

Commit

Permalink
[PatternMatching] Add generic API for matching constants using custom…
Browse files Browse the repository at this point in the history
… conditions

The new API is:
    `m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)`
        - Matches non-undef constants s.t `Lambda(ele)` is true for all
          elements.
    `m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)`
        - Matches constants/undef s.t `Lambda(ele)` is true for all
          elements.

The goal with these is to be able to replace the common usage of:
```
    match(X, m_APInt(C)) && CustomCheck(C)
```
with
```
    match(X, m_CheckedInt(C, CustomChecks);
```

The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.

The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.
  • Loading branch information
goldsteinn committed May 3, 2024
1 parent 285dbed commit d8428df
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 0 deletions.
33 changes: 33 additions & 0 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,39 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
//
///////////////////////////////////////////////////////////////////////////////

template <typename APTy> struct custom_checkfn {
function_ref<bool(const APTy &)> CheckFn;
bool isValue(const APTy &C) { return CheckFn(C); }
};

/// Match an integer or vector where CheckFn(ele) for each element is true.
/// For vectors, poison elements are assumed to match.
inline cst_pred_ty<custom_checkfn<APInt>>
m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
}

inline api_pred_ty<custom_checkfn<APInt>>
m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
api_pred_ty<custom_checkfn<APInt>> P(V);
P.CheckFn = CheckFn;
return P;
}

/// Match a float or vector where CheckFn(ele) for each element is true.
/// For vectors, poison elements are assumed to match.
inline cstfp_pred_ty<custom_checkfn<APFloat>>
m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
}

inline apf_pred_ty<custom_checkfn<APFloat>>
m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
apf_pred_ty<custom_checkfn<APFloat>> P(V);
P.CheckFn = CheckFn;
return P;
}

struct is_any_apint {
bool isValue(const APInt &C) { return true; }
};
Expand Down
177 changes: 177 additions & 0 deletions llvm/unittests/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,134 @@ TEST_F(PatternMatchTest, BitCast) {
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
}

TEST_F(PatternMatchTest, CheckedInt) {
Type *I8Ty = IRB.getInt8Ty();
const APInt *Res = nullptr;

auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
auto CheckTrue = [](const APInt &) { return true; };
auto CheckFalse = [](const APInt &) { return false; };
auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };

auto DoScalarCheck = [&](int8_t Val) {
APInt APVal(8, Val);
Constant *C = ConstantInt::get(I8Ty, Val);

Res = nullptr;
EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
EXPECT_EQ(*Res, APVal);

Res = nullptr;
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));

Res = nullptr;
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
if (CheckUgt1(APVal)) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, APVal);
}

Res = nullptr;
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
if (CheckNonZero(APVal)) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, APVal);
}

Res = nullptr;
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
if (CheckPow2(APVal)) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, APVal);
}

};

DoScalarCheck(0);
DoScalarCheck(1);
DoScalarCheck(2);
DoScalarCheck(3);

EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
EXPECT_EQ(Res, nullptr);

EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
EXPECT_EQ(Res, nullptr);

EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
EXPECT_EQ(Res, nullptr);

EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
EXPECT_EQ(Res, nullptr);

auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
function_ref<bool(const APInt &)> CheckFn,
bool UndefAsPoison) {
SmallVector<Constant *> VecElems;
std::optional<bool> Okay;
bool AllSame = true;
bool HasUndef = false;
std::optional<APInt> First;
for (const std::optional<int8_t> &Val : Vals) {
if (!Val.has_value()) {
VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
: UndefValue::get(I8Ty));
HasUndef = true;
} else {
if (!Okay.has_value())
Okay = true;
APInt APVal(8, *Val);
if (!First.has_value())
First = APVal;
else
AllSame &= First->eq(APVal);
Okay = *Okay && CheckFn(APVal);
VecElems.push_back(ConstantInt::get(I8Ty, *Val));
}
}

Constant *C = ConstantVector::get(VecElems);
EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
m_CheckedInt(CheckFn).match(C));

Res = nullptr;
bool Expec =
!(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
if (Expec) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, *First);
}
};
auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/false);
DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/false);
DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/true);
DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/true);
DoVecCheckImpl(Vals, CheckUgt1, /*UndefAsPoison=*/false);
DoVecCheckImpl(Vals, CheckNonZero, /*UndefAsPoison=*/false);
DoVecCheckImpl(Vals, CheckPow2, /*UndefAsPoison=*/false);
};

DoVecCheck({0, 1});
DoVecCheck({1, 1});
DoVecCheck({1, 2});
DoVecCheck({1, std::nullopt});
DoVecCheck({1, std::nullopt, 1});
DoVecCheck({1, std::nullopt, 2});
DoVecCheck({std::nullopt, std::nullopt, std::nullopt});
}

TEST_F(PatternMatchTest, Power2) {
Value *C128 = IRB.getInt32(128);
Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));
Expand Down Expand Up @@ -1397,21 +1525,58 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
EXPECT_FALSE(match(VectorInfPoison, m_Finite()));
EXPECT_FALSE(match(VectorNaNPoison, m_Finite()));

auto CheckTrue = [](const APFloat &) { return true; };
EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckTrue)));
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckTrue)));
EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckTrue)));
EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckTrue)));
EXPECT_TRUE(match(ScalarNaN, m_CheckedFp(CheckTrue)));
EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckTrue)));
EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckTrue)));
EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckTrue)));
EXPECT_TRUE(match(VectorNaNPoison, m_CheckedFp(CheckTrue)));

auto CheckFalse = [](const APFloat &) { return false; };
EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckFalse)));
EXPECT_FALSE(match(VectorZeroPoison, m_CheckedFp(CheckFalse)));
EXPECT_FALSE(match(ScalarPosInf, m_CheckedFp(CheckFalse)));
EXPECT_FALSE(match(ScalarNegInf, m_CheckedFp(CheckFalse)));
EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckFalse)));
EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckFalse)));
EXPECT_FALSE(match(VectorInfPoison, m_CheckedFp(CheckFalse)));
EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckFalse)));
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckFalse)));

auto CheckNonNaN = [](const APFloat &C) { return !C.isNaN(); };
EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckNonNaN)));
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckNonNaN)));
EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckNonNaN)));
EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckNonNaN)));
EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckNonNaN)));
EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckNonNaN)));
EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckNonNaN)));
EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckNonNaN)));
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));

const APFloat *C;
// Regardless of whether poison is allowed,
// a fully undef/poison constant does not match.
EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));

// We can always match simple constants and simple splats.
C = nullptr;
Expand All @@ -1432,6 +1597,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
C = nullptr;
EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
EXPECT_TRUE(C->isZero());
C = nullptr;
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
EXPECT_TRUE(C->isZero());
C = nullptr;
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
EXPECT_TRUE(C->isZero());

// Splats with undef are never allowed.
// Whether splats with poison can be matched depends on the matcher.
Expand All @@ -1456,6 +1627,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
C = nullptr;
EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
EXPECT_TRUE(C->isZero());
C = nullptr;
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
EXPECT_TRUE(C->isZero());
C = nullptr;
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
EXPECT_TRUE(C->isZero());
}

TEST_F(PatternMatchTest, FloatingPointFNeg) {
Expand Down

0 comments on commit d8428df

Please sign in to comment.