-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] Add bindings for PDL constraint function registering #160520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesThis is a follow-up to #159926. That PR (#159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with In this PR, we add support for native constraint functions in PDL via Full diff: https://github.com/llvm/llvm-project/pull/160520.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index f4974348945c5..77be1f480eacf 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -375,6 +375,20 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLRewriteFunction rewriteFn, void *userData);
+/// This function type is used as callbacks for PDL native constraint functions.
+/// Input values can be accessed by `values` with its size `nValues`;
+/// output values can be added into `results` by `mlirPDLResultListPushBack*`
+/// APIs. And the return value indicates whether the constraint holds.
+typedef MlirLogicalResult (*MlirPDLConstraintFunction)(
+ MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
+ MlirPDLValue *values, void *userData);
+
+/// Register a constraint function into the given PDL pattern module.
+/// `userData` will be provided as an argument to the constraint function.
+MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterConstraintFunction(
+ MlirPDLPatternModule pdlModule, MlirStringRef name,
+ MlirPDLConstraintFunction constraintFn, void *userData);
+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
#undef DEFINE_C_API_STRUCT
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index c53c6cf0dab1e..20392b9002706 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -40,6 +40,15 @@ static nb::object objectFromPDLValue(MlirPDLValue value) {
throw std::runtime_error("unsupported PDL value type");
}
+static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
+ MlirPDLValue *values) {
+ std::vector<nb::object> args;
+ args.reserve(nValues);
+ for (size_t i = 0; i < nValues; ++i)
+ args.push_back(objectFromPDLValue(values[i]));
+ return args;
+}
+
// Convert the Python object to a boolean.
// If it evaluates to False, treat it as success;
// otherwise, treat it as failure.
@@ -74,11 +83,22 @@ class PyPDLPatternModule {
size_t nValues, MlirPDLValue *values,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
- std::vector<nb::object> args;
- args.reserve(nValues);
- for (size_t i = 0; i < nValues; ++i)
- args.push_back(objectFromPDLValue(values[i]));
- return logicalResultFromObject(f(rewriter, results, args));
+ return logicalResultFromObject(
+ f(rewriter, results, objectsFromPDLValues(nValues, values)));
+ },
+ fn.ptr());
+ }
+
+ void registerConstraintFunction(const std::string &name,
+ const nb::callable &fn) {
+ mlirPDLPatternModuleRegisterConstraintFunction(
+ get(), mlirStringRefCreate(name.data(), name.size()),
+ [](MlirPatternRewriter rewriter, MlirPDLResultList results,
+ size_t nValues, MlirPDLValue *values,
+ void *userData) -> MlirLogicalResult {
+ nb::handle f = nb::handle(static_cast<PyObject *>(userData));
+ return logicalResultFromObject(
+ f(rewriter, results, objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
@@ -199,6 +219,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
const nb::callable &fn) {
self.registerRewriteFunction(name, fn);
},
+ nb::keep_alive<1, 3>())
+ .def(
+ "register_constraint_function",
+ [](PyPDLPatternModule &self, const std::string &name,
+ const nb::callable &fn) {
+ self.registerConstraintFunction(name, fn);
+ },
nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 9ecce956a05b9..8ee6308cadf83 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -398,6 +398,15 @@ void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
unwrap(results)->push_back(unwrap(value));
}
+inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) {
+ std::vector<MlirPDLValue> mlirValues;
+ mlirValues.reserve(values.size());
+ for (auto &value : values) {
+ mlirValues.push_back(wrap(&value));
+ }
+ return mlirValues;
+}
+
void mlirPDLPatternModuleRegisterRewriteFunction(
MlirPDLPatternModule pdlModule, MlirStringRef name,
MlirPDLRewriteFunction rewriteFn, void *userData) {
@@ -405,14 +414,25 @@ void mlirPDLPatternModuleRegisterRewriteFunction(
unwrap(name),
[userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> values) -> LogicalResult {
- std::vector<MlirPDLValue> mlirValues;
- mlirValues.reserve(values.size());
- for (auto &value : values) {
- mlirValues.push_back(wrap(&value));
- }
+ std::vector<MlirPDLValue> mlirValues = wrap(values);
return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
mlirValues.size(), mlirValues.data(),
userData));
});
}
+
+void mlirPDLPatternModuleRegisterConstraintFunction(
+ MlirPDLPatternModule pdlModule, MlirStringRef name,
+ MlirPDLConstraintFunction constraintFn, void *userData) {
+ unwrap(pdlModule)->registerConstraintFunction(
+ unwrap(name),
+ [userData, constraintFn](PatternRewriter &rewriter,
+ PDLResultList &results,
+ ArrayRef<PDLValue> values) -> LogicalResult {
+ std::vector<MlirPDLValue> mlirValues = wrap(values);
+ return unwrap(constraintFn(wrap(&rewriter), wrap(&results),
+ mlirValues.size(), mlirValues.data(),
+ userData));
+ });
+}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py
index e85c6c77ef955..c8e6197e03842 100644
--- a/mlir/test/python/integration/dialects/pdl.py
+++ b/mlir/test/python/integration/dialects/pdl.py
@@ -153,12 +153,43 @@ def rew():
)
pdl.ReplaceOp(op0, with_op=newOp)
+ @pdl.pattern(benefit=1, sym_name="myint_add_zero_fold")
+ def pat():
+ t = pdl.TypeOp(i32)
+ v0 = pdl.OperandOp()
+ v1 = pdl.OperandOp()
+ v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1])
+ op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
+
+ @pdl.rewrite()
+ def rew():
+ pdl.ReplaceOp(op0, with_values=[v])
+
def add_fold(rewriter, results, values):
a0, a1 = values
results.append(IntegerAttr.get(i32, a0.value + a1.value))
+ def is_zero(value):
+ op = value.owner
+ if isinstance(op, Operation):
+ return op.name == "myint.constant" and op.attributes["value"].value == 0
+ return False
+
+ # Check if either operand is a constant zero,
+ # and append the other operand to the results if so.
+ def has_zero(rewriter, results, values):
+ v0, v1 = values
+ if is_zero(v0):
+ results.append(v1)
+ return False
+ if is_zero(v1):
+ results.append(v0)
+ return False
+ return True
+
pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold)
+ pdl_module.register_constraint_function("has_zero", has_zero)
return pdl_module.freeze()
@@ -181,3 +212,28 @@ def test_pdl_register_function(module_):
apply_patterns_and_fold_greedily(module_, frozen)
return module_
+
+
+# CHECK-LABEL: TEST: test_pdl_register_function_constraint
+# CHECK: return %arg0 : i32
+@construct_and_print_in_module
+def test_pdl_register_function_constraint(module_):
+ load_myint_dialect()
+
+ module_ = Module.parse(
+ """
+ func.func @f(%x : i32) -> i32 {
+ %c0 = "myint.constant"() { value = 1 }: () -> (i32)
+ %c1 = "myint.constant"() { value = -1 }: () -> (i32)
+ %a = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
+ %b = "myint.add"(%a, %x): (i32, i32) -> (i32)
+ %c = "myint.add"(%b, %a): (i32, i32) -> (i32)
+ func.return %c : i32
+ }
+ """
+ )
+
+ frozen = get_pdl_pattern_fold()
+ apply_patterns_and_fold_greedily(module_, frozen)
+
+ return module_
|
results.append(v1) | ||
return False | ||
if is_zero(v1): | ||
results.append(v0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I haven't used this pdl API before: are results used for anything? Are they meant to be? Because I'm pretty sure currently (the current state of the PR) they're not and it'd be impossible to make it work (since you're wrapping in a std:: vector
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh sorry there's no type annotation so easy to get wrong here. results
here is typed MlirPDLResultList
and we expose an append
method for this type (this is the only method of this type for now). And values
is typed std::vector<nb::object>
so it is a list of Value
/Attribute
/Type
.. .
The results
can be used in the pdl.apply_native_constraint
, for example
%res = pdl.apply_native_constraint("some_constraint", %v1: pdl.value, %v2: pdl.value) -> pdl.value
Then the callable passed with some_constraint
will be called like (pseudocode, the call actually happens in C++):
values = [v1, v2] # corresponding to argument %v1 and %v2
if not some_constraint(rewriter, results, values): # if it succeeds
assert(len(results) == 1) # results[0] corresponding to %res
And here for has_zero
, the story is:
- we check if either operand is zero
- if no zero, just fail and exit
- otherwise, we push the other (non-zero) operand into
results
and use it as the new op for rewrite
e.g. for x + 0
:
- zero (the second operand) found! (the constraint holds)
- push
x
(the first operand) to results x + 0
rewritten tox
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I understand now - I thought MlirPDLResultList was just an ArrayRef but I see it's not, it's actually a container that holds a bunch of SmallVectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool this looks great. Just wondering - have you noticed any other missing functionality for writing diverse rewrites using PDL?
I think the current Python-side PDL capabilities already make it possible to express most rewrite patterns. One possible future extension would be to expose some of the Take a simple example: suppose we want to rewrite a constant with value In short, the current PDL covers most patterns well, but some more complex logic still needs to be “expanded” into multiple patterns. That’s an area we might consider enhancing in the future. |
Thank you and I'll merge it soon : ) |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/18844 Here is the relevant piece of the build log for the reference
|
Hmmm. I did some check and it looks unrelated to this PR? |
Ya this is a flaky test I see fail sometimes. |
…lvm#160520) This is a follow-up to llvm#159926. That PR (llvm#159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with `pdl.apply_native_rewrite`. In this PR, we add support for native constraint functions in PDL via `pdl.apply_native_constraint`, further completing the PDL API.
In [#160520](#160520), we discussed the current limitations of PDL rewriting in Python (see [this comment](#160520 (comment))). At the moment, we cannot create new operations in PDL native (python) rewrite functions because the `PatternRewriter` APIs are not exposed. This PR introduces bindings to retrieve the insertion point of the `PatternRewriter`, enabling users to create new operations within Python rewrite functions. With this capability, more complex rewrites e.g. with branching and loops that involve op creations become possible. --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
…ter (#161001) In [#160520](llvm/llvm-project#160520), we discussed the current limitations of PDL rewriting in Python (see [this comment](llvm/llvm-project#160520 (comment))). At the moment, we cannot create new operations in PDL native (python) rewrite functions because the `PatternRewriter` APIs are not exposed. This PR introduces bindings to retrieve the insertion point of the `PatternRewriter`, enabling users to create new operations within Python rewrite functions. With this capability, more complex rewrites e.g. with branching and loops that involve op creations become possible. --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
In [#160520](llvm/llvm-project#160520), we discussed the current limitations of PDL rewriting in Python (see [this comment](llvm/llvm-project#160520 (comment))). At the moment, we cannot create new operations in PDL native (python) rewrite functions because the `PatternRewriter` APIs are not exposed. This PR introduces bindings to retrieve the insertion point of the `PatternRewriter`, enabling users to create new operations within Python rewrite functions. With this capability, more complex rewrites e.g. with branching and loops that involve op creations become possible. --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
…61001) In [llvm#160520](llvm#160520), we discussed the current limitations of PDL rewriting in Python (see [this comment](llvm#160520 (comment))). At the moment, we cannot create new operations in PDL native (python) rewrite functions because the `PatternRewriter` APIs are not exposed. This PR introduces bindings to retrieve the insertion point of the `PatternRewriter`, enabling users to create new operations within Python rewrite functions. With this capability, more complex rewrites e.g. with branching and loops that involve op creations become possible. --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
This is a follow-up to #159926.
That PR (#159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with
pdl.apply_native_rewrite
.In this PR, we add support for native constraint functions in PDL via
pdl.apply_native_constraint
, further completing the PDL API.