-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] Add bindings for PDL native rewrite function registering #159926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the Python code formatter. |
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesIn the MLIR Python bindings, we can currently use PDL to define simple patterns and then execute them with the greedy rewrite driver. However, when dealing with more complex patterns—such as constant folding for integer addition—we find that we need pdl.pattern : benefit(1) {
%a0 = pdl.attribute
%a1 = pdl.attribute
%c0 = pdl.operation "arith.constant" {value = %a0}
%c1 = pdl.operation "arith.constant" {value = %a1}
%op = pdl.operation "arith.addi"(%c0, %c1)
%sum = pdl.apply_native_rewrite "addIntegers"(%a0, %a1)
%new_cst = pdl.operation "arith.constant" {value = %sum}
pdl.replace %op with %new_cst
} Here, As a test case, we defined two new operations ( m = Module.create()
with InsertionPoint(m.body):
@<!-- -->pdl.pattern(benefit=1, sym_name="myint_add_fold")
def pat():
...
op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
@<!-- -->pdl.rewrite()
def rew():
sum = pdl.apply_native_rewrite(
[pdl.AttributeType.get()], "add_fold", [a0, a1]
)
newOp = pdl.OperationOp(
name="myint.constant", attributes={"value": sum}, types=[t]
)
pdl.ReplaceOp(op0, with_op=newOp)
def add_fold(rewriter, results, values):
a0, a1 = values
results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
return True
pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold) Full diff: https://github.com/llvm/llvm-project/pull/159926.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 374d2fb78de88..c20558fc8f9d9 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void);
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
+DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
@@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
+DEFINE_C_API_STRUCT(MlirPDLValue, const void);
+DEFINE_C_API_STRUCT(MlirPDLResultList, void);
MLIR_CAPI_EXPORTED MlirPDLPatternModule
mlirPDLPatternModuleFromModule(MlirModule op);
@@ -323,6 +326,35 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
+
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value);
+MLIR_CAPI_EXPORTED bool mlirPDLValueIsAttribute(MlirPDLValue value);
+MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value);
+
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value);
+MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results,
+ MlirType value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+ MlirOperation value);
+MLIR_CAPI_EXPORTED void
+mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+ MlirAttribute value);
+
+typedef MlirLogicalResult (*MlirPDLRewriteFunction)(
+ MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
+ MlirPDLValue *values, void *userData);
+
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
+ MlirPDLPatternModule module, MlirStringRef name,
+ MlirPDLRewriteFunction rewriteFn, void *userData);
+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
#undef DEFINE_C_API_STRUCT
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5b7de50f02e6a..eceb5895fd901 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -9,10 +9,15 @@
#include "Rewrite.h"
#include "IRModule.h"
+#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
+#include "mlir-c/Support.h"
+// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
+// clang-format on
#include "mlir/Config/mlir-config.h"
+#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace mlir;
@@ -36,6 +41,43 @@ class PyPDLPatternModule {
}
MlirPDLPatternModule get() { return module; }
+ static nb::object fromPDLValue(MlirPDLValue value) {
+ if (mlirPDLValueIsValue(value)) {
+ return nb::cast(mlirPDLValueAsValue(value));
+ }
+ if (mlirPDLValueIsOperation(value)) {
+ return nb::cast(mlirPDLValueAsOperation(value));
+ }
+ if (mlirPDLValueIsAttribute(value)) {
+ return nb::cast(mlirPDLValueAsAttribute(value));
+ }
+ if (mlirPDLValueIsType(value)) {
+ return nb::cast(mlirPDLValueAsType(value));
+ }
+
+ throw std::runtime_error("unsupported PDL value type");
+ }
+
+ void registerRewriteFunction(const std::string &name,
+ const nb::callable &fn) {
+ mlirPDLPatternModuleRegisterRewriteFunction(
+ get(), mlirStringRefCreate(name.data(), name.size()),
+ [](MlirPatternRewriter rewriter, MlirPDLResultList results,
+ size_t nValues, MlirPDLValue *values,
+ void *userData) -> MlirLogicalResult {
+ auto f = nb::handle(static_cast<PyObject *>(userData));
+ std::vector<nb::object> args;
+ args.reserve(nValues);
+ for (size_t i = 0; i < nValues; ++i) {
+ args.push_back(fromPDLValue(values[i]));
+ }
+ return nb::cast<bool>(f(rewriter, results, args))
+ ? mlirLogicalResultSuccess()
+ : mlirLogicalResultFailure();
+ },
+ fn.ptr());
+ }
+
private:
MlirPDLPatternModule module;
};
@@ -76,10 +118,27 @@ class PyFrozenRewritePatternSet {
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+ nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
//----------------------------------------------------------------------------
- // Mapping of the top-level PassManager
+ // Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+ nb::class_<MlirPDLResultList>(m, "PDLResultList")
+ .def("push_back",
+ [](MlirPDLResultList results, const PyValue &value) {
+ mlirPDLResultListPushBackValue(results, value);
+ })
+ .def("push_back",
+ [](MlirPDLResultList results, const PyOperation &op) {
+ mlirPDLResultListPushBackOperation(results, op);
+ })
+ .def("push_back",
+ [](MlirPDLResultList results, const PyType &type) {
+ mlirPDLResultListPushBackType(results, type);
+ })
+ .def("push_back", [](MlirPDLResultList results, const PyAttribute &attr) {
+ mlirPDLResultListPushBackAttribute(results, attr);
+ });
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
@@ -88,10 +147,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
},
"module"_a, "Create a PDL module from the given module.")
- .def("freeze", [](PyPDLPatternModule &self) {
- return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
- mlirRewritePatternSetFromPDLPatternModule(self.get())));
- });
+ .def(
+ "freeze",
+ [](PyPDLPatternModule &self) {
+ return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
+ mlirRewritePatternSetFromPDLPatternModule(self.get())));
+ },
+ nb::keep_alive<0, 1>())
+ .def(
+ "register_rewrite_function",
+ [](PyPDLPatternModule &self, const std::string &name,
+ const nb::callable &fn) {
+ self.registerRewriteFunction(name, fn);
+ },
+ nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 6f85357a14a18..0033abde986ea 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -13,6 +13,8 @@
#include "mlir/CAPI/Rewrite.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PDLPatternMatch.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
+//===----------------------------------------------------------------------===//
+/// PatternRewriter API
+//===----------------------------------------------------------------------===//
+
+inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
+ assert(rewriter.ptr && "unexpected null rewriter");
+ return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
+}
+
+inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
+ return {rewriter};
+}
+
//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
@@ -331,4 +346,88 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
op.ptr = nullptr;
return wrap(m);
}
+
+inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
+ assert(value.ptr && "unexpected null PDL value");
+ return static_cast<const mlir::PDLValue *>(value.ptr);
+}
+
+inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
+
+inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
+ assert(results.ptr && "unexpected null PDL results");
+ return static_cast<mlir::PDLResultList *>(results.ptr);
+}
+
+inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
+ return {results};
+}
+
+bool mlirPDLValueIsValue(MlirPDLValue value) {
+ return unwrap(value)->isa<mlir::Value>();
+}
+
+MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
+ return wrap(unwrap(value)->cast<mlir::Value>());
+}
+
+bool mlirPDLValueIsType(MlirPDLValue value) {
+ return unwrap(value)->isa<mlir::Type>();
+}
+
+MlirType mlirPDLValueAsType(MlirPDLValue value) {
+ return wrap(unwrap(value)->cast<mlir::Type>());
+}
+
+bool mlirPDLValueIsOperation(MlirPDLValue value) {
+ return unwrap(value)->isa<mlir::Operation *>();
+}
+
+MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
+ return wrap(unwrap(value)->cast<mlir::Operation *>());
+}
+
+bool mlirPDLValueIsAttribute(MlirPDLValue value) {
+ return unwrap(value)->isa<mlir::Attribute>();
+}
+
+MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
+ return wrap(unwrap(value)->cast<mlir::Attribute>());
+}
+
+void mlirPDLResultListPushBackValue(MlirPDLResultList results,
+ MlirValue value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
+ MlirOperation value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
+ MlirAttribute value) {
+ unwrap(results)->push_back(unwrap(value));
+}
+
+void mlirPDLPatternModuleRegisterRewriteFunction(
+ MlirPDLPatternModule module, MlirStringRef name,
+ MlirPDLRewriteFunction rewriteFn, void *userData) {
+ unwrap(module)->registerRewriteFunction(
+ unwrap(name),
+ [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
+ ArrayRef<PDLValue> values) -> LogicalResult {
+ std::vector<MlirPDLValue> mlirValues;
+ for (auto &value : values) {
+ mlirValues.push_back(wrap(&value));
+ }
+ return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
+ mlirValues.size(), mlirValues.data(),
+ userData));
+ });
+}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index dd6c74ce622c8..c78f2d4f9a0dc 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -86,3 +86,99 @@ def add_func(a, b):
frozen = get_pdl_patterns()
apply_patterns_and_fold_greedily(module_.operation, frozen)
return module_
+
+
+# If we use arith.constant and arith.addi here,
+# these C++-defined folding/canonicalization will be applied
+# implicitly in the greedy pattern rewrite driver to
+# make our Python-defined folding useless,
+# so here we define a new dialect to workaround this.
+def load_myint_dialect():
+ from mlir.dialects import irdl
+
+ m = Module.create()
+ with InsertionPoint(m.body):
+ myint = irdl.dialect("myint")
+ with InsertionPoint(myint.body):
+ constant = irdl.operation_("constant")
+ with InsertionPoint(constant.body):
+ iattr = irdl.base(base_name="#builtin.integer")
+ i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+ irdl.attributes_([iattr], ["value"])
+ irdl.results_([i32], ["cst"], [irdl.Variadicity.single])
+ add = irdl.operation_("add")
+ with InsertionPoint(add.body):
+ i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
+ irdl.operands_(
+ [i32, i32],
+ ["lhs", "rhs"],
+ [irdl.Variadicity.single, irdl.Variadicity.single],
+ )
+ irdl.results_([i32], ["res"], [irdl.Variadicity.single])
+
+ m.operation.verify()
+ irdl.load_dialects(m)
+
+
+# This PDL pattern is to fold constant additions,
+# i.e. add(constant0, constant1) -> constant2
+# where constant2 = constant0 + constant1.
+def get_pdl_pattern_fold():
+ m = Module.create()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(m.body):
+
+ @pdl.pattern(benefit=1, sym_name="myint_add_fold")
+ def pat():
+ t = pdl.TypeOp(i32)
+ a0 = pdl.AttributeOp()
+ a1 = pdl.AttributeOp()
+ c0 = pdl.OperationOp(
+ name="myint.constant", attributes={"value": a0}, types=[t]
+ )
+ c1 = pdl.OperationOp(
+ name="myint.constant", attributes={"value": a1}, types=[t]
+ )
+ v0 = pdl.ResultOp(c0, 0)
+ v1 = pdl.ResultOp(c1, 0)
+ op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+ @pdl.rewrite()
+ def rew():
+ sum = pdl.apply_native_rewrite(
+ [pdl.AttributeType.get()], "add_fold", [a0, a1]
+ )
+ newOp = pdl.OperationOp(
+ name="myint.constant", attributes={"value": sum}, types=[t]
+ )
+ pdl.ReplaceOp(op0, with_op=newOp)
+
+ def add_fold(rewriter, results, values):
+ a0, a1 = values
+ results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
+ return True
+
+ pdl_module = PDLModule(m)
+ pdl_module.register_rewrite_function("add_fold", add_fold)
+ return pdl_module.freeze()
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function
+# CHECK: "myint.constant"() {value = 8 : i32} : () -> i32
+@construct_and_print_in_module
+def test_pdl_register_function(module_):
+ load_myint_dialect()
+
+ module_ = Module.parse(
+ """
+ %c0 = "myint.constant"() { value = 2 }: () -> (i32)
+ %c1 = "myint.constant"() { value = 3 }: () -> (i32)
+ %x = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
+ "myint.add"(%x, %c1): (i32, i32) -> (i32)
+ """
+ )
+
+ frozen = get_pdl_pattern_fold()
+ apply_patterns_and_fold_greedily(module_, frozen)
+
+ return module_
|
Man great work - this is way simpler than I thought since we're not actually using any |
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
…-python-pdl-register
template <typename T, | ||
typename ResultT = std::conditional_t< | ||
std::is_convertible<T, bool>::value, T, std::optional<T>>> | ||
std::is_constructible_v<bool, T>, T, std::optional<T>>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some classes like mlir::Value
, mlir::Type
.. have an explicit operator bool()
, and is_convertible
will ignore explicit conversions. So here we replace it with is_constructible_v
to make it work for these types, so that mlir::Value
can be used instead of weird std::optional<mlir::Value>
(since mlir::Value
is nullable itself).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I don't understand this one - also is this needed for this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. Due to this review suggestion #159926 (comment), now we use dyn_cast
instead of isa
and cast
of PDLValue. And this change is to address issues in PDLValue::dyn_cast
. (previously we did't need this change : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can change the API back to the origin form if we want to avoid touching that header. Maybe @jpienaar has idea for this : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nah it's fine I just didn't see that comment about dyn_cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one random q: why is this file named .h.inc
? @jpienaar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - @ftynse any other asks?
Thank you all and I'll merge it soon. 🥳 |
``` external/llvm-project/mlir/lib/CAPI/Transforms/Rewrite.cpp:17:10: error: use of private header from outside its module: 'mlir/IR/PDLPatternMatch.h.inc' [-Wprivate-header] 17 | #include "mlir/IR/PDLPatternMatch.h.inc" | ^ ```
…160520) This is a follow-up to #159926. That PR (#159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with `pdl.apply_native_rewrite`. In this PR, we add support for native constraint functions in PDL via `pdl.apply_native_constraint`, further completing the PDL API.
…ng (llvm#159926) In the MLIR Python bindings, we can currently use PDL to define simple patterns and then execute them with the greedy rewrite driver. However, when dealing with more complex patterns—such as constant folding for integer addition—we find that we need `apply_native_rewrite` to actually perform arithmetic (i.e., compute the sum of two constants). For example, consider the following PDL pseudocode: ```mlir pdl.pattern : benefit(1) { %a0 = pdl.attribute %a1 = pdl.attribute %c0 = pdl.operation "arith.constant" {value = %a0} %c1 = pdl.operation "arith.constant" {value = %a1} %op = pdl.operation "arith.addi"(%c0, %c1) %sum = pdl.apply_native_rewrite "addIntegers"(%a0, %a1) %new_cst = pdl.operation "arith.constant" {value = %sum} pdl.replace %op with %new_cst } ``` Here, `addIntegers` cannot be expressed in PDL alone—it requires a *native rewrite function*. This PR introduces a mechanism to support exactly that, allowing complex rewrite patterns to be expressed in Python and enabling many passes to be implemented directly in Python as well. As a test case, we defined two new operations (`myint.constant` and `myint.add`) in Python and implemented a constant-folding rewrite pattern for them. The core code looks like this: ```python m = Module.create() with InsertionPoint(m.body): @pdl.pattern(benefit=1, sym_name="myint_add_fold") def pat(): ... op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) @pdl.rewrite() def rew(): sum = pdl.apply_native_rewrite( [pdl.AttributeType.get()], "add_fold", [a0, a1] ) newOp = pdl.OperationOp( name="myint.constant", attributes={"value": sum}, types=[t] ) pdl.ReplaceOp(op0, with_op=newOp) def add_fold(rewriter, results, values): a0, a1 = values results.push_back(IntegerAttr.get(i32, a0.value + a1.value)) pdl_module = PDLModule(m) pdl_module.register_rewrite_function("add_fold", add_fold) ``` The idea is previously discussed in Discord #mlir-python channel with @makslevental. --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
…lvm#160432) ``` external/llvm-project/mlir/lib/CAPI/Transforms/Rewrite.cpp:17:10: error: use of private header from outside its module: 'mlir/IR/PDLPatternMatch.h.inc' [-Wprivate-header] 17 | #include "mlir/IR/PDLPatternMatch.h.inc" | ^ ```
…lvm#160520) This is a follow-up to llvm#159926. That PR (llvm#159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with `pdl.apply_native_rewrite`. In this PR, we add support for native constraint functions in PDL via `pdl.apply_native_constraint`, further completing the PDL API.
Adds argument names to the method stubs for PDLResultList (from #159926).
Adds argument names to the method stubs for PDLResultList (from llvm/llvm-project#159926).
Adds argument names to the method stubs for PDLResultList (from llvm/llvm-project#159926).
…lvm#160520) This is a follow-up to llvm#159926. That PR (llvm#159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with `pdl.apply_native_rewrite`. In this PR, we add support for native constraint functions in PDL via `pdl.apply_native_constraint`, further completing the PDL API.
Adds argument names to the method stubs for PDLResultList (from llvm#159926).
In the MLIR Python bindings, we can currently use PDL to define simple patterns and then execute them with the greedy rewrite driver. However, when dealing with more complex patterns—such as constant folding for integer addition—we find that we need
apply_native_rewrite
to actually perform arithmetic (i.e., compute the sum of two constants). For example, consider the following PDL pseudocode:Here,
addIntegers
cannot be expressed in PDL alone—it requires a native rewrite function. This PR introduces a mechanism to support exactly that, allowing complex rewrite patterns to be expressed in Python and enabling many passes to be implemented directly in Python as well.As a test case, we defined two new operations (
myint.constant
andmyint.add
) in Python and implemented a constant-folding rewrite pattern for them. The core code looks like this:The idea is previously discussed in Discord #mlir-python channel with @makslevental.