diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 6ee85e8a31492..47ef5d8e9dd3b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -56,6 +56,13 @@ class PyPassManager { /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { + //---------------------------------------------------------------------------- + // Mapping of MlirExternalPass + //---------------------------------------------------------------------------- + nb::class_(m, "ExternalPass") + .def("signal_pass_failure", + [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); }); + //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- @@ -182,9 +189,9 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { callbacks.clone = [](void *) -> void * { throw std::runtime_error("Cloning Python passes not supported"); }; - callbacks.run = [](MlirOperation op, MlirExternalPass, + callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) { - nb::borrow(static_cast(userData))(op); + nb::handle(static_cast(userData))(op, pass); }; auto externalPass = mlirCreateExternalPass( passID, mlirStringRefCreate(name->data(), name->length()), diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py index c94f96e20966f..50c42102f66d3 100644 --- a/mlir/test/python/python_pass.py +++ b/mlir/test/python/python_pass.py @@ -64,12 +64,12 @@ def testCustomPass(): """ ) - def custom_pass_1(op): + def custom_pass_1(op, pass_): print("hello from pass 1!!!", file=sys.stderr) class CustomPass2: - def __call__(self, m): - apply_patterns_and_fold_greedily(m, frozen) + def __call__(self, op, pass_): + apply_patterns_and_fold_greedily(op, frozen) custom_pass_2 = CustomPass2() @@ -86,3 +86,17 @@ def __call__(self, m): # CHECK: llvm.mul pm.add("convert-arith-to-llvm") pm.run(module) + + # test signal_pass_failure + def custom_pass_that_fails(op, pass_): + print("hello from pass that fails") + pass_.signal_pass_failure() + + pm = PassManager("any") + pm.add(custom_pass_that_fails, "CustomPassThatFails") + # CHECK: hello from pass that fails + # CHECK: caught exception: Failure while executing pass pipeline + try: + pm.run(module) + except Exception as e: + print(f"caught exception: {e}")