Skip to content

Commit

Permalink
Decouple OpPassManager from the the MLIRContext (NFC)
Browse files Browse the repository at this point in the history
This is allowing to build an OpPassManager from a StringRef instead of an
Identifier, which enables building pipelines without an MLIRContext.
An identifier is still cached on-demand on the OpPassManager for efficiency
during the IR traversal.
  • Loading branch information
joker-eph committed Sep 3, 2020
1 parent 8d35080 commit c0b6bc0
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 39 deletions.
16 changes: 11 additions & 5 deletions mlir/include/mlir/Pass/PassManager.h
Expand Up @@ -47,7 +47,8 @@ struct OpPassManagerImpl;
/// other OpPassManagers or the top-level PassManager.
class OpPassManager {
public:
OpPassManager(Identifier name, MLIRContext *context, bool verifyPasses);
OpPassManager(Identifier name, bool verifyPasses);
OpPassManager(StringRef name, bool verifyPasses);
OpPassManager(OpPassManager &&rhs);
OpPassManager(const OpPassManager &rhs);
~OpPassManager();
Expand All @@ -73,7 +74,7 @@ class OpPassManager {
OpPassManager &nest(Identifier nestedName);
OpPassManager &nest(StringRef nestedName);
template <typename OpT> OpPassManager &nest() {
return nest(Identifier::get(OpT::getOperationName(), getContext()));
return nest(OpT::getOperationName());
}

/// Add the given pass to this pass manager. If this pass has a concrete
Expand All @@ -89,11 +90,11 @@ class OpPassManager {
/// Returns the number of passes held by this manager.
size_t size() const;

/// Return an instance of the context.
MLIRContext *getContext() const;
/// Return the operation name that this pass manager operates on.
Identifier getOpName(MLIRContext &context) const;

/// Return the operation name that this pass manager operates on.
Identifier getOpName() const;
StringRef getOpName() const;

/// Returns the internal implementation instance.
detail::OpPassManagerImpl &getImpl();
Expand Down Expand Up @@ -151,6 +152,9 @@ class PassManager : public OpPassManager {
LLVM_NODISCARD
LogicalResult run(ModuleOp module);

/// Return an instance of the context.
MLIRContext *getContext() const { return context; }

/// Enable support for the pass manager to generate a reproducer on the event
/// of a crash or a pass failure. `outputFile` is a .mlir filename used to
/// write the generated reproducer. If `genLocalReproducer` is true, the pass
Expand Down Expand Up @@ -304,6 +308,8 @@ class PassManager : public OpPassManager {
runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
ModuleOp module, AnalysisManager am);

MLIRContext *context;

/// Flag that specifies if pass statistics should be dumped.
Optional<PassDisplayMode> passStatisticsMode;

Expand Down
92 changes: 59 additions & 33 deletions mlir/lib/Pass/Pass.cpp
Expand Up @@ -92,18 +92,18 @@ void VerifierPass::runOnOperation() {
namespace mlir {
namespace detail {
struct OpPassManagerImpl {
OpPassManagerImpl(Identifier name, MLIRContext *ctx, bool verifyPasses)
: name(name), context(ctx), verifyPasses(verifyPasses) {}
OpPassManagerImpl(Identifier identifier, bool verifyPasses)
: name(identifier), identifier(identifier), verifyPasses(verifyPasses) {}
OpPassManagerImpl(StringRef name, bool verifyPasses)
: name(name), verifyPasses(verifyPasses) {}

/// Merge the passes of this pass manager into the one provided.
void mergeInto(OpPassManagerImpl &rhs);

/// Nest a new operation pass manager for the given operation kind under this
/// pass manager.
OpPassManager &nest(Identifier nestedName);
OpPassManager &nest(StringRef nestedName) {
return nest(Identifier::get(nestedName, getContext()));
}
OpPassManager &nest(StringRef nestedName);

/// Add the given pass to this pass manager. If this pass has a concrete
/// operation type, it must be the same type as this pass manager.
Expand All @@ -117,14 +117,18 @@ struct OpPassManagerImpl {
/// pass.
void splitAdaptorPasses();

/// Return an instance of the context.
MLIRContext *getContext() const { return context; }
Identifier getOpName(MLIRContext &context) {
if (!identifier)
identifier = Identifier::get(name, &context);
return *identifier;
}

/// The name of the operation that passes of this pass manager operate on.
Identifier name;
StringRef name;

/// The current context for this pass manager
MLIRContext *context;
/// The cached identifier (internalized in the context) for the name of the
/// operation that passes of this pass manager operate on.
Optional<Identifier> identifier;

/// Flag that specifies if the IR should be verified after each pass has run.
bool verifyPasses : 1;
Expand All @@ -143,7 +147,14 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
}

OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
OpPassManager nested(nestedName, getContext(), verifyPasses);
OpPassManager nested(nestedName, verifyPasses);
auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
addPass(std::unique_ptr<Pass>(adaptor));
return adaptor->getPassManagers().front();
}

OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) {
OpPassManager nested(nestedName, verifyPasses);
auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
addPass(std::unique_ptr<Pass>(adaptor));
return adaptor->getPassManagers().front();
Expand All @@ -153,7 +164,7 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
// If this pass runs on a different operation than this pass manager, then
// implicitly nest a pass manager for this operation.
auto passOpName = pass->getOpName();
if (passOpName && passOpName != name.strref())
if (passOpName && passOpName != name)
return nest(*passOpName).addPass(std::move(pass));

passes.emplace_back(std::move(pass));
Expand Down Expand Up @@ -240,14 +251,14 @@ void OpPassManagerImpl::splitAdaptorPasses() {
// OpPassManager
//===----------------------------------------------------------------------===//

OpPassManager::OpPassManager(Identifier name, MLIRContext *context,
bool verifyPasses)
: impl(new OpPassManagerImpl(name, context, verifyPasses)) {}
OpPassManager::OpPassManager(Identifier name, bool verifyPasses)
: impl(new OpPassManagerImpl(name, verifyPasses)) {}
OpPassManager::OpPassManager(StringRef name, bool verifyPasses)
: impl(new OpPassManagerImpl(name, verifyPasses)) {}
OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->getContext(),
rhs.impl->verifyPasses));
impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses));
for (auto &pass : rhs.impl->passes)
impl->passes.emplace_back(pass->clone());
return *this;
Expand Down Expand Up @@ -290,11 +301,13 @@ size_t OpPassManager::size() const { return impl->passes.size(); }
/// Returns the internal implementation instance.
OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }

/// Return an instance of the context.
MLIRContext *OpPassManager::getContext() const { return impl->getContext(); }
/// Return the operation name that this pass manager operates on.
StringRef OpPassManager::getOpName() const { return impl->name; }

/// Return the operation name that this pass manager operates on.
Identifier OpPassManager::getOpName() const { return impl->name; }
Identifier OpPassManager::getOpName(MLIRContext &context) const {
return impl->getOpName(context);
}

/// Prints out the given passes as the textual representation of a pipeline.
static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
Expand Down Expand Up @@ -389,12 +402,22 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
/// Find an operation pass manager that can operate on an operation of the given
/// type, or nullptr if one does not exist.
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
Identifier name) {
StringRef name) {
auto it = llvm::find_if(
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
return it == mgrs.end() ? nullptr : &*it;
}

/// Find an operation pass manager that can operate on an operation of the given
/// type, or nullptr if one does not exist.
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
Identifier name,
MLIRContext &context) {
auto it = llvm::find_if(
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
return it == mgrs.end() ? nullptr : &*it;
}

OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
mgrs.emplace_back(std::move(mgr));
}
Expand All @@ -421,8 +444,7 @@ void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
// After coalescing, sort the pass managers within rhs by name.
llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
[](const OpPassManager *lhs, const OpPassManager *rhs) {
return lhs->getOpName().strref().compare(
rhs->getOpName().strref());
return lhs->getOpName().compare(rhs->getOpName());
});
}

Expand Down Expand Up @@ -454,16 +476,18 @@ void OpToOpPassAdaptor::runOnOperationImpl() {
for (auto &region : getOperation()->getRegions()) {
for (auto &block : region) {
for (auto &op : block) {
auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier());
auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(),
*op.getContext());
if (!mgr)
continue;
Identifier opName = mgr->getOpName(*getOperation()->getContext());

// Run the held pipeline over the current operation.
if (instrumentor)
instrumentor->runBeforePipeline(mgr->getOpName(), parentInfo);
instrumentor->runBeforePipeline(opName, parentInfo);
auto result = runPipeline(mgr->getPasses(), &op, am.nest(&op));
if (instrumentor)
instrumentor->runAfterPipeline(mgr->getOpName(), parentInfo);
instrumentor->runAfterPipeline(opName, parentInfo);

if (failed(result))
return signalPassFailure();
Expand Down Expand Up @@ -499,7 +523,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
for (auto &block : region) {
for (auto &op : block) {
// Add this operation iff the name matches any of the pass managers.
if (findPassManagerFor(mgrs, op.getName().getIdentifier()))
if (findPassManagerFor(mgrs, op.getName().getIdentifier(),
getContext()))
opAMPairs.emplace_back(&op, am.nest(&op));
}
}
Expand Down Expand Up @@ -535,16 +560,17 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {

// Get the pass manager for this operation and execute it.
auto &it = opAMPairs[nextID];
auto *pm =
findPassManagerFor(pms, it.first->getName().getIdentifier());
auto *pm = findPassManagerFor(
pms, it.first->getName().getIdentifier(), getContext());
assert(pm && "expected valid pass manager for operation");

Identifier opName = pm->getOpName(*getOperation()->getContext());
if (instrumentor)
instrumentor->runBeforePipeline(pm->getOpName(), parentInfo);
instrumentor->runBeforePipeline(opName, parentInfo);
auto pipelineResult =
runPipeline(pm->getPasses(), it.first, it.second);
if (instrumentor)
instrumentor->runAfterPipeline(pm->getOpName(), parentInfo);
instrumentor->runAfterPipeline(opName, parentInfo);

// Drop this thread from being tracked by the diagnostic handler.
// After this task has finished, the thread may be used outside of
Expand Down Expand Up @@ -737,9 +763,9 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
//===----------------------------------------------------------------------===//

PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
: OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), ctx,
: OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx),
verifyPasses),
passTiming(false), localReproducer(false) {}
context(ctx), passTiming(false), localReproducer(false) {}

PassManager::~PassManager() {}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Pass/PassStatistics.cpp
Expand Up @@ -116,7 +116,7 @@ static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {

// Print each of the children passes.
for (OpPassManager &mgr : mgrs) {
auto name = ("'" + mgr.getOpName().strref() + "' Pipeline").str();
auto name = ("'" + mgr.getOpName() + "' Pipeline").str();
printPassEntry(os, indent, name);
for (Pass &pass : mgr.getPasses())
printPass(indent + 2, &pass);
Expand Down

0 comments on commit c0b6bc0

Please sign in to comment.