-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][Python] Add the ability to signal pass failures in python-defined passes #157613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesThis is a follow-up PR for #156000. In this PR we add the ability to signal pass failures ( To achieve this, we store the Note that the Full diff: https://github.com/llvm/llvm-project/pull/157613.diff 2 Files Affected:
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8a31492..c5fe7bda4a680 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,6 +56,8 @@ class PyPassManager {
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+ constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__";
+
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
@@ -182,10 +184,29 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
callbacks.clone = [](void *) -> void * {
throw std::runtime_error("Cloning Python passes not supported");
};
- callbacks.run = [](MlirOperation op, MlirExternalPass,
+ callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+ auto callable =
+ nb::borrow<nb::callable>(static_cast<PyObject *>(userData));
+ nb::setattr(callable, mlirExternalPassAttr,
+ nb::capsule(pass.ptr));
+ callable(op);
+ // delete it to avoid that it is used after
+ // the external pass is freed by the pass manager
+ nb::delattr(callable, mlirExternalPassAttr);
};
+ nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
+ nb::capsule cap;
+ try {
+ cap = run.attr(mlirExternalPassAttr);
+ } catch (nb::python_error &e) {
+ throw std::runtime_error(
+ "signal_pass_failure() should always be called "
+ "from the __call__ method");
+ }
+ mlirExternalPassSignalFailure(
+ MlirExternalPass{cap.data()});
+ }));
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966f..4784e073fef0a 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -86,3 +86,26 @@ def __call__(self, m):
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)
+
+ # test signal_pass_failure
+ class CustomPassThatFails:
+ def __call__(self, m):
+ print("hello from pass that fails")
+ self.signal_pass_failure()
+
+ custom_pass_that_fails = CustomPassThatFails()
+
+ pm = PassManager("any")
+ pm.add(custom_pass_that_fails, "CustomPassThatFails")
+ # CHECK: hello from pass that fails
+ # CHECK: caught exception: Failure while executing pass pipeline
+ try:
+ pm.run(module)
+ except Exception as e:
+ print(f"caught exception: {e}")
+
+ # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method
+ try:
+ custom_pass_that_fails.signal_pass_failure()
+ except Exception as e:
+ print(f"caught exception: {e}")
|
this isn't really how signalling pass failure works in C++ (signalling doesn't generate an exception or anything like that). also getting/setting attrs like this from cpp side is kind of "cursed" (i know i did it in the other one but it's generally to be used sparingly). I can review more closely tomorrow but how about something like this instead https://gist.github.com/makslevental/ca6170ea884ef2b8d04601be3b6c8cac |
signalling pass failure doesn't really do anything except prevent verifaction from occurring llvm-project/mlir/lib/Pass/Pass.cpp Line 563 in 4b1d5b8
EDIT: sorry the pass does return llvm-project/mlir/lib/Pass/Pass.cpp Line 594 in 4b1d5b8
|
Ahh let me explain a little bit, the exception will only be thrown if the user mis-use this method (e.g. the method is used outside the In the happy path, the The exception you see in the test file, is actually thrown by the pass manager, not from the pass : ) |
ah sorry right right i skimmed too quickly. still i think it's a little too clever to attach the method to the callback (class or otherwise) instead of just passing the |
Yeah the current method looks magic LOL. I'll consider to pass it as an argument to avoid surprise to users. |
Hi @makslevental, I have finished the changes guided by the gist. A minor difference: I did some check and it seems that we don't need to cast the handle to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This is a follow-up PR for #156000.
In this PR we add the ability to signal pass failures (
signal_pass_failure()
) in python-defined passes.To achieve this, we expose
MlirExternalPass
vianb::class_
with a methodsignal_pass_failure()
, and the callable passed topm.add(..)
now accepts two arguments (op: MlirOperation, pass_: MlirExternalPass
).For example: