Skip to content

Conversation

PragmaTwice
Copy link
Member

@PragmaTwice PragmaTwice commented Sep 9, 2025

This is a follow-up PR for #156000.

In this PR we add the ability to signal pass failures (signal_pass_failure()) in python-defined passes.

To achieve this, we expose MlirExternalPass via nb::class_ with a method signal_pass_failure(), and the callable passed to pm.add(..) now accepts two arguments (op: MlirOperation, pass_: MlirExternalPass).

For example:

def custom_pass_that_fails(op, pass_):
    if some_condition:
        pass_.signal_pass_failure()
    # do something

@PragmaTwice PragmaTwice marked this pull request as ready for review September 9, 2025 05:15
@llvmbot llvmbot added the mlir label Sep 9, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

This is a follow-up PR for #156000.

In this PR we add the ability to signal pass failures (pass.signal_pass_failure()) in python-defined passes.

To achieve this, we store the MlirExternalPass C pointer into an attribute of the pass object, and add a method signal_pass_failure into the pass object while it is added into the pass manager via pm.add(..).

Note that the signal_pass_failure() method should be always called from __call__ in the pass object since the MlirExternalPass should be only available in this context (otherwise a friendly exception message will be raised).


Full diff: https://github.com/llvm/llvm-project/pull/157613.diff

2 Files Affected:

  • (modified) mlir/lib/Bindings/Python/Pass.cpp (+23-2)
  • (modified) mlir/test/python/python_pass.py (+23)
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8a31492..c5fe7bda4a680 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -56,6 +56,8 @@ class PyPassManager {
 
 /// Create the `mlir.passmanager` here.
 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
+  constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__";
+
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
@@ -182,10 +184,29 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
             callbacks.clone = [](void *) -> void * {
               throw std::runtime_error("Cloning Python passes not supported");
             };
-            callbacks.run = [](MlirOperation op, MlirExternalPass,
+            callbacks.run = [](MlirOperation op, MlirExternalPass pass,
                                void *userData) {
-              nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+              auto callable =
+                  nb::borrow<nb::callable>(static_cast<PyObject *>(userData));
+              nb::setattr(callable, mlirExternalPassAttr,
+                          nb::capsule(pass.ptr));
+              callable(op);
+              // delete it to avoid that it is used after
+              // the external pass is freed by the pass manager
+              nb::delattr(callable, mlirExternalPassAttr);
             };
+            nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
+                          nb::capsule cap;
+                          try {
+                            cap = run.attr(mlirExternalPassAttr);
+                          } catch (nb::python_error &e) {
+                            throw std::runtime_error(
+                                "signal_pass_failure() should always be called "
+                                "from the __call__ method");
+                          }
+                          mlirExternalPassSignalFailure(
+                              MlirExternalPass{cap.data()});
+                        }));
             auto externalPass = mlirCreateExternalPass(
                 passID, mlirStringRefCreate(name->data(), name->length()),
                 mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966f..4784e073fef0a 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -86,3 +86,26 @@ def __call__(self, m):
         # CHECK: llvm.mul
         pm.add("convert-arith-to-llvm")
         pm.run(module)
+
+        # test signal_pass_failure
+        class CustomPassThatFails:
+            def __call__(self, m):
+                print("hello from pass that fails")
+                self.signal_pass_failure()
+
+        custom_pass_that_fails = CustomPassThatFails()
+
+        pm = PassManager("any")
+        pm.add(custom_pass_that_fails, "CustomPassThatFails")
+        # CHECK: hello from pass that fails
+        # CHECK: caught exception: Failure while executing pass pipeline
+        try:
+            pm.run(module)
+        except Exception as e:
+            print(f"caught exception: {e}")
+
+        # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method
+        try:
+            custom_pass_that_fails.signal_pass_failure()
+        except Exception as e:
+            print(f"caught exception: {e}")

@makslevental
Copy link
Contributor

this isn't really how signalling pass failure works in C++ (signalling doesn't generate an exception or anything like that). also getting/setting attrs like this from cpp side is kind of "cursed" (i know i did it in the other one but it's generally to be used sparingly). I can review more closely tomorrow but how about something like this instead https://gist.github.com/makslevental/ca6170ea884ef2b8d04601be3b6c8cac

@makslevental
Copy link
Contributor

makslevental commented Sep 9, 2025

signalling pass failure doesn't really do anything except prevent verifaction from occurring

if (!passFailed && verifyPasses) {
(the ir is "allowed to be in an invalid state").

EDIT:

sorry the pass does return failure if you signal failure

return failure(passFailed);

@PragmaTwice
Copy link
Member Author

PragmaTwice commented Sep 9, 2025

signalling doesn't generate an exception or anything like that

Ahh let me explain a little bit, the exception will only be thrown if the user mis-use this method (e.g. the method is used outside the __call__ method).

In the happy path, the mlirExternalPassSignalFailure C API will be called. No exception thrown from here.

The exception you see in the test file, is actually thrown by the pass manager, not from the pass : )

@makslevental
Copy link
Contributor

In the happy path, the mlirExternalPassSignalFailure C API will be called. No exception thrown from here.

ah sorry right right i skimmed too quickly. still i think it's a little too clever to attach the method to the callback (class or otherwise) instead of just passing the ExternalPass all the way through.

@PragmaTwice
Copy link
Member Author

ah sorry right right i skimmed too quickly. still i think it's a little too clever to attach the method to the callback (class or otherwise) instead of just passing the ExternalPass all the way through.

Yeah the current method looks magic LOL. I'll consider to pass it as an argument to avoid surprise to users.

@PragmaTwice
Copy link
Member Author

Hi @makslevental, I have finished the changes guided by the gist.

A minor difference: I did some check and it seems that we don't need to cast the handle to nb::callable before calling it, so nb::handle should be enough here to use operator().

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@makslevental makslevental merged commit 7123463 into llvm:main Sep 9, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants