From 118e4c7da941001295148006c3a45193ff078259 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 20 Sep 2025 01:47:41 +0800 Subject: [PATCH 01/12] [MLIR][Python] Add bindings for PDL native function registering --- mlir/include/mlir-c/Rewrite.h | 32 +++++++ mlir/lib/Bindings/Python/Rewrite.cpp | 74 +++++++++++++-- mlir/lib/CAPI/Transforms/Rewrite.cpp | 99 ++++++++++++++++++++ mlir/test/python/integration/dialects/pdl.py | 82 ++++++++++++++++ 4 files changed, 281 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 374d2fb78de88..c20558fc8f9d9 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void); DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); +DEFINE_C_API_STRUCT(MlirPatternRewriter, void); //===----------------------------------------------------------------------===// /// RewriterBase API inherited from OpBuilder @@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( #if MLIR_ENABLE_PDL_IN_PATTERNMATCH DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); +DEFINE_C_API_STRUCT(MlirPDLValue, const void); +DEFINE_C_API_STRUCT(MlirPDLResultList, void); MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op); @@ -323,6 +326,35 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op); MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op); + +MLIR_CAPI_EXPORTED bool mlirPDLValueIsValue(MlirPDLValue value); +MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value); +MLIR_CAPI_EXPORTED bool mlirPDLValueIsType(MlirPDLValue value); +MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value); +MLIR_CAPI_EXPORTED bool mlirPDLValueIsOperation(MlirPDLValue value); +MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value); +MLIR_CAPI_EXPORTED bool mlirPDLValueIsAttribute(MlirPDLValue value); +MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value); + +MLIR_CAPI_EXPORTED void +mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value); +MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results, + MlirType value); +MLIR_CAPI_EXPORTED void +mlirPDLResultListPushBackOperation(MlirPDLResultList results, + MlirOperation value); +MLIR_CAPI_EXPORTED void +mlirPDLResultListPushBackAttribute(MlirPDLResultList results, + MlirAttribute value); + +typedef MlirLogicalResult (*MlirPDLRewriteFunction)( + MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, + MlirPDLValue *values, void *userData); + +MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction( + MlirPDLPatternModule module, MlirStringRef name, + MlirPDLRewriteFunction rewriteFn, void *userData); + #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH #undef DEFINE_C_API_STRUCT diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 5b7de50f02e6a..89dda560702ba 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -9,10 +9,13 @@ #include "Rewrite.h" #include "IRModule.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +#include "mlir-c/IR.h" #include "mlir-c/Rewrite.h" +#include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir/Config/mlir-config.h" +#include "nanobind/nanobind.h" namespace nb = nanobind; using namespace mlir; @@ -36,6 +39,22 @@ class PyPDLPatternModule { } MlirPDLPatternModule get() { return module; } + void registerRewriteFunction(const std::string &name, + const nb::callable &fn) { + mlirPDLPatternModuleRegisterRewriteFunction( + get(), mlirStringRefCreate(name.data(), name.size()), + [](MlirPatternRewriter rewriter, MlirPDLResultList results, + size_t nValues, MlirPDLValue *values, + void *userData) -> MlirLogicalResult { + auto f = nb::handle(static_cast(userData)); + auto valueVec = std::vector(values, values + nValues); + return nb::cast(f(rewriter, results, valueVec)) + ? mlirLogicalResultSuccess() + : mlirLogicalResultFailure(); + }, + fn.ptr()); + } + private: MlirPDLPatternModule module; }; @@ -76,10 +95,43 @@ class PyFrozenRewritePatternSet { /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { + nb::class_(m, "PatternRewriter"); //---------------------------------------------------------------------------- - // Mapping of the top-level PassManager + // Mapping of the PDLModule //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH + nb::class_(m, "PDLValue").def("get", [](MlirPDLValue value) { + if (mlirPDLValueIsValue(value)) { + return nb::cast(mlirPDLValueAsValue(value)); + } + if (mlirPDLValueIsOperation(value)) { + return nb::cast(mlirPDLValueAsOperation(value)); + } + if (mlirPDLValueIsAttribute(value)) { + return nb::cast(mlirPDLValueAsAttribute(value)); + } + if (mlirPDLValueIsType(value)) { + return nb::cast(mlirPDLValueAsType(value)); + } + + throw std::runtime_error("unsupported PDL value type"); + }); + nb::class_(m, "PDLResultList") + .def("push_back", + [](MlirPDLResultList results, const PyValue &value) { + mlirPDLResultListPushBackValue(results, value); + }) + .def("push_back", + [](MlirPDLResultList results, const PyOperation &op) { + mlirPDLResultListPushBackOperation(results, op); + }) + .def("push_back", + [](MlirPDLResultList results, const PyType &type) { + mlirPDLResultListPushBackType(results, type); + }) + .def("push_back", [](MlirPDLResultList results, MlirAttribute attr) { + mlirPDLResultListPushBackAttribute(results, attr); + }); nb::class_(m, "PDLModule") .def( "__init__", @@ -88,10 +140,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); }, "module"_a, "Create a PDL module from the given module.") - .def("freeze", [](PyPDLPatternModule &self) { - return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( - mlirRewritePatternSetFromPDLPatternModule(self.get()))); - }); + .def( + "freeze", + [](PyPDLPatternModule &self) { + return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + mlirRewritePatternSetFromPDLPatternModule(self.get()))); + }, + nb::keep_alive<0, 1>()) + .def( + "register_rewrite_function", + [](PyPDLPatternModule &self, const std::string &name, + const nb::callable &fn) { + self.registerRewriteFunction(name, fn); + }, + nb::keep_alive<1, 3>()); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "FrozenRewritePatternSet") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 6f85357a14a18..0033abde986ea 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -13,6 +13,8 @@ #include "mlir/CAPI/Rewrite.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PDLPatternMatch.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); } +//===----------------------------------------------------------------------===// +/// PatternRewriter API +//===----------------------------------------------------------------------===// + +inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { + assert(rewriter.ptr && "unexpected null rewriter"); + return static_cast(rewriter.ptr); +} + +inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { + return {rewriter}; +} + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// @@ -331,4 +346,88 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { op.ptr = nullptr; return wrap(m); } + +inline const mlir::PDLValue *unwrap(MlirPDLValue value) { + assert(value.ptr && "unexpected null PDL value"); + return static_cast(value.ptr); +} + +inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } + +inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { + assert(results.ptr && "unexpected null PDL results"); + return static_cast(results.ptr); +} + +inline MlirPDLResultList wrap(mlir::PDLResultList *results) { + return {results}; +} + +bool mlirPDLValueIsValue(MlirPDLValue value) { + return unwrap(value)->isa(); +} + +MlirValue mlirPDLValueAsValue(MlirPDLValue value) { + return wrap(unwrap(value)->cast()); +} + +bool mlirPDLValueIsType(MlirPDLValue value) { + return unwrap(value)->isa(); +} + +MlirType mlirPDLValueAsType(MlirPDLValue value) { + return wrap(unwrap(value)->cast()); +} + +bool mlirPDLValueIsOperation(MlirPDLValue value) { + return unwrap(value)->isa(); +} + +MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) { + return wrap(unwrap(value)->cast()); +} + +bool mlirPDLValueIsAttribute(MlirPDLValue value) { + return unwrap(value)->isa(); +} + +MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) { + return wrap(unwrap(value)->cast()); +} + +void mlirPDLResultListPushBackValue(MlirPDLResultList results, + MlirValue value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackOperation(MlirPDLResultList results, + MlirOperation value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, + MlirAttribute value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLPatternModuleRegisterRewriteFunction( + MlirPDLPatternModule module, MlirStringRef name, + MlirPDLRewriteFunction rewriteFn, void *userData) { + unwrap(module)->registerRewriteFunction( + unwrap(name), + [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results, + ArrayRef values) -> LogicalResult { + std::vector mlirValues; + for (auto &value : values) { + mlirValues.push_back(wrap(&value)); + } + return unwrap(rewriteFn(wrap(&rewriter), wrap(&results), + mlirValues.size(), mlirValues.data(), + userData)); + }); +} #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index dd6c74ce622c8..8954e3622d3ef 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -86,3 +86,85 @@ def add_func(a, b): frozen = get_pdl_patterns() apply_patterns_and_fold_greedily(module_.operation, frozen) return module_ + + +# If we use arith.constant and arith.addi here, +# these C++-defined folding/canonicalization will be applied +# implicitly in the greedy pattern rewrite driver to +# make our Python-defined folding useless, +# so here we define a new dialect to workaround this. +def load_myint_dialect(): + from mlir.dialects import irdl + m = Module.create() + with InsertionPoint(m.body): + myint = irdl.dialect("myint") + with InsertionPoint(myint.body): + constant = irdl.operation_("constant") + with InsertionPoint(constant.body): + iattr = irdl.base(base_name="#builtin.integer") + i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32))) + irdl.attributes_([iattr], ["value"]) + irdl.results_([i32], ["cst"], [irdl.Variadicity.single]) + add = irdl.operation_("add") + with InsertionPoint(add.body): + i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32))) + irdl.operands_([i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single]) + irdl.results_([i32], ["res"], [irdl.Variadicity.single]) + + m.operation.verify() + irdl.load_dialects(m) + +# this PDL pattern is to fold constant additions, +# i.e. add(constant0, constant1) -> constant2 +# where constant2 = constant0 + constant1 +def get_pdl_pattern_fold(): + m = Module.create() + with InsertionPoint(m.body): + @pdl.pattern(benefit=1, sym_name="myint_add_fold") + def pat(): + t = pdl.TypeOp(IntegerType.get_signless(32)) + a0 = pdl.AttributeOp() + a1 = pdl.AttributeOp() + c0 = pdl.OperationOp(name="myint.constant", attributes={"value": a0}, types=[t]) + c1 = pdl.OperationOp(name="myint.constant", attributes={"value": a1}, types=[t]) + v0 = pdl.ResultOp(c0, 0) + v1 = pdl.ResultOp(c1, 0) + op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) + + @pdl.rewrite() + def rew(): + sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1]) + newOp = pdl.OperationOp( + name="myint.constant", attributes={"value": sum}, types=[t] + ) + pdl.ReplaceOp(op0, with_op=newOp) + + pdl_module = PDLModule(m) + def add_fold(rewriter, results, values): + a0, a1 = [i.get() for i in values] + i32 = IntegerType.get_signless(32) + results.push_back(IntegerAttr.get(i32, a0.value + a1.value)) + return True + pdl_module.register_rewrite_function("add_fold", add_fold) + return pdl_module.freeze() + + +# CHECK-LABEL: TEST: test_pdl_register_function +# CHECK: "myint.constant"() {value = 8 : i32} : () -> i32 +@construct_and_print_in_module +def test_pdl_register_function(module_): + load_myint_dialect() + + module_ = Module.parse( + """ + %c0 = "myint.constant"() { value = 2 }: () -> (i32) + %c1 = "myint.constant"() { value = 3 }: () -> (i32) + %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32) + "myint.add"(%x, %c1): (i32, i32) -> (i32) + """ + ) + + frozen = get_pdl_pattern_fold() + apply_patterns_and_fold_greedily(module_, frozen) + + return module_ From f1315a6088031967a673406c239173333bd3103a Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 20 Sep 2025 20:56:58 +0800 Subject: [PATCH 02/12] fix style --- mlir/test/python/integration/dialects/pdl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index 8954e3622d3ef..64628db072be9 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -114,9 +114,9 @@ def load_myint_dialect(): m.operation.verify() irdl.load_dialects(m) -# this PDL pattern is to fold constant additions, +# This PDL pattern is to fold constant additions, # i.e. add(constant0, constant1) -> constant2 -# where constant2 = constant0 + constant1 +# where constant2 = constant0 + constant1. def get_pdl_pattern_fold(): m = Module.create() with InsertionPoint(m.body): @@ -139,12 +139,13 @@ def rew(): ) pdl.ReplaceOp(op0, with_op=newOp) - pdl_module = PDLModule(m) def add_fold(rewriter, results, values): a0, a1 = [i.get() for i in values] i32 = IntegerType.get_signless(32) results.push_back(IntegerAttr.get(i32, a0.value + a1.value)) return True + + pdl_module = PDLModule(m) pdl_module.register_rewrite_function("add_fold", add_fold) return pdl_module.freeze() From 2f3da2e0c356721311d2828b1293c35204d2fd6e Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 20 Sep 2025 22:08:15 +0800 Subject: [PATCH 03/12] format --- mlir/test/python/integration/dialects/pdl.py | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index 64628db072be9..5b33802cefaba 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -95,6 +95,7 @@ def add_func(a, b): # so here we define a new dialect to workaround this. def load_myint_dialect(): from mlir.dialects import irdl + m = Module.create() with InsertionPoint(m.body): myint = irdl.dialect("myint") @@ -108,32 +109,44 @@ def load_myint_dialect(): add = irdl.operation_("add") with InsertionPoint(add.body): i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32))) - irdl.operands_([i32, i32], ["lhs", "rhs"], [irdl.Variadicity.single, irdl.Variadicity.single]) + irdl.operands_( + [i32, i32], + ["lhs", "rhs"], + [irdl.Variadicity.single, irdl.Variadicity.single] + ) irdl.results_([i32], ["res"], [irdl.Variadicity.single]) m.operation.verify() irdl.load_dialects(m) + # This PDL pattern is to fold constant additions, # i.e. add(constant0, constant1) -> constant2 # where constant2 = constant0 + constant1. def get_pdl_pattern_fold(): m = Module.create() with InsertionPoint(m.body): + @pdl.pattern(benefit=1, sym_name="myint_add_fold") def pat(): t = pdl.TypeOp(IntegerType.get_signless(32)) a0 = pdl.AttributeOp() a1 = pdl.AttributeOp() - c0 = pdl.OperationOp(name="myint.constant", attributes={"value": a0}, types=[t]) - c1 = pdl.OperationOp(name="myint.constant", attributes={"value": a1}, types=[t]) + c0 = pdl.OperationOp( + name="myint.constant", attributes={"value": a0}, types=[t] + ) + c1 = pdl.OperationOp( + name="myint.constant", attributes={"value": a1}, types=[t] + ) v0 = pdl.ResultOp(c0, 0) v1 = pdl.ResultOp(c1, 0) op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) @pdl.rewrite() def rew(): - sum = pdl.apply_native_rewrite([pdl.AttributeType.get()], "add_fold", [a0, a1]) + sum = pdl.apply_native_rewrite( + [pdl.AttributeType.get()], "add_fold", [a0, a1] + ) newOp = pdl.OperationOp( name="myint.constant", attributes={"value": sum}, types=[t] ) From b2de4ec9ea125ce6ff181fce352da75fec00d9ba Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 20 Sep 2025 22:27:58 +0800 Subject: [PATCH 04/12] remove useless bindings --- mlir/lib/Bindings/Python/Rewrite.cpp | 43 +++++++++++--------- mlir/test/python/integration/dialects/pdl.py | 8 ++-- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 89dda560702ba..e1cf61677fef2 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -39,6 +39,23 @@ class PyPDLPatternModule { } MlirPDLPatternModule get() { return module; } + static nb::object fromPDLValue(MlirPDLValue value) { + if (mlirPDLValueIsValue(value)) { + return nb::cast(mlirPDLValueAsValue(value)); + } + if (mlirPDLValueIsOperation(value)) { + return nb::cast(mlirPDLValueAsOperation(value)); + } + if (mlirPDLValueIsAttribute(value)) { + return nb::cast(mlirPDLValueAsAttribute(value)); + } + if (mlirPDLValueIsType(value)) { + return nb::cast(mlirPDLValueAsType(value)); + } + + throw std::runtime_error("unsupported PDL value type"); + } + void registerRewriteFunction(const std::string &name, const nb::callable &fn) { mlirPDLPatternModuleRegisterRewriteFunction( @@ -47,8 +64,12 @@ class PyPDLPatternModule { size_t nValues, MlirPDLValue *values, void *userData) -> MlirLogicalResult { auto f = nb::handle(static_cast(userData)); - auto valueVec = std::vector(values, values + nValues); - return nb::cast(f(rewriter, results, valueVec)) + std::vector args; + args.reserve(nValues); + for (size_t i = 0; i < nValues; ++i) { + args.push_back(fromPDLValue(values[i])); + } + return nb::cast(f(rewriter, results, args)) ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); }, @@ -97,25 +118,9 @@ class PyFrozenRewritePatternSet { void mlir::python::populateRewriteSubmodule(nb::module_ &m) { nb::class_(m, "PatternRewriter"); //---------------------------------------------------------------------------- - // Mapping of the PDLModule + // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH - nb::class_(m, "PDLValue").def("get", [](MlirPDLValue value) { - if (mlirPDLValueIsValue(value)) { - return nb::cast(mlirPDLValueAsValue(value)); - } - if (mlirPDLValueIsOperation(value)) { - return nb::cast(mlirPDLValueAsOperation(value)); - } - if (mlirPDLValueIsAttribute(value)) { - return nb::cast(mlirPDLValueAsAttribute(value)); - } - if (mlirPDLValueIsType(value)) { - return nb::cast(mlirPDLValueAsType(value)); - } - - throw std::runtime_error("unsupported PDL value type"); - }); nb::class_(m, "PDLResultList") .def("push_back", [](MlirPDLResultList results, const PyValue &value) { diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index 5b33802cefaba..c78f2d4f9a0dc 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -112,7 +112,7 @@ def load_myint_dialect(): irdl.operands_( [i32, i32], ["lhs", "rhs"], - [irdl.Variadicity.single, irdl.Variadicity.single] + [irdl.Variadicity.single, irdl.Variadicity.single], ) irdl.results_([i32], ["res"], [irdl.Variadicity.single]) @@ -125,11 +125,12 @@ def load_myint_dialect(): # where constant2 = constant0 + constant1. def get_pdl_pattern_fold(): m = Module.create() + i32 = IntegerType.get_signless(32) with InsertionPoint(m.body): @pdl.pattern(benefit=1, sym_name="myint_add_fold") def pat(): - t = pdl.TypeOp(IntegerType.get_signless(32)) + t = pdl.TypeOp(i32) a0 = pdl.AttributeOp() a1 = pdl.AttributeOp() c0 = pdl.OperationOp( @@ -153,8 +154,7 @@ def rew(): pdl.ReplaceOp(op0, with_op=newOp) def add_fold(rewriter, results, values): - a0, a1 = [i.get() for i in values] - i32 = IntegerType.get_signless(32) + a0, a1 = values results.push_back(IntegerAttr.get(i32, a0.value + a1.value)) return True From d6db1e5be4d3f11127c8b978276efc19da9f9eb0 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 20 Sep 2025 22:33:59 +0800 Subject: [PATCH 05/12] fix header order --- mlir/lib/Bindings/Python/Rewrite.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index e1cf61677fef2..52be91223c5f8 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -9,11 +9,13 @@ #include "Rewrite.h" #include "IRModule.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir-c/IR.h" #include "mlir-c/Rewrite.h" #include "mlir-c/Support.h" +// clang-format off #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +// clang-format on #include "mlir/Config/mlir-config.h" #include "nanobind/nanobind.h" From c2f727e12c525a00c0192037cb79ead204cf5ab5 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 20 Sep 2025 23:04:35 +0800 Subject: [PATCH 06/12] fix --- mlir/lib/Bindings/Python/Rewrite.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 52be91223c5f8..eceb5895fd901 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -136,7 +136,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { [](MlirPDLResultList results, const PyType &type) { mlirPDLResultListPushBackType(results, type); }) - .def("push_back", [](MlirPDLResultList results, MlirAttribute attr) { + .def("push_back", [](MlirPDLResultList results, const PyAttribute &attr) { mlirPDLResultListPushBackAttribute(results, attr); }); nb::class_(m, "PDLModule") From 64350847220748ddc87ff764cf402bb03367baf6 Mon Sep 17 00:00:00 2001 From: Twice Date: Sun, 21 Sep 2025 11:27:08 +0800 Subject: [PATCH 07/12] Apply suggestion from @makslevental Co-authored-by: Maksim Levental --- mlir/lib/Bindings/Python/Rewrite.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index eceb5895fd901..d8194388b195b 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -42,18 +42,14 @@ class PyPDLPatternModule { MlirPDLPatternModule get() { return module; } static nb::object fromPDLValue(MlirPDLValue value) { - if (mlirPDLValueIsValue(value)) { + if (mlirPDLValueIsValue(value)) return nb::cast(mlirPDLValueAsValue(value)); - } - if (mlirPDLValueIsOperation(value)) { + if (mlirPDLValueIsOperation(value)) return nb::cast(mlirPDLValueAsOperation(value)); - } - if (mlirPDLValueIsAttribute(value)) { + if (mlirPDLValueIsAttribute(value)) return nb::cast(mlirPDLValueAsAttribute(value)); - } - if (mlirPDLValueIsType(value)) { + if (mlirPDLValueIsType(value)) return nb::cast(mlirPDLValueAsType(value)); - } throw std::runtime_error("unsupported PDL value type"); } From 0653ac60d67ed853d1dea24ac352d5bd59cb806b Mon Sep 17 00:00:00 2001 From: Twice Date: Sun, 21 Sep 2025 11:27:25 +0800 Subject: [PATCH 08/12] Apply suggestion from @makslevental Co-authored-by: Maksim Levental --- mlir/lib/Bindings/Python/Rewrite.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index d8194388b195b..aee8534b33f7e 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -64,9 +64,8 @@ class PyPDLPatternModule { auto f = nb::handle(static_cast(userData)); std::vector args; args.reserve(nValues); - for (size_t i = 0; i < nValues; ++i) { + for (size_t i = 0; i < nValues; ++i) args.push_back(fromPDLValue(values[i])); - } return nb::cast(f(rewriter, results, args)) ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); From e9845d639f6243d6133cd79d4a861563dc6784ca Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sun, 21 Sep 2025 13:25:11 +0800 Subject: [PATCH 09/12] move out pdlvalue cast and logical result conversion --- mlir/lib/Bindings/Python/Rewrite.cpp | 42 +++++++++++--------- mlir/test/python/integration/dialects/pdl.py | 1 - 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index ca4e3e331e9b5..3870e6887dfdb 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -27,6 +27,27 @@ using namespace mlir::python; namespace { #if MLIR_ENABLE_PDL_IN_PATTERNMATCH +nb::object objectFromPDLValue(MlirPDLValue value) { + if (mlirPDLValueIsValue(value)) + return nb::cast(mlirPDLValueAsValue(value)); + if (mlirPDLValueIsOperation(value)) + return nb::cast(mlirPDLValueAsOperation(value)); + if (mlirPDLValueIsAttribute(value)) + return nb::cast(mlirPDLValueAsAttribute(value)); + if (mlirPDLValueIsType(value)) + return nb::cast(mlirPDLValueAsType(value)); + + throw std::runtime_error("unsupported PDL value type"); +} + +MlirLogicalResult logicalResultFromObject(const nb::object &obj) { + if (obj.is_none()) + return mlirLogicalResultSuccess(); + + return nb::cast(obj) ? mlirLogicalResultFailure() + : mlirLogicalResultSuccess(); +} + /// Owning Wrapper around a PDLPatternModule. class PyPDLPatternModule { public: @@ -41,19 +62,6 @@ class PyPDLPatternModule { } MlirPDLPatternModule get() { return module; } - static nb::object fromPDLValue(MlirPDLValue value) { - if (mlirPDLValueIsValue(value)) - return nb::cast(mlirPDLValueAsValue(value)); - if (mlirPDLValueIsOperation(value)) - return nb::cast(mlirPDLValueAsOperation(value)); - if (mlirPDLValueIsAttribute(value)) - return nb::cast(mlirPDLValueAsAttribute(value)); - if (mlirPDLValueIsType(value)) - return nb::cast(mlirPDLValueAsType(value)); - - throw std::runtime_error("unsupported PDL value type"); - } - void registerRewriteFunction(const std::string &name, const nb::callable &fn) { mlirPDLPatternModuleRegisterRewriteFunction( @@ -61,14 +69,12 @@ class PyPDLPatternModule { [](MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, MlirPDLValue *values, void *userData) -> MlirLogicalResult { - auto f = nb::handle(static_cast(userData)); + nb::handle f = nb::handle(static_cast(userData)); std::vector args; args.reserve(nValues); for (size_t i = 0; i < nValues; ++i) - args.push_back(fromPDLValue(values[i])); - return nb::cast(f(rewriter, results, args)) - ? mlirLogicalResultSuccess() - : mlirLogicalResultFailure(); + args.push_back(objectFromPDLValue(values[i])); + return logicalResultFromObject(f(rewriter, results, args)); }, fn.ptr()); } diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index c78f2d4f9a0dc..8fbe1a7151f63 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -156,7 +156,6 @@ def rew(): def add_fold(rewriter, results, values): a0, a1 = values results.push_back(IntegerAttr.get(i32, a0.value + a1.value)) - return True pdl_module = PDLModule(m) pdl_module.register_rewrite_function("add_fold", add_fold) From bf7940966941de09273a170c98fa5e6ffbdaf867 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sun, 21 Sep 2025 14:28:32 +0800 Subject: [PATCH 10/12] add sigs --- mlir/lib/Bindings/Python/Rewrite.cpp | 51 ++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 3870e6887dfdb..7f71f134e44ae 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -125,21 +125,42 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "PDLResultList") - .def("push_back", - [](MlirPDLResultList results, const PyValue &value) { - mlirPDLResultListPushBackValue(results, value); - }) - .def("push_back", - [](MlirPDLResultList results, const PyOperation &op) { - mlirPDLResultListPushBackOperation(results, op); - }) - .def("push_back", - [](MlirPDLResultList results, const PyType &type) { - mlirPDLResultListPushBackType(results, type); - }) - .def("push_back", [](MlirPDLResultList results, const PyAttribute &attr) { - mlirPDLResultListPushBackAttribute(results, attr); - }); + .def( + "push_back", + [](MlirPDLResultList results, const PyValue &value) { + mlirPDLResultListPushBackValue(results, value); + }, + // clang-format off + nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")") + // clang-format on + ) + .def( + "push_back", + [](MlirPDLResultList results, const PyOperation &op) { + mlirPDLResultListPushBackOperation(results, op); + }, + // clang-format off + nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")") + // clang-format on + ) + .def( + "push_back", + [](MlirPDLResultList results, const PyType &type) { + mlirPDLResultListPushBackType(results, type); + }, + // clang-format off + nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")") + // clang-format on + ) + .def( + "push_back", + [](MlirPDLResultList results, const PyAttribute &attr) { + mlirPDLResultListPushBackAttribute(results, attr); + }, + // clang-format off + nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")") + // clang-format on + ); nb::class_(m, "PDLModule") .def( "__init__", From e821fd4339180b45bdcf8bb2aa77edb404b6a7c2 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Tue, 23 Sep 2025 22:08:42 +0800 Subject: [PATCH 11/12] apply review suggestions --- mlir/include/mlir-c/Rewrite.h | 30 ++++++++++++--- mlir/include/mlir/IR/PDLPatternMatch.h.inc | 2 +- mlir/lib/Bindings/Python/Rewrite.cpp | 40 +++++++++++--------- mlir/lib/CAPI/Transforms/Rewrite.cpp | 25 +++--------- mlir/test/python/integration/dialects/pdl.py | 2 +- 5 files changed, 54 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index c20558fc8f9d9..f4974348945c5 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -327,32 +327,52 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op); MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op); -MLIR_CAPI_EXPORTED bool mlirPDLValueIsValue(MlirPDLValue value); +/// Cast the MlirPDLValue to an MlirValue. +/// Return a null value if the cast fails, just like llvm::dyn_cast. MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value); -MLIR_CAPI_EXPORTED bool mlirPDLValueIsType(MlirPDLValue value); + +/// Cast the MlirPDLValue to an MlirType. +/// Return a null value if the cast fails, just like llvm::dyn_cast. MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value); -MLIR_CAPI_EXPORTED bool mlirPDLValueIsOperation(MlirPDLValue value); + +/// Cast the MlirPDLValue to an MlirOperation. +/// Return a null value if the cast fails, just like llvm::dyn_cast. MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value); -MLIR_CAPI_EXPORTED bool mlirPDLValueIsAttribute(MlirPDLValue value); + +/// Cast the MlirPDLValue to an MlirAttribute. +/// Return a null value if the cast fails, just like llvm::dyn_cast. MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value); +/// Push the MlirValue into the given MlirPDLResultList. MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value); + +/// Push the MlirType into the given MlirPDLResultList. MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value); + +/// Push the MlirOperation into the given MlirPDLResultList. MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackOperation(MlirPDLResultList results, MlirOperation value); + +/// Push the MlirAttribute into the given MlirPDLResultList. MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, MlirAttribute value); +/// This function type is used as callbacks for PDL native rewrite functions. +/// Input values can be accessed by `values` with its size `nValues`; +/// output values can be added into `results` by `mlirPDLResultListPushBack*` +/// APIs. And the return value indicates whether the rewrite succeeds. typedef MlirLogicalResult (*MlirPDLRewriteFunction)( MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, MlirPDLValue *values, void *userData); +/// Register a rewrite function into the given PDL pattern module. +/// `userData` will be provided as an argument to the rewrite function. MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction( - MlirPDLPatternModule module, MlirStringRef name, + MlirPDLPatternModule pdlModule, MlirStringRef name, MlirPDLRewriteFunction rewriteFn, void *userData); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc index 96ba98a850de0..d5fb57d7c360d 100644 --- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc +++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc @@ -53,7 +53,7 @@ public: /// value is not an instance of `T`. template ::value, T, std::optional>> + std::is_constructible_v, T, std::optional>> ResultT dyn_cast() const { return isa() ? castImpl() : ResultT(); } diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 7f71f134e44ae..72062f660458b 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -27,20 +27,24 @@ using namespace mlir::python; namespace { #if MLIR_ENABLE_PDL_IN_PATTERNMATCH -nb::object objectFromPDLValue(MlirPDLValue value) { - if (mlirPDLValueIsValue(value)) - return nb::cast(mlirPDLValueAsValue(value)); - if (mlirPDLValueIsOperation(value)) - return nb::cast(mlirPDLValueAsOperation(value)); - if (mlirPDLValueIsAttribute(value)) - return nb::cast(mlirPDLValueAsAttribute(value)); - if (mlirPDLValueIsType(value)) - return nb::cast(mlirPDLValueAsType(value)); +static nb::object objectFromPDLValue(MlirPDLValue value) { + if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v)) + return nb::cast(v); + if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v)) + return nb::cast(v); + if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v)) + return nb::cast(v); + if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v)) + return nb::cast(v); throw std::runtime_error("unsupported PDL value type"); } -MlirLogicalResult logicalResultFromObject(const nb::object &obj) { +// Convert the Python object to a boolean. +// If it evaluates to False, treat it as success; +// otherwise, treat it as failure. +// Note that None is considered success. +static MlirLogicalResult logicalResultFromObject(const nb::object &obj) { if (obj.is_none()) return mlirLogicalResultSuccess(); @@ -126,39 +130,39 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { #if MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "PDLResultList") .def( - "push_back", + "append", [](MlirPDLResultList results, const PyValue &value) { mlirPDLResultListPushBackValue(results, value); }, // clang-format off - nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")") + nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")") // clang-format on ) .def( - "push_back", + "append", [](MlirPDLResultList results, const PyOperation &op) { mlirPDLResultListPushBackOperation(results, op); }, // clang-format off - nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")") + nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")") // clang-format on ) .def( - "push_back", + "append", [](MlirPDLResultList results, const PyType &type) { mlirPDLResultListPushBackType(results, type); }, // clang-format off - nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")") + nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")") // clang-format on ) .def( - "push_back", + "append", [](MlirPDLResultList results, const PyAttribute &attr) { mlirPDLResultListPushBackAttribute(results, attr); }, // clang-format off - nb::sig("def push_back(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")") + nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")") // clang-format on ); nb::class_(m, "PDLModule") diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 0033abde986ea..8b41e7022bc18 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -363,36 +363,20 @@ inline MlirPDLResultList wrap(mlir::PDLResultList *results) { return {results}; } -bool mlirPDLValueIsValue(MlirPDLValue value) { - return unwrap(value)->isa(); -} - MlirValue mlirPDLValueAsValue(MlirPDLValue value) { - return wrap(unwrap(value)->cast()); -} - -bool mlirPDLValueIsType(MlirPDLValue value) { - return unwrap(value)->isa(); + return wrap(unwrap(value)->dyn_cast()); } MlirType mlirPDLValueAsType(MlirPDLValue value) { - return wrap(unwrap(value)->cast()); -} - -bool mlirPDLValueIsOperation(MlirPDLValue value) { - return unwrap(value)->isa(); + return wrap(unwrap(value)->dyn_cast()); } MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) { - return wrap(unwrap(value)->cast()); -} - -bool mlirPDLValueIsAttribute(MlirPDLValue value) { - return unwrap(value)->isa(); + return wrap(unwrap(value)->dyn_cast()); } MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) { - return wrap(unwrap(value)->cast()); + return wrap(unwrap(value)->dyn_cast()); } void mlirPDLResultListPushBackValue(MlirPDLResultList results, @@ -422,6 +406,7 @@ void mlirPDLPatternModuleRegisterRewriteFunction( [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results, ArrayRef values) -> LogicalResult { std::vector mlirValues; + mlirValues.reserve(values.size()); for (auto &value : values) { mlirValues.push_back(wrap(&value)); } diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index 8fbe1a7151f63..e85c6c77ef955 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -155,7 +155,7 @@ def rew(): def add_fold(rewriter, results, values): a0, a1 = values - results.push_back(IntegerAttr.get(i32, a0.value + a1.value)) + results.append(IntegerAttr.get(i32, a0.value + a1.value)) pdl_module = PDLModule(m) pdl_module.register_rewrite_function("add_fold", add_fold) From b08315f65006cfaebaaa56e5f401217ca7300290 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Tue, 23 Sep 2025 22:11:54 +0800 Subject: [PATCH 12/12] rename --- mlir/lib/CAPI/Transforms/Rewrite.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 8b41e7022bc18..9ecce956a05b9 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -399,9 +399,9 @@ void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, } void mlirPDLPatternModuleRegisterRewriteFunction( - MlirPDLPatternModule module, MlirStringRef name, + MlirPDLPatternModule pdlModule, MlirStringRef name, MlirPDLRewriteFunction rewriteFn, void *userData) { - unwrap(module)->registerRewriteFunction( + unwrap(pdlModule)->registerRewriteFunction( unwrap(name), [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results, ArrayRef values) -> LogicalResult {