diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 374d2fb78de88..f4974348945c5 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void); DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); +DEFINE_C_API_STRUCT(MlirPatternRewriter, void); //===----------------------------------------------------------------------===// /// RewriterBase API inherited from OpBuilder @@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( #if MLIR_ENABLE_PDL_IN_PATTERNMATCH DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); +DEFINE_C_API_STRUCT(MlirPDLValue, const void); +DEFINE_C_API_STRUCT(MlirPDLResultList, void); MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op); @@ -323,6 +326,55 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op); MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op); + +/// Cast the MlirPDLValue to an MlirValue. +/// Return a null value if the cast fails, just like llvm::dyn_cast. +MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value); + +/// Cast the MlirPDLValue to an MlirType. +/// Return a null value if the cast fails, just like llvm::dyn_cast. +MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value); + +/// Cast the MlirPDLValue to an MlirOperation. +/// Return a null value if the cast fails, just like llvm::dyn_cast. +MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value); + +/// Cast the MlirPDLValue to an MlirAttribute. +/// Return a null value if the cast fails, just like llvm::dyn_cast. +MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value); + +/// Push the MlirValue into the given MlirPDLResultList. +MLIR_CAPI_EXPORTED void +mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value); + +/// Push the MlirType into the given MlirPDLResultList. +MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results, + MlirType value); + +/// Push the MlirOperation into the given MlirPDLResultList. +MLIR_CAPI_EXPORTED void +mlirPDLResultListPushBackOperation(MlirPDLResultList results, + MlirOperation value); + +/// Push the MlirAttribute into the given MlirPDLResultList. +MLIR_CAPI_EXPORTED void +mlirPDLResultListPushBackAttribute(MlirPDLResultList results, + MlirAttribute value); + +/// This function type is used as callbacks for PDL native rewrite functions. +/// Input values can be accessed by `values` with its size `nValues`; +/// output values can be added into `results` by `mlirPDLResultListPushBack*` +/// APIs. And the return value indicates whether the rewrite succeeds. +typedef MlirLogicalResult (*MlirPDLRewriteFunction)( + MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, + MlirPDLValue *values, void *userData); + +/// Register a rewrite function into the given PDL pattern module. +/// `userData` will be provided as an argument to the rewrite function. +MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLRewriteFunction rewriteFn, void *userData); + #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH #undef DEFINE_C_API_STRUCT diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc index 96ba98a850de0..d5fb57d7c360d 100644 --- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc +++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc @@ -53,7 +53,7 @@ public: /// value is not an instance of `T`. template ::value, T, std::optional>> + std::is_constructible_v, T, std::optional>> ResultT dyn_cast() const { return isa() ? castImpl() : ResultT(); } diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 96f00cface64f..72062f660458b 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -9,12 +9,15 @@ #include "Rewrite.h" #include "IRModule.h" +#include "mlir-c/IR.h" #include "mlir-c/Rewrite.h" +#include "mlir-c/Support.h" // clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. // clang-format on #include "mlir/Config/mlir-config.h" +#include "nanobind/nanobind.h" namespace nb = nanobind; using namespace mlir; @@ -24,6 +27,31 @@ using namespace mlir::python; namespace { #if MLIR_ENABLE_PDL_IN_PATTERNMATCH +static nb::object objectFromPDLValue(MlirPDLValue value) { + if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v)) + return nb::cast(v); + if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v)) + return nb::cast(v); + if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v)) + return nb::cast(v); + if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v)) + return nb::cast(v); + + throw std::runtime_error("unsupported PDL value type"); +} + +// Convert the Python object to a boolean. +// If it evaluates to False, treat it as success; +// otherwise, treat it as failure. +// Note that None is considered success. +static MlirLogicalResult logicalResultFromObject(const nb::object &obj) { + if (obj.is_none()) + return mlirLogicalResultSuccess(); + + return nb::cast(obj) ? mlirLogicalResultFailure() + : mlirLogicalResultSuccess(); +} + /// Owning Wrapper around a PDLPatternModule. class PyPDLPatternModule { public: @@ -38,6 +66,23 @@ class PyPDLPatternModule { } MlirPDLPatternModule get() { return module; } + void registerRewriteFunction(const std::string &name, + const nb::callable &fn) { + mlirPDLPatternModuleRegisterRewriteFunction( + get(), mlirStringRefCreate(name.data(), name.size()), + [](MlirPatternRewriter rewriter, MlirPDLResultList results, + size_t nValues, MlirPDLValue *values, + void *userData) -> MlirLogicalResult { + nb::handle f = nb::handle(static_cast(userData)); + std::vector args; + args.reserve(nValues); + for (size_t i = 0; i < nValues; ++i) + args.push_back(objectFromPDLValue(values[i])); + return logicalResultFromObject(f(rewriter, results, args)); + }, + fn.ptr()); + } + private: MlirPDLPatternModule module; }; @@ -78,10 +123,48 @@ class PyFrozenRewritePatternSet { /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { + nb::class_(m, "PatternRewriter"); //---------------------------------------------------------------------------- - // Mapping of the top-level PassManager + // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH + nb::class_(m, "PDLResultList") + .def( + "append", + [](MlirPDLResultList results, const PyValue &value) { + mlirPDLResultListPushBackValue(results, value); + }, + // clang-format off + nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")") + // clang-format on + ) + .def( + "append", + [](MlirPDLResultList results, const PyOperation &op) { + mlirPDLResultListPushBackOperation(results, op); + }, + // clang-format off + nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")") + // clang-format on + ) + .def( + "append", + [](MlirPDLResultList results, const PyType &type) { + mlirPDLResultListPushBackType(results, type); + }, + // clang-format off + nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")") + // clang-format on + ) + .def( + "append", + [](MlirPDLResultList results, const PyAttribute &attr) { + mlirPDLResultListPushBackAttribute(results, attr); + }, + // clang-format off + nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")") + // clang-format on + ); nb::class_(m, "PDLModule") .def( "__init__", @@ -93,10 +176,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"), // clang-format on "module"_a, "Create a PDL module from the given module.") - .def("freeze", [](PyPDLPatternModule &self) { - return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( - mlirRewritePatternSetFromPDLPatternModule(self.get()))); - }); + .def( + "freeze", + [](PyPDLPatternModule &self) { + return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + mlirRewritePatternSetFromPDLPatternModule(self.get()))); + }, + nb::keep_alive<0, 1>()) + .def( + "register_rewrite_function", + [](PyPDLPatternModule &self, const std::string &name, + const nb::callable &fn) { + self.registerRewriteFunction(name, fn); + }, + nb::keep_alive<1, 3>()); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "FrozenRewritePatternSet") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 6f85357a14a18..9ecce956a05b9 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -13,6 +13,8 @@ #include "mlir/CAPI/Rewrite.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PDLPatternMatch.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); } +//===----------------------------------------------------------------------===// +/// PatternRewriter API +//===----------------------------------------------------------------------===// + +inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { + assert(rewriter.ptr && "unexpected null rewriter"); + return static_cast(rewriter.ptr); +} + +inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { + return {rewriter}; +} + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// @@ -331,4 +346,73 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { op.ptr = nullptr; return wrap(m); } + +inline const mlir::PDLValue *unwrap(MlirPDLValue value) { + assert(value.ptr && "unexpected null PDL value"); + return static_cast(value.ptr); +} + +inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } + +inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { + assert(results.ptr && "unexpected null PDL results"); + return static_cast(results.ptr); +} + +inline MlirPDLResultList wrap(mlir::PDLResultList *results) { + return {results}; +} + +MlirValue mlirPDLValueAsValue(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast()); +} + +MlirType mlirPDLValueAsType(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast()); +} + +MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast()); +} + +MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast()); +} + +void mlirPDLResultListPushBackValue(MlirPDLResultList results, + MlirValue value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackOperation(MlirPDLResultList results, + MlirOperation value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, + MlirAttribute value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLPatternModuleRegisterRewriteFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLRewriteFunction rewriteFn, void *userData) { + unwrap(pdlModule)->registerRewriteFunction( + unwrap(name), + [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results, + ArrayRef values) -> LogicalResult { + std::vector mlirValues; + mlirValues.reserve(values.size()); + for (auto &value : values) { + mlirValues.push_back(wrap(&value)); + } + return unwrap(rewriteFn(wrap(&rewriter), wrap(&results), + mlirValues.size(), mlirValues.data(), + userData)); + }); +} #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index dd6c74ce622c8..e85c6c77ef955 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -86,3 +86,98 @@ def add_func(a, b): frozen = get_pdl_patterns() apply_patterns_and_fold_greedily(module_.operation, frozen) return module_ + + +# If we use arith.constant and arith.addi here, +# these C++-defined folding/canonicalization will be applied +# implicitly in the greedy pattern rewrite driver to +# make our Python-defined folding useless, +# so here we define a new dialect to workaround this. +def load_myint_dialect(): + from mlir.dialects import irdl + + m = Module.create() + with InsertionPoint(m.body): + myint = irdl.dialect("myint") + with InsertionPoint(myint.body): + constant = irdl.operation_("constant") + with InsertionPoint(constant.body): + iattr = irdl.base(base_name="#builtin.integer") + i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32))) + irdl.attributes_([iattr], ["value"]) + irdl.results_([i32], ["cst"], [irdl.Variadicity.single]) + add = irdl.operation_("add") + with InsertionPoint(add.body): + i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32))) + irdl.operands_( + [i32, i32], + ["lhs", "rhs"], + [irdl.Variadicity.single, irdl.Variadicity.single], + ) + irdl.results_([i32], ["res"], [irdl.Variadicity.single]) + + m.operation.verify() + irdl.load_dialects(m) + + +# This PDL pattern is to fold constant additions, +# i.e. add(constant0, constant1) -> constant2 +# where constant2 = constant0 + constant1. +def get_pdl_pattern_fold(): + m = Module.create() + i32 = IntegerType.get_signless(32) + with InsertionPoint(m.body): + + @pdl.pattern(benefit=1, sym_name="myint_add_fold") + def pat(): + t = pdl.TypeOp(i32) + a0 = pdl.AttributeOp() + a1 = pdl.AttributeOp() + c0 = pdl.OperationOp( + name="myint.constant", attributes={"value": a0}, types=[t] + ) + c1 = pdl.OperationOp( + name="myint.constant", attributes={"value": a1}, types=[t] + ) + v0 = pdl.ResultOp(c0, 0) + v1 = pdl.ResultOp(c1, 0) + op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) + + @pdl.rewrite() + def rew(): + sum = pdl.apply_native_rewrite( + [pdl.AttributeType.get()], "add_fold", [a0, a1] + ) + newOp = pdl.OperationOp( + name="myint.constant", attributes={"value": sum}, types=[t] + ) + pdl.ReplaceOp(op0, with_op=newOp) + + def add_fold(rewriter, results, values): + a0, a1 = values + results.append(IntegerAttr.get(i32, a0.value + a1.value)) + + pdl_module = PDLModule(m) + pdl_module.register_rewrite_function("add_fold", add_fold) + return pdl_module.freeze() + + +# CHECK-LABEL: TEST: test_pdl_register_function +# CHECK: "myint.constant"() {value = 8 : i32} : () -> i32 +@construct_and_print_in_module +def test_pdl_register_function(module_): + load_myint_dialect() + + module_ = Module.parse( + """ + %c0 = "myint.constant"() { value = 2 }: () -> (i32) + %c1 = "myint.constant"() { value = 3 }: () -> (i32) + %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32) + "myint.add"(%x, %c1): (i32, i32) -> (i32) + """ + ) + + frozen = get_pdl_pattern_fold() + apply_patterns_and_fold_greedily(module_, frozen) + + return module_