diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 0d2e19ee7fb0a..1328b4de5d4eb 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -184,6 +184,13 @@ MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass( intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData); +MLIR_CAPI_EXPORTED void +mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name, + MlirStringRef argument, MlirStringRef description, + MlirStringRef opName, intptr_t nDependentDialects, + MlirDialectHandle *dependentDialects, + MlirExternalPassCallbacks callbacks, void *userData); + /// This signals that the pass has failed. This is only valid to call during /// the `run` callback of `MlirExternalPassCallbacks`. /// See Pass::signalPassFailure(). diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 47ef5d8e9dd3b..558ab6a43d87b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -52,6 +52,24 @@ class PyPassManager { MlirPassManager passManager; }; +MlirExternalPassCallbacks createExternalPassCallbacksForPythonCallable() { + MlirExternalPassCallbacks callbacks; + callbacks.construct = [](void *obj) { + (void)nb::handle(static_cast(obj)).inc_ref(); + }; + callbacks.destruct = [](void *obj) { + (void)nb::handle(static_cast(obj)).dec_ref(); + }; + callbacks.initialize = nullptr; + callbacks.clone = [](void *) -> void * { + throw std::runtime_error("Cloning Python passes not supported"); + }; + callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) { + nb::handle(static_cast(userData))(op, pass); + }; + return callbacks; +} + } // namespace /// Create the `mlir.passmanager` here. @@ -63,6 +81,33 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { .def("signal_pass_failure", [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); }); + //---------------------------------------------------------------------------- + // Mapping of register_pass + //---------------------------------------------------------------------------- + m.def( + "register_pass", + [](const std::string &argument, const nb::callable &run, + std::optional &name, const std::string &description, + const std::string &opName) { + if (!name.has_value()) { + name = + nb::cast(nb::borrow(run.attr("__name__"))); + } + MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); + MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); + auto callbacks = createExternalPassCallbacksForPythonCallable(); + mlirRegisterExternalPass( + passID, mlirStringRefCreate(name->data(), name->length()), + mlirStringRefCreate(argument.data(), argument.length()), + mlirStringRefCreate(description.data(), description.length()), + mlirStringRefCreate(opName.data(), opName.size()), + /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, callbacks, + /*userData*/ run.ptr()); + }, + "argument"_a, "run"_a, "name"_a.none() = nb::none(), + "description"_a.none() = "", "op_name"_a.none() = "", + "Register a python-defined pass."); + //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- @@ -178,21 +223,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); - MlirExternalPassCallbacks callbacks; - callbacks.construct = [](void *obj) { - (void)nb::handle(static_cast(obj)).inc_ref(); - }; - callbacks.destruct = [](void *obj) { - (void)nb::handle(static_cast(obj)).dec_ref(); - }; - callbacks.initialize = nullptr; - callbacks.clone = [](void *) -> void * { - throw std::runtime_error("Cloning Python passes not supported"); - }; - callbacks.run = [](MlirOperation op, MlirExternalPass pass, - void *userData) { - nb::handle(static_cast(userData))(op, pass); - }; + auto callbacks = createExternalPassCallbacksForPythonCallable(); auto externalPass = mlirCreateExternalPass( passID, mlirStringRefCreate(name->data(), name->length()), mlirStringRefCreate(argument.data(), argument.length()), diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index b0a6ec1ace3cc..8924f6d9ec6a9 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -216,6 +216,32 @@ MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, userData))); } +void mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name, + MlirStringRef argument, MlirStringRef description, + MlirStringRef opName, intptr_t nDependentDialects, + MlirDialectHandle *dependentDialects, + MlirExternalPassCallbacks callbacks, + void *userData) { + // here we clone these arguments as owned and pass them to + // the lambda as copies to avoid dangling refs, + // since the lambda below lives longer than the current function + std::string nameStr = unwrap(name).str(); + std::string argumentStr = unwrap(argument).str(); + std::string descriptionStr = unwrap(description).str(); + std::string opNameStr = unwrap(opName).str(); + std::vector dependentDialectVec( + dependentDialects, dependentDialects + nDependentDialects); + + mlir::registerPass([passID, nameStr, argumentStr, descriptionStr, opNameStr, + dependentDialectVec, callbacks, userData] { + return std::unique_ptr(new mlir::ExternalPass( + unwrap(passID), nameStr, argumentStr, descriptionStr, + opNameStr.length() > 0 ? std::optional(opNameStr) + : std::nullopt, + dependentDialectVec, callbacks, userData)); + }); +} + void mlirExternalPassSignalFailure(MlirExternalPass pass) { unwrap(pass)->signalPassFailure(); } diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py index 50c42102f66d3..0fbd96ec71ddc 100644 --- a/mlir/test/python/python_pass.py +++ b/mlir/test/python/python_pass.py @@ -89,7 +89,7 @@ def __call__(self, op, pass_): # test signal_pass_failure def custom_pass_that_fails(op, pass_): - print("hello from pass that fails") + print("hello from pass that fails", file=sys.stderr) pass_.signal_pass_failure() pm = PassManager("any") @@ -99,4 +99,44 @@ def custom_pass_that_fails(op, pass_): try: pm.run(module) except Exception as e: - print(f"caught exception: {e}") + print(f"caught exception: {e}", file=sys.stderr) + + +# CHECK-LABEL: TEST: testRegisterPass +@run +def testRegisterPass(): + with Context(): + pdl_module = make_pdl_module() + frozen = PDLModule(pdl_module).freeze() + + module = ModuleOp.parse( + r""" + module { + func.func @add(%a: i64, %b: i64) -> i64 { + %sum = arith.addi %a, %b : i64 + return %sum : i64 + } + } + """ + ) + + def custom_pass_3(op, pass_): + print("hello from pass 3!!!", file=sys.stderr) + + def custom_pass_4(op, pass_): + apply_patterns_and_fold_greedily(op, frozen) + + register_pass("custom-pass-one", custom_pass_3) + register_pass("custom-pass-two", custom_pass_4) + + pm = PassManager("any") + pm.enable_ir_printing() + + # CHECK: hello from pass 3!!! + # CHECK-LABEL: Dump After custom_pass_3 + # CHECK-LABEL: Dump After custom_pass_4 + # CHECK: arith.muli + # CHECK-LABEL: Dump After ArithToLLVMConversionPass + # CHECK: llvm.mul + pm.add("custom-pass-one, custom-pass-two, convert-arith-to-llvm") + pm.run(module)