diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0f2ca666ccc05..d4e947a61a1af 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -110,6 +110,15 @@ static const char kOperationPrintDocstring[] = invalid, behavior is undefined. )"; +static const char kOperationPrintStateDocstring[] = + R"(Prints the assembly form of the operation to a file like object. + +Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + state: AsmState capturing the operation numbering and flags. +)"; + static const char kOperationGetAsmDocstring[] = R"(Gets the assembly form of the operation with all options available. @@ -1169,11 +1178,11 @@ void PyOperation::checkValid() const { } } -void PyOperationBase::print(py::object fileObject, bool binary, - std::optional largeElementsLimit, +void PyOperationBase::print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified) { + bool assumeVerified, py::object fileObject, + bool binary) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) @@ -1198,6 +1207,17 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsDestroy(flags); } +void PyOperationBase::print(PyAsmState &state, py::object fileObject, + bool binary) { + PyOperation &operation = getOperation(); + operation.checkValid(); + if (fileObject.is_none()) + fileObject = py::module::import("sys").attr("stdout"); + PyFileAccumulator accum(fileObject, binary); + mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), + accum.getUserData()); +} + void PyOperationBase::writeBytecode(const py::object &fileObject, std::optional bytecodeVersion) { PyOperation &operation = getOperation(); @@ -1230,13 +1250,14 @@ py::object PyOperationBase::getAsm(bool binary, } else { fileObject = py::module::import("io").attr("StringIO")(); } - print(fileObject, /*binary=*/binary, - /*largeElementsLimit=*/largeElementsLimit, + print(/*largeElementsLimit=*/largeElementsLimit, /*enableDebugInfo=*/enableDebugInfo, /*prettyDebugInfo=*/prettyDebugInfo, /*printGenericOpForm=*/printGenericOpForm, /*useLocalScope=*/useLocalScope, - /*assumeVerified=*/assumeVerified); + /*assumeVerified=*/assumeVerified, + /*fileObject=*/fileObject, + /*binary=*/binary); return fileObject.attr("getvalue")(); } @@ -2946,15 +2967,23 @@ void mlir::python::populateIRCore(py::module &m) { /*assumeVerified=*/false); }, "Returns the assembly form of the operation.") - .def("print", &PyOperationBase::print, + .def("print", + py::overload_cast( + &PyOperationBase::print), + py::arg("state"), py::arg("file") = py::none(), + py::arg("binary") = false, kOperationPrintStateDocstring) + .def("print", + py::overload_cast, bool, bool, bool, bool, + bool, py::object, bool>( + &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. - py::arg("file") = py::none(), py::arg("binary") = false, py::arg("large_elements_limit") = py::none(), py::arg("enable_debug_info") = false, py::arg("pretty_debug_info") = false, py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, - py::arg("assume_verified") = false, kOperationPrintDocstring) + py::arg("assume_verified") = false, py::arg("file") = py::none(), + py::arg("binary") = false, kOperationPrintDocstring) .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), py::arg("desired_version") = py::none(), kOperationPrintBytecodeDocstring) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index af55693f18fbb..d99b87d19bbea 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -550,16 +550,19 @@ class PyModule : public BaseContextObject { pybind11::handle handle; }; +class PyAsmState; + /// Base class for PyOperation and PyOpView which exposes the primary, user /// visible methods for manipulating it. class PyOperationBase { public: virtual ~PyOperationBase() = default; /// Implements the bound 'print' method and helps with others. - void print(pybind11::object fileObject, bool binary, - std::optional largeElementsLimit, bool enableDebugInfo, + void print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified); + bool assumeVerified, py::object fileObject, bool binary); + void print(PyAsmState &state, py::object fileObject, bool binary); + pybind11::object getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 04239b048c1c6..04f8a9936e31f 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -622,10 +622,15 @@ def testOperationPrint(): print(bytes_value.__class__) print(bytes_value) - # Test get_asm local_scope. + # Test print local_scope. # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") module.operation.print(enable_debug_info=True, use_local_scope=True) + # Test printing using state. + state = AsmState(module.operation) + # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> + module.operation.print(state) + # Test get_asm with options. # CHECK: value = dense_resource<__elided__> : tensor<4xi32> # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7