Skip to content

Commit

Permalink
[mlir][Pass] Make PassManager default to op-agnostic
Browse files Browse the repository at this point in the history
Currently `PassManager` defaults to being anchored on `builtin.module`.
Switching the default makes `PassManager` consistent with
`OpPassManager` and avoids the implicit dependency on `builtin.module`.

Specifying the anchor op type isn't strictly necessary when using
explicit nesting (existing pipelines will continue to work), but I've
updated most call sites to specify the anchor since it allows for better
error-checking during pipeline construction.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D137731
  • Loading branch information
rkayaith committed Jan 25, 2023
1 parent d1b775d commit 94a3092
Show file tree
Hide file tree
Showing 17 changed files with 51 additions and 34 deletions.
6 changes: 4 additions & 2 deletions flang/lib/Frontend/FrontendActions.cpp
Expand Up @@ -184,7 +184,8 @@ bool CodeGenAction::beginSourceFileAction() {
lb.lower(parseTree, ci.getInvocation().getSemanticsContext());

// run the default passes.
mlir::PassManager pm(mlirCtx.get(), mlir::OpPassManager::Nesting::Implicit);
mlir::PassManager pm((*mlirModule)->getName(),
mlir::OpPassManager::Nesting::Implicit);
pm.enableVerifier(/*verifyPasses=*/true);
pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());

Expand Down Expand Up @@ -535,7 +536,8 @@ void CodeGenAction::generateLLVMIR() {
fir::support::registerLLVMTranslation(*mlirCtx);

// Set-up the MLIR pass manager
mlir::PassManager pm(mlirCtx.get(), mlir::OpPassManager::Nesting::Implicit);
mlir::PassManager pm((*mlirModule)->getName(),
mlir::OpPassManager::Nesting::Implicit);

pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());
pm.enableVerifier(/*verifyPasses=*/true);
Expand Down
3 changes: 2 additions & 1 deletion flang/tools/bbc/bbc.cpp
Expand Up @@ -249,7 +249,8 @@ static mlir::LogicalResult convertFortranSourceToMLIR(
<< outputName;

// Otherwise run the default passes.
mlir::PassManager pm(&ctx, mlir::OpPassManager::Nesting::Implicit);
mlir::PassManager pm(mlirModule->getName(),
mlir::OpPassManager::Nesting::Implicit);
pm.enableVerifier(/*verifyPasses=*/true);
mlir::applyPassManagerCLOptions(pm);
if (passPipeline.hasAnyOccurrences()) {
Expand Down
3 changes: 2 additions & 1 deletion flang/tools/tco/tco.cpp
Expand Up @@ -103,7 +103,8 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
fir::KindMapping kindMap{&context};
fir::setTargetTriple(*owningRef, targetTriple);
fir::setKindMapping(*owningRef, kindMap);
mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit);
mlir::PassManager pm((*owningRef)->getName(),
mlir::OpPassManager::Nesting::Implicit);
pm.enableVerifier(/*verifyPasses=*/true);
mlir::applyPassManagerCLOptions(pm);
if (emitFir) {
Expand Down
7 changes: 2 additions & 5 deletions mlir/docs/PassManagement.md
Expand Up @@ -399,11 +399,8 @@ Below is an example of constructing a pipeline that operates on the above
structure:

```c++
// Create a top-level `PassManager` class. If an operation type is not
// explicitly specific, the default is the builtin `module` operation.
PassManager pm(ctx);
// Note: We could also create the above `PassManager` this way.
PassManager pm(ctx, /*operationName=*/"builtin.module");
// Create a top-level `PassManager` class.
auto pm = PassManager::on<ModuleOp>(ctx);

// Add a pass on the top-level module operation.
pm.addPass(std::make_unique<MyModulePass>());
Expand Down
2 changes: 1 addition & 1 deletion mlir/docs/Tutorials/Toy/Ch-3.md
Expand Up @@ -124,7 +124,7 @@ pipeline. In MLIR, the optimizations are run through a `PassManager` in a
similar way to LLVM:
```c++
mlir::PassManager pm(module.getContext());
mlir::PassManager pm(module->getName());
pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass());
```

Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch3/toyc.cpp
Expand Up @@ -113,7 +113,7 @@ int dumpMLIR() {
return error;

if (enableOpt) {
mlir::PassManager pm(&context);
mlir::PassManager pm(module.get()->getName());
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);

Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch4/toyc.cpp
Expand Up @@ -114,7 +114,7 @@ int dumpMLIR() {
return error;

if (enableOpt) {
mlir::PassManager pm(&context);
mlir::PassManager pm(module.get()->getName());
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);

Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch5/toyc.cpp
Expand Up @@ -117,7 +117,7 @@ int dumpMLIR() {
if (int error = loadMLIR(sourceMgr, context, module))
return error;

mlir::PassManager pm(&context);
mlir::PassManager pm(module.get()->getName());
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);

Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch6/toyc.cpp
Expand Up @@ -132,7 +132,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
if (int error = loadMLIR(context, module))
return error;

mlir::PassManager pm(&context);
mlir::PassManager pm(module.get()->getName());
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);

Expand Down
2 changes: 1 addition & 1 deletion mlir/examples/toy/Ch7/toyc.cpp
Expand Up @@ -132,7 +132,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
if (int error = loadMLIR(context, module))
return error;

mlir::PassManager pm(&context);
mlir::PassManager pm(module.get()->getName());
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);

Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/OperationSupport.h
Expand Up @@ -322,6 +322,9 @@ class OperationName {
/// Return the operation name with dialect name stripped, if it has one.
StringRef stripDialect() const { return getStringRef().split('.').second; }

/// Return the context this operation is associated with.
MLIRContext *getContext() { return getIdentifier().getContext(); }

/// Return the name of this operation. This always succeeds.
StringRef getStringRef() const { return getIdentifier(); }

Expand Down
21 changes: 14 additions & 7 deletions mlir/include/mlir/Pass/PassManager.h
Expand Up @@ -213,14 +213,20 @@ class PassManager : public OpPassManager {
/// Create a new pass manager under the given context with a specific nesting
/// style. The created pass manager can schedule operations that match
/// `operationName`.
/// FIXME: We should make the specification of `builtin.module` explicit here,
/// so that we can have top-level op-agnostic pass managers.
PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit,
StringRef operationName = "builtin.module");
PassManager(MLIRContext *ctx, StringRef operationName)
: PassManager(ctx, Nesting::Explicit, operationName) {}
PassManager(MLIRContext *ctx,
StringRef operationName = PassManager::getAnyOpAnchorName(),
Nesting nesting = Nesting::Explicit);
PassManager(OperationName operationName, Nesting nesting = Nesting::Explicit);
~PassManager();

/// Create a new pass manager under the given context with a specific nesting
/// style. The created pass manager can schedule operations that match
/// `OperationTy`.
template <typename OperationTy>
static PassManager on(MLIRContext *ctx, Nesting nesting = Nesting::Explicit) {
return PassManager(ctx, OperationTy::getOperationName(), nesting);
}

/// Run the passes within this manager on the provided operation. The
/// specified operation must have the same name as the one provided the pass
/// manager on construction.
Expand Down Expand Up @@ -438,7 +444,8 @@ class PassManager : public OpPassManager {
std::unique_ptr<detail::PassCrashReproducerGenerator> crashReproGenerator;

/// A hash key used to detect when reinitialization is necessary.
llvm::hash_code initializationKey;
llvm::hash_code initializationKey =
DenseMapInfo<llvm::hash_code>::getTombstoneKey();

/// Flag that specifies if pass timing is enabled.
bool passTiming : 1;
Expand Down
14 changes: 9 additions & 5 deletions mlir/lib/Pass/Pass.cpp
Expand Up @@ -769,11 +769,15 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
// PassManager
//===----------------------------------------------------------------------===//

PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
StringRef operationName)
: OpPassManager(OperationName(operationName, ctx), nesting), context(ctx),
initializationKey(DenseMapInfo<llvm::hash_code>::getTombstoneKey()),
passTiming(false), verifyPasses(true) {}
PassManager::PassManager(MLIRContext *ctx, StringRef operationName,
Nesting nesting)
: OpPassManager(operationName, nesting), context(ctx), passTiming(false),
verifyPasses(true) {}

PassManager::PassManager(OperationName operationName, Nesting nesting)
: OpPassManager(operationName, nesting),
context(operationName.getContext()), passTiming(false),
verifyPasses(true) {}

PassManager::~PassManager() = default;

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
Expand Up @@ -34,7 +34,7 @@ convertPDLToPDLInterp(ModuleOp pdlModule,
pdlModule.getBody()->walk(simplifyFn);

/// Lower the PDL pattern module to the interpreter dialect.
PassManager pdlPipeline(pdlModule.getContext());
PassManager pdlPipeline(pdlModule->getName());
#ifdef NDEBUG
// We don't want to incur the hit of running the verifier when in release
// mode.
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Expand Up @@ -79,8 +79,7 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
parserTiming.stop();

// Prepare the pass manager, applying command-line and reproducer options.
PassManager pm(context, OpPassManager::Nesting::Implicit,
op.get()->getName().getStringRef());
PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
pm.enableVerifier(verifyPasses);
applyPassManagerCLOptions(pm);
pm.enableTiming(timing);
Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/ExecutionEngine/Invoke.cpp
Expand Up @@ -52,7 +52,7 @@ static struct LLVMInitializer {
/// Simple conversion pipeline for the purpose of testing sources written in
/// dialects lowering to LLVM Dialect.
static LogicalResult lowerToLLVMDialect(ModuleOp module) {
PassManager pm(module.getContext());
PassManager pm(module->getName());
pm.addPass(mlir::createMemRefToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createConvertFuncToLLVMPass());
Expand Down
9 changes: 6 additions & 3 deletions mlir/unittests/Pass/PassManagerTest.cpp
Expand Up @@ -68,7 +68,7 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
}

// Instantiate and run our pass.
PassManager pm(&context);
auto pm = PassManager::on<ModuleOp>(&context);
pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
LogicalResult result = pm.run(module.get());
EXPECT_TRUE(succeeded(result));
Expand Down Expand Up @@ -123,7 +123,7 @@ TEST(PassManagerTest, InvalidPass) {
});

// Instantiate and run our pass.
PassManager pm(&context);
auto pm = PassManager::on<ModuleOp>(&context);
pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
LogicalResult result = pm.run(module.get());
EXPECT_TRUE(failed(result));
Expand All @@ -138,7 +138,10 @@ TEST(PassManagerTest, InvalidPass) {
EXPECT_TRUE(succeeded(result));

// Check that adding the pass at the top-level triggers a fatal error.
ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()),
"Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
"PassManager intended to run on 'builtin.module', did you "
"intend to nest?");
}

} // namespace

0 comments on commit 94a3092

Please sign in to comment.