-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] Add the --mlir-print-ir-tree-dir
to the C and Python API
#117339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesFull diff: https://github.com/llvm/llvm-project/pull/117339.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 2218ec0f47d199..6019071cfdaa29 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -75,10 +75,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
/// Enable IR printing.
+/// The treePrintingPath argument is an optional path to a directory
+/// where the dumps will be produced. If it isn't provided then dumps
+/// are produced to stderr.
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
bool printModuleScope, bool printAfterOnlyOnChange,
- bool printAfterOnlyOnFailure);
+ bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
/// Enable / disable verify-each.
MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1d0e5ce2115a0a..00cb9510cecd47 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -76,14 +76,16 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"enable_ir_printing",
[](PyPassManager &passManager, bool printBeforeAll,
bool printAfterAll, bool printModuleScope, bool printAfterChange,
- bool printAfterFailure) {
+ bool printAfterFailure, std::string treePrintingPath) {
mlirPassManagerEnableIRPrinting(
passManager.get(), printBeforeAll, printAfterAll,
- printModuleScope, printAfterChange, printAfterFailure);
+ printModuleScope, printAfterChange, printAfterFailure,
+ mlirStringRefCreate(treePrintingPath.data(),
+ treePrintingPath.size()));
},
"print_before_all"_a = false, "print_after_all"_a = true,
"print_module_scope"_a = false, "print_after_change"_a = false,
- "print_after_failure"_a = false,
+ "print_after_failure"_a = false, "tree_printing_dir_path"_a = "",
"Enable IR printing, default as mlir-print-ir-after-all.")
.def(
"enable_verifier",
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index a6c9fbd08d45a6..01151eafeb5268 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -48,17 +48,25 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
bool printBeforeAll, bool printAfterAll,
bool printModuleScope,
bool printAfterOnlyOnChange,
- bool printAfterOnlyOnFailure) {
+ bool printAfterOnlyOnFailure,
+ MlirStringRef treePrintingPath) {
auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
return printBeforeAll;
};
auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
return printAfterAll;
};
- return unwrap(passManager)
- ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
- printModuleScope, printAfterOnlyOnChange,
- printAfterOnlyOnFailure);
+ if (unwrap(treePrintingPath).empty())
+ return unwrap(passManager)
+ ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+ printModuleScope, printAfterOnlyOnChange,
+ printAfterOnlyOnFailure);
+
+ unwrap(passManager)
+ ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
+ printModuleScope, printAfterOnlyOnChange,
+ printAfterOnlyOnFailure,
+ unwrap(treePrintingPath));
}
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 74967032562351..f0698d8d3fb73d 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -1,6 +1,6 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s
-import gc, sys
+import gc, os, sys, tempfile
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.func import FuncOp
@@ -340,3 +340,45 @@ def testPrintIrBeforeAndAfterAll():
# CHECK: }
# CHECK: }
pm.run(module)
+
+
+# CHECK-LABEL: TEST: testPrintIrTree
+@run
+def testPrintIrTree():
+ with Context() as ctx:
+ module = ModuleOp.parse(
+ """
+ module {
+ func.func @main() {
+ %0 = arith.constant 10
+ return
+ }
+ }
+ """
+ )
+ pm = PassManager.parse("builtin.module(canonicalize)")
+ ctx.enable_multithreading(False)
+ pm.enable_ir_printing()
+ # CHECK-LABEL: // Tree printing begin
+ # CHECK: └── builtin_module_no-symbol-name
+ # CHECK: └── 0_canonicalize.mlir
+ # CHECK-LABEL: // Tree printing end
+ pm.run(module)
+ log("// Tree printing begin")
+ with tempfile.TemporaryDirectory() as temp_dir:
+ pm.enable_ir_printing(tree_printing_dir_path=temp_dir)
+ pm.run(module)
+
+ def print_file_tree(directory, prefix=""):
+ entries = sorted(os.listdir(directory))
+ for i, entry in enumerate(entries):
+ path = os.path.join(directory, entry)
+ connector = "└── " if i == len(entries) - 1 else "├── "
+ log(f"{prefix}{connector}{entry}")
+ if os.path.isdir(path):
+ print_file_tree(
+ path, prefix + (" " if i == len(entries) - 1 else "│ ")
+ )
+
+ print_file_tree(temp_dir)
+ log("// Tree printing end")
|
4024e6e
to
dc0dad3
Compare
MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, | ||
bool printModuleScope, bool printAfterOnlyOnChange, | ||
bool printAfterOnlyOnFailure); | ||
bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Random: I'd consider having a PassManagerPrintConfig
object to make for a shorter and more stable API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can give it a try in a subsequent PR.
Do you mean it as a C API thing? A Python API thing? Both?
dc0dad3
to
4e8435c
Compare
No description provided.