Skip to content

Conversation

joker-eph
Copy link
Collaborator

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 22, 2024

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/117339.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/Pass.h (+4-1)
  • (modified) mlir/lib/Bindings/Python/Pass.cpp (+5-3)
  • (modified) mlir/lib/CAPI/IR/Pass.cpp (+13-5)
  • (modified) mlir/test/python/pass_manager.py (+43-1)
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 2218ec0f47d199..6019071cfdaa29 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -75,10 +75,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
 mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
 
 /// Enable IR printing.
+/// The treePrintingPath argument is an optional path to a directory
+/// where the dumps will be produced. If it isn't provided then dumps
+/// are produced to stderr.
 MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
     MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
     bool printModuleScope, bool printAfterOnlyOnChange,
-    bool printAfterOnlyOnFailure);
+    bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
 
 /// Enable / disable verify-each.
 MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1d0e5ce2115a0a..00cb9510cecd47 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -76,14 +76,16 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           "enable_ir_printing",
           [](PyPassManager &passManager, bool printBeforeAll,
              bool printAfterAll, bool printModuleScope, bool printAfterChange,
-             bool printAfterFailure) {
+             bool printAfterFailure, std::string treePrintingPath) {
             mlirPassManagerEnableIRPrinting(
                 passManager.get(), printBeforeAll, printAfterAll,
-                printModuleScope, printAfterChange, printAfterFailure);
+                printModuleScope, printAfterChange, printAfterFailure,
+                mlirStringRefCreate(treePrintingPath.data(),
+                                    treePrintingPath.size()));
           },
           "print_before_all"_a = false, "print_after_all"_a = true,
           "print_module_scope"_a = false, "print_after_change"_a = false,
-          "print_after_failure"_a = false,
+          "print_after_failure"_a = false, "tree_printing_dir_path"_a = "",
           "Enable IR printing, default as mlir-print-ir-after-all.")
       .def(
           "enable_verifier",
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index a6c9fbd08d45a6..01151eafeb5268 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -48,17 +48,25 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
                                      bool printBeforeAll, bool printAfterAll,
                                      bool printModuleScope,
                                      bool printAfterOnlyOnChange,
-                                     bool printAfterOnlyOnFailure) {
+                                     bool printAfterOnlyOnFailure,
+                                     MlirStringRef treePrintingPath) {
   auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
     return printBeforeAll;
   };
   auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
     return printAfterAll;
   };
-  return unwrap(passManager)
-      ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
-                         printModuleScope, printAfterOnlyOnChange,
-                         printAfterOnlyOnFailure);
+  if (unwrap(treePrintingPath).empty())
+    return unwrap(passManager)
+        ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
+                           printModuleScope, printAfterOnlyOnChange,
+                           printAfterOnlyOnFailure);
+
+  unwrap(passManager)
+      ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
+                                   printModuleScope, printAfterOnlyOnChange,
+                                   printAfterOnlyOnFailure,
+                                   unwrap(treePrintingPath));
 }
 
 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 74967032562351..f0698d8d3fb73d 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -1,6 +1,6 @@
 # RUN: %PYTHON %s 2>&1 | FileCheck %s
 
-import gc, sys
+import gc, os, sys, tempfile
 from mlir.ir import *
 from mlir.passmanager import *
 from mlir.dialects.func import FuncOp
@@ -340,3 +340,45 @@ def testPrintIrBeforeAndAfterAll():
         # CHECK:   }
         # CHECK: }
         pm.run(module)
+
+
+# CHECK-LABEL: TEST: testPrintIrTree
+@run
+def testPrintIrTree():
+    with Context() as ctx:
+        module = ModuleOp.parse(
+            """
+          module {
+            func.func @main() {
+              %0 = arith.constant 10
+              return
+            }
+          }
+        """
+        )
+        pm = PassManager.parse("builtin.module(canonicalize)")
+        ctx.enable_multithreading(False)
+        pm.enable_ir_printing()
+        # CHECK-LABEL: // Tree printing begin
+        # CHECK: └── builtin_module_no-symbol-name
+        # CHECK:     └── 0_canonicalize.mlir
+        # CHECK-LABEL: // Tree printing end
+        pm.run(module)
+        log("// Tree printing begin")
+        with tempfile.TemporaryDirectory() as temp_dir:
+            pm.enable_ir_printing(tree_printing_dir_path=temp_dir)
+            pm.run(module)
+
+            def print_file_tree(directory, prefix=""):
+                entries = sorted(os.listdir(directory))
+                for i, entry in enumerate(entries):
+                    path = os.path.join(directory, entry)
+                    connector = "└── " if i == len(entries) - 1 else "├── "
+                    log(f"{prefix}{connector}{entry}")
+                    if os.path.isdir(path):
+                        print_file_tree(
+                            path, prefix + ("    " if i == len(entries) - 1 else "│   ")
+                        )
+
+            print_file_tree(temp_dir)
+        log("// Tree printing end")

@joker-eph joker-eph force-pushed the mlir-print-ir-tree-dir-python branch from 4024e6e to dc0dad3 Compare November 23, 2024 16:59
Comment on lines 82 to +84
MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
bool printModuleScope, bool printAfterOnlyOnChange,
bool printAfterOnlyOnFailure);
bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random: I'd consider having a PassManagerPrintConfig object to make for a shorter and more stable API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can give it a try in a subsequent PR.

Do you mean it as a C API thing? A Python API thing? Both?

@joker-eph joker-eph force-pushed the mlir-print-ir-tree-dir-python branch from dc0dad3 to 4e8435c Compare November 23, 2024 18:35
@joker-eph joker-eph merged commit c8b837a into llvm:main Nov 23, 2024
8 checks passed
@joker-eph joker-eph deleted the mlir-print-ir-tree-dir-python branch November 23, 2024 19:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants