Skip to content

Commit

Permalink
[mlir] Expose skipRegions option for Op printing in the C and Python …
Browse files Browse the repository at this point in the history
…bindings (#96150)

The MLIR C and Python Bindings expose various methods from
`mlir::OpPrintingFlags` . This PR adds a binding for the `skipRegions`
method, which allows to skip the printing of Regions when printing Ops.
It also exposes this option as parameter in the python `get_asm` and
`print` methods
  • Loading branch information
jorickert authored Jun 20, 2024
1 parent d4bfc4a commit abad845
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 13 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
MLIR_CAPI_EXPORTED void
mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);

/// Skip printing regions.
MLIR_CAPI_EXPORTED void
mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags);

//===----------------------------------------------------------------------===//
// Bytecode printing flags API.
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 15 additions & 7 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ static const char kOperationPrintDocstring[] =
and report failures in a more robust fashion. Set this to True if doing this
in order to avoid running a redundant verification. If the IR is actually
invalid, behavior is undefined.
skip_regions: Whether to skip printing regions. Defaults to False.
)";

static const char kOperationPrintStateDocstring[] =
Expand Down Expand Up @@ -1221,7 +1222,7 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, py::object fileObject,
bool binary) {
bool binary, bool skipRegions) {
PyOperation &operation = getOperation();
operation.checkValid();
if (fileObject.is_none())
Expand All @@ -1239,6 +1240,8 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
mlirOpPrintingFlagsUseLocalScope(flags);
if (assumeVerified)
mlirOpPrintingFlagsAssumeVerified(flags);
if (skipRegions)
mlirOpPrintingFlagsSkipRegions(flags);

PyFileAccumulator accum(fileObject, binary);
mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
Expand Down Expand Up @@ -1314,7 +1317,7 @@ py::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified) {
bool assumeVerified, bool skipRegions) {
py::object fileObject;
if (binary) {
fileObject = py::module::import("io").attr("BytesIO")();
Expand All @@ -1328,7 +1331,8 @@ py::object PyOperationBase::getAsm(bool binary,
/*useLocalScope=*/useLocalScope,
/*assumeVerified=*/assumeVerified,
/*fileObject=*/fileObject,
/*binary=*/binary);
/*binary=*/binary,
/*skipRegions=*/skipRegions);

return fileObject.attr("getvalue")();
}
Expand Down Expand Up @@ -3043,7 +3047,8 @@ void mlir::python::populateIRCore(py::module &m) {
/*prettyDebugInfo=*/false,
/*printGenericOpForm=*/false,
/*useLocalScope=*/false,
/*assumeVerified=*/false);
/*assumeVerified=*/false,
/*skipRegions=*/false);
},
"Returns the assembly form of the operation.")
.def("print",
Expand All @@ -3053,15 +3058,17 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
bool, py::object, bool>(&PyOperationBase::print),
bool, py::object, bool, bool>(
&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
py::arg("pretty_debug_info") = false,
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
py::arg("assume_verified") = false, py::arg("file") = py::none(),
py::arg("binary") = false, kOperationPrintDocstring)
py::arg("binary") = false, py::arg("skip_regions") = false,
kOperationPrintDocstring)
.def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
py::arg("desired_version") = py::none(),
kOperationPrintBytecodeDocstring)
Expand All @@ -3073,7 +3080,8 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("pretty_debug_info") = false,
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
py::arg("assume_verified") = false, kOperationGetAsmDocstring)
py::arg("assume_verified") = false, py::arg("skip_regions") = false,
kOperationGetAsmDocstring)
.def("verify", &PyOperationBase::verify,
"Verify the operation. Raises MLIRError if verification fails, and "
"returns true otherwise.")
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,15 @@ class PyOperationBase {
/// Implements the bound 'print' method and helps with others.
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, py::object fileObject, bool binary);
bool assumeVerified, py::object fileObject, bool binary,
bool skipRegions);
void print(PyAsmState &state, py::object fileObject, bool binary);

pybind11::object getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified);
bool assumeVerified, bool skipRegions);

// Implement the bound 'writeBytecode' method.
void writeBytecode(const pybind11::object &fileObject,
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
unwrap(flags)->assumeVerified();
}

void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) {
unwrap(flags)->skipRegions();
}
//===----------------------------------------------------------------------===//
// Bytecode printing flags API.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class _OperationBase:
print_generic_op_form: bool = False,
use_local_scope: bool = False,
assume_verified: bool = False,
skip_regions: bool = False,
) -> Union[io.BytesIO, io.StringIO]:
"""
Gets the assembly form of the operation with all options available.
Expand Down Expand Up @@ -256,6 +257,7 @@ class _OperationBase:
assume_verified: bool = False,
file: Optional[Any] = None,
binary: bool = False,
skip_regions: bool = False,
) -> None:
"""
Prints the assembly form of the operation to a file like object.
Expand All @@ -281,6 +283,7 @@ class _OperationBase:
and report failures in a more robust fashion. Set this to True if doing this
in order to avoid running a redundant verification. If the IR is actually
invalid, behavior is undefined.
skip_regions: Whether to skip printing regions. Defaults to False.
"""
def verify(self) -> bool:
"""
Expand Down
18 changes: 15 additions & 3 deletions mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// function.
MlirRegion region = mlirOperationGetRegion(operation, 0);
MlirBlock block = mlirRegionGetFirstBlock(region);
operation = mlirBlockGetFirstOperation(block);
region = mlirOperationGetRegion(operation, 0);
MlirOperation parentOperation = operation;
MlirOperation function = mlirBlockGetFirstOperation(block);
region = mlirOperationGetRegion(function, 0);
MlirOperation parentOperation = function;
block = mlirRegionGetFirstBlock(region);
operation = mlirBlockGetFirstOperation(block);
assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
Expand Down Expand Up @@ -490,6 +490,18 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
// clang-format on

mlirOpPrintingFlagsDestroy(flags);
flags = mlirOpPrintingFlagsCreate();
mlirOpPrintingFlagsSkipRegions(flags);
fprintf(stderr, "Op print with skip regions flag: ");
mlirOperationPrintWithFlags(function, flags, printToStderr, NULL);
fprintf(stderr, "\n");
// clang-format off
// CHECK: Op print with skip regions flag: func.func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>)
// CHECK-NOT: constant
// CHECK-NOT: return
// clang-format on

fprintf(stderr, "With state: |");
mlirValuePrintAsOperand(value, state, printToStderr, NULL);
// CHECK: With state: |%0|
Expand Down
9 changes: 8 additions & 1 deletion mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def testOperationPrint():
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
module.operation.print(state)

# Test get_asm with options.
# Test print with options.
# CHECK: value = dense_resource<__elided__> : tensor<4xi32>
# CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
module.operation.print(
Expand All @@ -642,6 +642,13 @@ def testOperationPrint():
use_local_scope=True,
)

# Test print with skip_regions option
# CHECK: func.func @f1(%arg0: i32) -> i32
# CHECK-NOT: func.return
module.body.operations[0].print(
skip_regions=True,
)


# CHECK-LABEL: TEST: testKnownOpView
@run
Expand Down

0 comments on commit abad845

Please sign in to comment.