diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md index 5772ee4b2d7454..16e3167e647007 100644 --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -86,8 +86,7 @@ struct MyFunctionPass : public PassWrapper( - "flag-name-to-invoke-pass-via-mlir-opt", "Pass description here"); + PassRegistration(); } ``` @@ -503,7 +502,15 @@ struct MyPass ... { /// ensure that the options are initialized properly. MyPass() = default; MyPass(const MyPass& pass) {} - + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "argument"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "description"; + } /// Define the statistic to track during the execution of MyPass. Statistic exampleStat{this, "exampleStat", "An example statistic"}; @@ -562,21 +569,22 @@ example registration is shown below: ```c++ void registerMyPass() { - PassRegistration("argument", "description"); + PassRegistration(); } ``` * `MyPass` is the name of the derived pass class. -* "argument" is the argument used to refer to the pass in the textual format. -* "description" is a brief description of the pass. +* The pass `getArgument()` method is used to get the identifier that will be + used to refer to the pass. +* The pass `getDescription()` method provides a short summary describing the + pass. For passes that cannot be default-constructed, `PassRegistration` accepts an -optional third argument that takes a callback to create the pass: +optional argument that takes a callback to create the pass: ```c++ void registerMyPass() { PassRegistration( - "argument", "description", []() -> std::unique_ptr { std::unique_ptr p = std::make_unique(/*options*/); /*... non-trivial-logic to configure the pass ...*/; @@ -710,7 +718,7 @@ std::unique_ptr foo::createMyPass() { /// Register this pass. void foo::registerMyPass() { - PassRegistration("my-pass", "My pass summary"); + PassRegistration(); } ``` diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 67c695467706a2..da91ef303cd1dd 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -73,10 +73,14 @@ class Pass { /// register the Affine dialect but does not need to register Linalg. virtual void getDependentDialects(DialectRegistry ®istry) const {} - /// Returns the command line argument used when registering this pass. Return + /// Return the command line argument used when registering this pass. Return /// an empty string if one does not exist. virtual StringRef getArgument() const { return ""; } + /// Return the command line description used when registering this pass. + /// Return an empty string if one does not exist. + virtual StringRef getDescription() const { return ""; } + /// Returns the name of the operation that this pass operates on, or None if /// this is a generic OperationPass. Optional getOpName() const { return opName; } diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h index d03aaf8dfd25f5..449889b9178122 100644 --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -125,20 +125,33 @@ void registerPassPipeline( /// Register a specific dialect pass allocator function with the system, /// typically used through the PassRegistration template. +/// Deprecated: please use the alternate version below. void registerPass(StringRef arg, StringRef description, const PassAllocatorFunction &function); +/// Register a specific dialect pass allocator function with the system, +/// typically used through the PassRegistration template. +void registerPass(const PassAllocatorFunction &function); + /// PassRegistration provides a global initializer that registers a Pass -/// allocation routine for a concrete pass instance. The third argument is +/// allocation routine for a concrete pass instance. The argument is /// optional and provides a callback to construct a pass that does not have /// a default constructor. /// /// Usage: /// /// /// At namespace scope. -/// static PassRegistration reg("my-pass", "My Pass Description."); +/// static PassRegistration reg; /// template struct PassRegistration { + PassRegistration(const PassAllocatorFunction &constructor) { + registerPass(constructor); + } + PassRegistration() + : PassRegistration([] { return std::make_unique(); }) {} + + /// Constructor below are deprecated. + PassRegistration(StringRef arg, StringRef description, const PassAllocatorFunction &constructor) { registerPass(arg, description, constructor); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 52a822c26b054c..ecd60de1104ed4 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -622,11 +622,6 @@ def PrintOpStats : Pass<"print-op-stats"> { let constructor = "mlir::createPrintOpStatsPass()"; } -def PrintOp : Pass<"print-op-graph", "ModuleOp"> { - let summary = "Print op graph per-Region"; - let constructor = "mlir::createPrintOpGraphPass()"; -} - def SCCP : Pass<"sccp"> { let summary = "Sparse Conditional Constant Propagation"; let description = [{ diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 2c690a2659ac3a..7f002ac0186ec1 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -122,6 +122,15 @@ void mlir::registerPass(StringRef arg, StringRef description, } } +void mlir::registerPass(const PassAllocatorFunction &function) { + std::unique_ptr pass = function(); + StringRef arg = pass->getArgument(); + if (arg.empty()) + llvm::report_fatal_error( + "Trying to register a pass that does not override `getArgument()`"); + registerPass(arg, pass->getDescription(), function); +} + /// Returns the pass info for the specified pass argument or null if unknown. const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) { auto it = passRegistry->find(passArg); diff --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir index 8ab60508b96079..4a5ac380632e10 100644 --- a/mlir/test/Transforms/print-op-graph.mlir +++ b/mlir/test/Transforms/print-op-graph.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -print-op-graph %s -o %t 2>&1 | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s // CHECK-LABEL: digraph "merge_blocks" // CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>} diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 35e6d980c9f39d..4d6432cbece288 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -71,10 +71,10 @@ def testParseFail(): def testInvalidNesting(): with Context(): try: - pm = PassManager.parse("func(print-op-graph)") + pm = PassManager.parse("func(view-op-graph)") except ValueError as e: # CHECK: Can't add pass 'ViewOpGraphPass' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest? - # CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'. + # CHECK: ValueError exception: invalid pass pipeline 'func(view-op-graph)'. log("ValueError exception:", e) else: log("Exception not produced") diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp index a33297e8fdfd76..8f3a19daaa5fb6 100644 --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -56,6 +56,8 @@ class {0}Base : public {1} { } ::llvm::StringRef getArgument() const override { return "{2}"; } + ::llvm::StringRef getDescription() const override { return "{3}"; } + /// Returns the derived pass name. static constexpr ::llvm::StringLiteral getPassName() { return ::llvm::StringLiteral("{0}"); @@ -74,7 +76,7 @@ class {0}Base : public {1} { /// Return the dialect that must be loaded in the context before this pass. void getDependentDialects(::mlir::DialectRegistry ®istry) const override { - {3} + {4} } protected: @@ -122,7 +124,8 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) { dependentDialect); } os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(), - pass.getArgument(), dependentDialectRegistrations); + pass.getArgument(), pass.getSummary(), + dependentDialectRegistrations); emitPassOptionDecls(pass, os); emitPassStatisticDecls(pass, os); os << "};\n"; @@ -154,8 +157,8 @@ const char *const passRegistrationCode = R"( //===----------------------------------------------------------------------===// inline void register{0}Pass() {{ - ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{ - return {3}; + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ + return {1}; }); } )"; @@ -175,7 +178,6 @@ static void emitRegistration(ArrayRef passes, raw_ostream &os) { os << "#ifdef GEN_PASS_REGISTRATION\n"; for (const Pass &pass : passes) { os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(), - pass.getArgument(), pass.getSummary(), pass.getConstructor()); }