diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 5dd285ee076c4..2db1d84cd1d89 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirPatternRewriter, void); +DEFINE_C_API_STRUCT(MlirRewritePattern, const void); //===----------------------------------------------------------------------===// /// RewriterBase API inherited from OpBuilder @@ -302,11 +303,15 @@ MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter); /// FrozenRewritePatternSet API //===----------------------------------------------------------------------===// +/// Freeze the given MlirRewritePatternSet to a MlirFrozenRewritePatternSet. +/// Note that the ownership of the input set is transferred into the frozen set +/// after this call. MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet -mlirFreezeRewritePattern(MlirRewritePatternSet op); +mlirFreezeRewritePattern(MlirRewritePatternSet set); +/// Destroy the given MlirFrozenRewritePatternSet. MLIR_CAPI_EXPORTED void -mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op); +mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set); MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp( MlirOperation op, MlirFrozenRewritePatternSet patterns, @@ -324,6 +329,51 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter); +//===----------------------------------------------------------------------===// +/// RewritePattern API +//===----------------------------------------------------------------------===// + +/// Callbacks to construct a rewrite pattern. +typedef struct { + /// Optional constructor for the user data. + /// Set to nullptr to disable it. + void (*construct)(void *userData); + /// Optional destructor for the user data. + /// Set to nullptr to disable it. + void (*destruct)(void *userData); + /// The callback function to match against code rooted at the specified + /// operation, and perform the rewrite if the match is successful, + /// corresponding to RewritePattern::matchAndRewrite. + MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern, + MlirOperation op, + MlirPatternRewriter rewriter, + void *userData); +} MlirRewritePatternCallbacks; + +/// Create a rewrite pattern that matches the operation +/// with the given rootName, corresponding to mlir::OpRewritePattern. +MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( + MlirStringRef rootName, unsigned benefit, MlirContext context, + MlirRewritePatternCallbacks callbacks, void *userData, + size_t nGeneratedNames, MlirStringRef *generatedNames); + +//===----------------------------------------------------------------------===// +/// RewritePatternSet API +//===----------------------------------------------------------------------===// + +/// Create an empty MlirRewritePatternSet. +MLIR_CAPI_EXPORTED MlirRewritePatternSet +mlirRewritePatternSetCreate(MlirContext context); + +/// Destruct the given MlirRewritePatternSet. +MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set); + +/// Add the given MlirRewritePattern into a MlirRewritePatternSet. +/// Note that the ownership of the pattern is transferred to the set after this +/// call. +MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, + MlirRewritePattern pattern); + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h index 1038c0a575cf2..8cd51edf0837b 100644 --- a/mlir/include/mlir/CAPI/Rewrite.h +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -20,5 +20,7 @@ #include "mlir/IR/PatternMatch.h" DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase) +DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern) +DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet) #endif // MLIR_CAPIREWRITER_H diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 9e3d9703c82e8..d506b7fc9bc7b 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -45,6 +45,16 @@ class PyPatternRewriter { return PyInsertionPoint(PyOperation::forOperation(ctx, op)); } + void replaceOp(MlirOperation op, MlirOperation newOp) { + mlirRewriterBaseReplaceOpWithOperation(base, op, newOp); + } + + void replaceOp(MlirOperation op, const std::vector &values) { + mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); + } + + void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); } + private: MlirRewriterBase base; PyMlirContextRef ctx; @@ -165,13 +175,116 @@ class PyFrozenRewritePatternSet { MlirFrozenRewritePatternSet set; }; +class PyRewritePatternSet { +public: + PyRewritePatternSet(MlirContext ctx) + : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {} + ~PyRewritePatternSet() { + if (set.ptr) + mlirRewritePatternSetDestroy(set); + } + + void add(MlirStringRef rootName, unsigned benefit, + const nb::callable &matchAndRewrite) { + MlirRewritePatternCallbacks callbacks; + callbacks.construct = [](void *userData) { + nb::handle(static_cast(userData)).inc_ref(); + }; + callbacks.destruct = [](void *userData) { + nb::handle(static_cast(userData)).dec_ref(); + }; + callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, + MlirPatternRewriter rewriter, + void *userData) -> MlirLogicalResult { + nb::handle f(static_cast(userData)); + nb::object res = f(op, PyPatternRewriter(rewriter)); + return logicalResultFromObject(res); + }; + MlirRewritePattern pattern = mlirOpRewritePattenCreate( + rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), + /* nGeneratedNames */ 0, + /* generatedNames */ nullptr); + mlirRewritePatternSetAdd(set, pattern); + } + + PyFrozenRewritePatternSet freeze() { + MlirRewritePatternSet s = set; + set.ptr = nullptr; + return mlirFreezeRewritePattern(s); + } + +private: + MlirRewritePatternSet set; + MlirContext ctx; +}; + } // namespace /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { - nb::class_(m, "PatternRewriter") - .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, - "The current insertion point of the PatternRewriter."); + //---------------------------------------------------------------------------- + // Mapping of the PatternRewriter + //---------------------------------------------------------------------------- + nb:: + class_(m, "PatternRewriter") + .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, + "The current insertion point of the PatternRewriter.") + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + MlirOperation newOp) { self.replaceOp(op, newOp); }, + "Replace an operation with a new operation.", nb::arg("op"), + nb::arg("new_op"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ) + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + const std::vector &values) { + self.replaceOp(op, values); + }, + "Replace an operation with a list of values.", nb::arg("op"), + nb::arg("values"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") + // clang-format on + ) + .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", + nb::arg("op"), + // clang-format off + nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ); + + //---------------------------------------------------------------------------- + // Mapping of the RewritePatternSet + //---------------------------------------------------------------------------- + nb::class_(m, "RewritePattern"); + nb::class_(m, "RewritePatternSet") + .def( + "__init__", + [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { + new (&self) PyRewritePatternSet(context.get()->get()); + }, + "context"_a = nb::none()) + .def( + "add", + [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, + unsigned benefit) { + std::string opName = + nb::cast(root.attr("OPERATION_NAME")); + self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, + fn); + }, + "root"_a, "fn"_a, "benefit"_a = 1, + "Add a new rewrite pattern on the given root operation with the " + "callable as the matching and rewriting function and the given " + "benefit.") + .def("freeze", &PyRewritePatternSet::freeze, + "Freeze the pattern set into a frozen one."); + //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- @@ -237,7 +350,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { .def( "freeze", [](PyPDLPatternModule &self) { - return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }, nb::keep_alive<0, 1>()) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index c15a73b991f5d..70dee598c9535 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -270,15 +270,6 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return *(static_cast(module.ptr)); -} - -static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { - return {module}; -} - static inline mlir::FrozenRewritePatternSet * unwrap(MlirFrozenRewritePatternSet module) { assert(module.ptr && "unexpected null module"); @@ -290,15 +281,16 @@ wrap(mlir::FrozenRewritePatternSet *module) { return {module}; } -MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { - auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); - op.ptr = nullptr; +MlirFrozenRewritePatternSet +mlirFreezeRewritePattern(MlirRewritePatternSet set) { + auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); + set.ptr = nullptr; return wrap(m); } -void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { - delete unwrap(op); - op.ptr = nullptr; +void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) { + delete unwrap(set); + set.ptr = nullptr; } MlirLogicalResult @@ -332,6 +324,77 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { return wrap(static_cast(unwrap(rewriter))); } +//===----------------------------------------------------------------------===// +/// RewritePattern API +//===----------------------------------------------------------------------===// + +namespace mlir { + +class ExternalRewritePattern : public mlir::RewritePattern { +public: + ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData, + StringRef rootName, PatternBenefit benefit, + MLIRContext *context, + ArrayRef generatedNames) + : RewritePattern(rootName, benefit, context, generatedNames), + callbacks(callbacks), userData(userData) { + if (callbacks.construct) + callbacks.construct(userData); + } + + ~ExternalRewritePattern() { + if (callbacks.destruct) + callbacks.destruct(userData); + } + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + return unwrap(callbacks.matchAndRewrite( + wrap(static_cast(this)), wrap(op), + wrap(&rewriter), userData)); + } + +private: + MlirRewritePatternCallbacks callbacks; + void *userData; +}; + +} // namespace mlir + +MlirRewritePattern mlirOpRewritePattenCreate( + MlirStringRef rootName, unsigned benefit, MlirContext context, + MlirRewritePatternCallbacks callbacks, void *userData, + size_t nGeneratedNames, MlirStringRef *generatedNames) { + std::vector generatedNamesVec; + generatedNamesVec.reserve(nGeneratedNames); + for (size_t i = 0; i < nGeneratedNames; ++i) { + generatedNamesVec.push_back(unwrap(generatedNames[i])); + } + return wrap(new mlir::ExternalRewritePattern( + callbacks, userData, unwrap(rootName), PatternBenefit(benefit), + unwrap(context), generatedNamesVec)); +} + +//===----------------------------------------------------------------------===// +/// RewritePatternSet API +//===----------------------------------------------------------------------===// + +MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) { + return wrap(new mlir::RewritePatternSet(unwrap(context))); +} + +void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) { + delete unwrap(set); +} + +void mlirRewritePatternSetAdd(MlirRewritePatternSet set, + MlirRewritePattern pattern) { + std::unique_ptr patternPtr( + const_cast(unwrap(pattern))); + pattern.ptr = nullptr; + unwrap(set)->add(std::move(patternPtr)); +} + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py new file mode 100644 index 0000000000000..acf7db23db914 --- /dev/null +++ b/mlir/test/python/rewrite.py @@ -0,0 +1,69 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from mlir.ir import * +from mlir.passmanager import * +from mlir.dialects.builtin import ModuleOp +from mlir.dialects import arith +from mlir.rewrite import * + + +def run(f): + print("\nTEST:", f.__name__) + f() + + +# CHECK-LABEL: TEST: testRewritePattern +@run +def testRewritePattern(): + def to_muli(op, rewriter): + with rewriter.ip: + new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location) + rewriter.replace_op(op, new_op.owner) + + def constant_1_to_2(op, rewriter): + c = op.attributes["value"].value + if c != 1: + return True # failed to match + with rewriter.ip: + new_op = arith.constant(op.result.type, 2, loc=op.location) + rewriter.replace_op(op, [new_op]) + + with Context(): + patterns = RewritePatternSet() + patterns.add(arith.AddIOp, to_muli) + patterns.add(arith.ConstantOp, constant_1_to_2) + frozen = patterns.freeze() + + module = ModuleOp.parse( + r""" + module { + func.func @add(%a: i64, %b: i64) -> i64 { + %sum = arith.addi %a, %b : i64 + return %sum : i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %0 = arith.muli %arg0, %arg1 : i64 + # CHECK: return %0 : i64 + print(module) + + module = ModuleOp.parse( + r""" + module { + func.func @const() -> (i64, i64) { + %0 = arith.constant 1 : i64 + %1 = arith.constant 3 : i64 + return %0, %1 : i64, i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %c2_i64 = arith.constant 2 : i64 + # CHECK: %c3_i64 = arith.constant 3 : i64 + # CHECK: return %c2_i64, %c3_i64 : i64, i64 + print(module)