diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index f4974348945c5..77be1f480eacf 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -375,6 +375,20 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction( MlirPDLPatternModule pdlModule, MlirStringRef name, MlirPDLRewriteFunction rewriteFn, void *userData); +/// This function type is used as callbacks for PDL native constraint functions. +/// Input values can be accessed by `values` with its size `nValues`; +/// output values can be added into `results` by `mlirPDLResultListPushBack*` +/// APIs. And the return value indicates whether the constraint holds. +typedef MlirLogicalResult (*MlirPDLConstraintFunction)( + MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, + MlirPDLValue *values, void *userData); + +/// Register a constraint function into the given PDL pattern module. +/// `userData` will be provided as an argument to the constraint function. +MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterConstraintFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLConstraintFunction constraintFn, void *userData); + #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH #undef DEFINE_C_API_STRUCT diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index c53c6cf0dab1e..20392b9002706 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -40,6 +40,15 @@ static nb::object objectFromPDLValue(MlirPDLValue value) { throw std::runtime_error("unsupported PDL value type"); } +static std::vector objectsFromPDLValues(size_t nValues, + MlirPDLValue *values) { + std::vector args; + args.reserve(nValues); + for (size_t i = 0; i < nValues; ++i) + args.push_back(objectFromPDLValue(values[i])); + return args; +} + // Convert the Python object to a boolean. // If it evaluates to False, treat it as success; // otherwise, treat it as failure. @@ -74,11 +83,22 @@ class PyPDLPatternModule { size_t nValues, MlirPDLValue *values, void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); - std::vector args; - args.reserve(nValues); - for (size_t i = 0; i < nValues; ++i) - args.push_back(objectFromPDLValue(values[i])); - return logicalResultFromObject(f(rewriter, results, args)); + return logicalResultFromObject( + f(rewriter, results, objectsFromPDLValues(nValues, values))); + }, + fn.ptr()); + } + + void registerConstraintFunction(const std::string &name, + const nb::callable &fn) { + mlirPDLPatternModuleRegisterConstraintFunction( + get(), mlirStringRefCreate(name.data(), name.size()), + [](MlirPatternRewriter rewriter, MlirPDLResultList results, + size_t nValues, MlirPDLValue *values, + void *userData) -> MlirLogicalResult { + nb::handle f = nb::handle(static_cast(userData)); + return logicalResultFromObject( + f(rewriter, results, objectsFromPDLValues(nValues, values))); }, fn.ptr()); } @@ -199,6 +219,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { const nb::callable &fn) { self.registerRewriteFunction(name, fn); }, + nb::keep_alive<1, 3>()) + .def( + "register_constraint_function", + [](PyPDLPatternModule &self, const std::string &name, + const nb::callable &fn) { + self.registerConstraintFunction(name, fn); + }, nb::keep_alive<1, 3>()); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "FrozenRewritePatternSet") diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 9ecce956a05b9..8ee6308cadf83 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -398,6 +398,15 @@ void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, unwrap(results)->push_back(unwrap(value)); } +inline std::vector wrap(ArrayRef values) { + std::vector mlirValues; + mlirValues.reserve(values.size()); + for (auto &value : values) { + mlirValues.push_back(wrap(&value)); + } + return mlirValues; +} + void mlirPDLPatternModuleRegisterRewriteFunction( MlirPDLPatternModule pdlModule, MlirStringRef name, MlirPDLRewriteFunction rewriteFn, void *userData) { @@ -405,14 +414,25 @@ void mlirPDLPatternModuleRegisterRewriteFunction( unwrap(name), [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results, ArrayRef values) -> LogicalResult { - std::vector mlirValues; - mlirValues.reserve(values.size()); - for (auto &value : values) { - mlirValues.push_back(wrap(&value)); - } + std::vector mlirValues = wrap(values); return unwrap(rewriteFn(wrap(&rewriter), wrap(&results), mlirValues.size(), mlirValues.data(), userData)); }); } + +void mlirPDLPatternModuleRegisterConstraintFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLConstraintFunction constraintFn, void *userData) { + unwrap(pdlModule)->registerConstraintFunction( + unwrap(name), + [userData, constraintFn](PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef values) -> LogicalResult { + std::vector mlirValues = wrap(values); + return unwrap(constraintFn(wrap(&rewriter), wrap(&results), + mlirValues.size(), mlirValues.data(), + userData)); + }); +} #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index e85c6c77ef955..c8e6197e03842 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -153,12 +153,43 @@ def rew(): ) pdl.ReplaceOp(op0, with_op=newOp) + @pdl.pattern(benefit=1, sym_name="myint_add_zero_fold") + def pat(): + t = pdl.TypeOp(i32) + v0 = pdl.OperandOp() + v1 = pdl.OperandOp() + v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1]) + op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) + + @pdl.rewrite() + def rew(): + pdl.ReplaceOp(op0, with_values=[v]) + def add_fold(rewriter, results, values): a0, a1 = values results.append(IntegerAttr.get(i32, a0.value + a1.value)) + def is_zero(value): + op = value.owner + if isinstance(op, Operation): + return op.name == "myint.constant" and op.attributes["value"].value == 0 + return False + + # Check if either operand is a constant zero, + # and append the other operand to the results if so. + def has_zero(rewriter, results, values): + v0, v1 = values + if is_zero(v0): + results.append(v1) + return False + if is_zero(v1): + results.append(v0) + return False + return True + pdl_module = PDLModule(m) pdl_module.register_rewrite_function("add_fold", add_fold) + pdl_module.register_constraint_function("has_zero", has_zero) return pdl_module.freeze() @@ -181,3 +212,28 @@ def test_pdl_register_function(module_): apply_patterns_and_fold_greedily(module_, frozen) return module_ + + +# CHECK-LABEL: TEST: test_pdl_register_function_constraint +# CHECK: return %arg0 : i32 +@construct_and_print_in_module +def test_pdl_register_function_constraint(module_): + load_myint_dialect() + + module_ = Module.parse( + """ + func.func @f(%x : i32) -> i32 { + %c0 = "myint.constant"() { value = 1 }: () -> (i32) + %c1 = "myint.constant"() { value = -1 }: () -> (i32) + %a = "myint.add"(%c0, %c1): (i32, i32) -> (i32) + %b = "myint.add"(%a, %x): (i32, i32) -> (i32) + %c = "myint.add"(%b, %a): (i32, i32) -> (i32) + func.return %c : i32 + } + """ + ) + + frozen = get_pdl_pattern_fold() + apply_patterns_and_fold_greedily(module_, frozen) + + return module_