275 changes: 7 additions & 268 deletions mlir/lib/Pass/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ struct OpPassManagerImpl {
/// recursively through the pipeline graph.
void coalesceAdjacentAdaptorPasses();

/// Split all of AdaptorPasses such that each adaptor only contains one leaf
/// pass.
void splitAdaptorPasses();

/// Return the operation name of this pass manager as an identifier.
Identifier getOpName(MLIRContext &context) {
if (!identifier)
Expand Down Expand Up @@ -213,27 +209,6 @@ void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
}

void OpPassManagerImpl::splitAdaptorPasses() {
std::vector<std::unique_ptr<Pass>> oldPasses;
std::swap(passes, oldPasses);

for (std::unique_ptr<Pass> &pass : oldPasses) {
// If this pass isn't an adaptor, move it directly to the new pass list.
auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get());
if (!currentAdaptor) {
addPass(std::move(pass));
continue;
}

// Otherwise, split the adaptors of each manager within the adaptor.
for (OpPassManager &adaptorPM : currentAdaptor->getPassManagers()) {
adaptorPM.getImpl().splitAdaptorPasses();
for (std::unique_ptr<Pass> &nestedPass : adaptorPM.getImpl().passes)
nest(adaptorPM.getOpName()).addPass(std::move(nestedPass));
}
}
}

//===----------------------------------------------------------------------===//
// OpPassManager
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -645,210 +620,6 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
signalPassFailure();
}

//===----------------------------------------------------------------------===//
// PassCrashReproducer
//===----------------------------------------------------------------------===//

namespace {
/// This class contains all of the context for generating a recovery reproducer.
/// Each recovery context is registered globally to allow for generating
/// reproducers when a signal is raised, such as a segfault.
struct RecoveryReproducerContext {
RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,
Operation *op,
PassManager::ReproducerStreamFactory &crashStream,
bool disableThreads, bool verifyPasses);
~RecoveryReproducerContext();

/// Generate a reproducer with the current context.
LogicalResult generate(std::string &error);

private:
/// This function is invoked in the event of a crash.
static void crashHandler(void *);

/// Register a signal handler to run in the event of a crash.
static void registerSignalHandler();

/// The textual description of the currently executing pipeline.
std::string pipeline;

/// The MLIR operation representing the IR before the crash.
Operation *preCrashOperation;

/// The factory for the reproducer output stream to use when generating the
/// reproducer.
PassManager::ReproducerStreamFactory &crashStreamFactory;

/// Various pass manager and context flags.
bool disableThreads;
bool verifyPasses;

/// The current set of active reproducer contexts. This is used in the event
/// of a crash. This is not thread_local as the pass manager may produce any
/// number of child threads. This uses a set to allow for multiple MLIR pass
/// managers to be running at the same time.
static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
static llvm::ManagedStatic<
llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
reproducerSet;
};

/// Instance of ReproducerStream backed by file.
struct FileReproducerStream : public PassManager::ReproducerStream {
FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
: outputFile(std::move(outputFile)) {}
~FileReproducerStream() override;

/// Description of the reproducer stream.
StringRef description() override;

/// Stream on which to output reprooducer.
raw_ostream &os() override;

private:
/// ToolOutputFile corresponding to opened `filename`.
std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
};

} // end anonymous namespace

llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
RecoveryReproducerContext::reproducerMutex;
llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
RecoveryReproducerContext::reproducerSet;

RecoveryReproducerContext::RecoveryReproducerContext(
MutableArrayRef<std::unique_ptr<Pass>> passes, Operation *op,
PassManager::ReproducerStreamFactory &crashStreamFactory,
bool disableThreads, bool verifyPasses)
: preCrashOperation(op->clone()), crashStreamFactory(crashStreamFactory),
disableThreads(disableThreads), verifyPasses(verifyPasses) {
// Grab the textual pipeline being executed..
{
llvm::raw_string_ostream pipelineOS(pipeline);
::printAsTextualPipeline(passes, pipelineOS);
}

// Make sure that the handler is registered, and update the current context.
llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
if (reproducerSet->empty())
llvm::CrashRecoveryContext::Enable();
registerSignalHandler();
reproducerSet->insert(this);
}

RecoveryReproducerContext::~RecoveryReproducerContext() {
// Erase the cloned preCrash IR that we cached.
preCrashOperation->erase();

llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
reproducerSet->remove(this);
if (reproducerSet->empty())
llvm::CrashRecoveryContext::Disable();
}

/// Description of the reproducer stream.
StringRef FileReproducerStream::description() {
return outputFile->getFilename();
}

/// Stream on which to output reproducer.
raw_ostream &FileReproducerStream::os() { return outputFile->os(); }

FileReproducerStream::~FileReproducerStream() { outputFile->keep(); }

LogicalResult RecoveryReproducerContext::generate(std::string &error) {
std::unique_ptr<PassManager::ReproducerStream> crashStream =
crashStreamFactory(error);
if (!crashStream)
return failure();

// Output the current pass manager configuration.
auto &os = crashStream->os();
os << "// configuration: -pass-pipeline='" << pipeline << "'";
if (disableThreads)
os << " -mlir-disable-threading";
if (verifyPasses)
os << " -verify-each";
os << '\n';

// Output the .mlir module.
preCrashOperation->print(os);

bool shouldPrintOnOp =
preCrashOperation->getContext()->shouldPrintOpOnDiagnostic();
preCrashOperation->getContext()->printOpOnDiagnostic(false);
preCrashOperation->emitError()
<< "A failure has been detected while processing the MLIR module, a "
"reproducer has been generated in '"
<< crashStream->description() << "'";
preCrashOperation->getContext()->printOpOnDiagnostic(shouldPrintOnOp);
return success();
}

void RecoveryReproducerContext::crashHandler(void *) {
// Walk the current stack of contexts and generate a reproducer for each one.
// We can't know for certain which one was the cause, so we need to generate
// a reproducer for all of them.
std::string ignored;
for (RecoveryReproducerContext *context : *reproducerSet)
(void)context->generate(ignored);
}

void RecoveryReproducerContext::registerSignalHandler() {
// Ensure that the handler is only registered once.
static bool registered =
(llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
(void)registered;
}

/// Run the pass manager with crash recover enabled.
LogicalResult PassManager::runWithCrashRecovery(Operation *op,
AnalysisManager am) {
// If this isn't a local producer, run all of the passes in recovery mode.
if (!localReproducer)
return runWithCrashRecovery(impl->passes, op, am);

// Split the passes within adaptors to ensure that each pass can be run in
// isolation.
impl->splitAdaptorPasses();

// If this is a local producer, run each of the passes individually.
MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
for (std::unique_ptr<Pass> &pass : passes)
if (failed(runWithCrashRecovery(pass, op, am)))
return failure();
return success();
}

/// Run the given passes with crash recover enabled.
LogicalResult
PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
Operation *op, AnalysisManager am) {
RecoveryReproducerContext context(passes, op, crashReproducerStreamFactory,
!getContext()->isMultithreadingEnabled(),
verifyPasses);

// Safely invoke the passes within a recovery context.
LogicalResult passManagerResult = failure();
llvm::CrashRecoveryContext recoveryContext;
recoveryContext.RunSafelyOnThread([&] {
for (std::unique_ptr<Pass> &pass : passes)
if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses,
impl->initializationGeneration)))
return;
passManagerResult = success();
});
if (succeeded(passManagerResult))
return success();

std::string error;
if (failed(context.generate(error)))
return op->emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
return failure();
}

//===----------------------------------------------------------------------===//
// PassManager
//===----------------------------------------------------------------------===//
Expand All @@ -857,7 +628,7 @@ PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
StringRef operationName)
: OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
initializationKey(DenseMapInfo<llvm::hash_code>::getTombstoneKey()),
passTiming(false), localReproducer(false), verifyPasses(true) {}
passTiming(false), verifyPasses(true) {}

PassManager::~PassManager() {}

Expand Down Expand Up @@ -898,10 +669,7 @@ LogicalResult PassManager::run(Operation *op) {
// If reproducer generation is enabled, run the pass manager with crash
// handling enabled.
LogicalResult result =
crashReproducerStreamFactory
? runWithCrashRecovery(op, am)
: OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses,
impl->initializationGeneration);
crashReproGenerator ? runWithCrashRecovery(op, am) : runPasses(op, am);

// Notify the context that the run is done.
context->exitMultiThreadedExecution();
Expand All @@ -912,40 +680,6 @@ LogicalResult PassManager::run(Operation *op) {
return result;
}

/// Enable support for the pass manager to generate a reproducer on the event
/// of a crash or a pass failure. `outputFile` is a .mlir filename used to write
/// the generated reproducer. If `genLocalReproducer` is true, the pass manager
/// will attempt to generate a local reproducer that contains the smallest
/// pipeline.
void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
bool genLocalReproducer) {
// Capture the filename by value in case outputFile is out of scope when
// invoked.
std::string filename = outputFile.str();
enableCrashReproducerGeneration(
[filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
std::unique_ptr<llvm::ToolOutputFile> outputFile =
mlir::openOutputFile(filename, &error);
if (!outputFile) {
error = "Failed to create reproducer stream: " + error;
return nullptr;
}
return std::make_unique<FileReproducerStream>(std::move(outputFile));
},
genLocalReproducer);
}

/// Enable support for the pass manager to generate a reproducer on the event
/// of a crash or a pass failure. `factory` is used to construct the streams
/// to write the generated reproducer to. If `genLocalReproducer` is true, the
/// pass manager will attempt to generate a local reproducer that contains the
/// smallest pipeline.
void PassManager::enableCrashReproducerGeneration(
ReproducerStreamFactory factory, bool genLocalReproducer) {
crashReproducerStreamFactory = factory;
localReproducer = genLocalReproducer;
}

/// Add the provided instrumentation to the pass manager.
void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
if (!instrumentor)
Expand All @@ -954,6 +688,11 @@ void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
instrumentor->addInstrumentation(std::move(pi));
}

LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) {
return OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses,
impl->initializationGeneration);
}

//===----------------------------------------------------------------------===//
// AnalysisManager
//===----------------------------------------------------------------------===//
Expand Down
441 changes: 441 additions & 0 deletions mlir/lib/Pass/PassCrashRecovery.cpp

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions mlir/lib/Pass/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,43 @@ class OpToOpPassAdaptor
friend class mlir::PassManager;
};

//===----------------------------------------------------------------------===//
// PassCrashReproducerGenerator
//===----------------------------------------------------------------------===//

class PassCrashReproducerGenerator {
public:
PassCrashReproducerGenerator(
PassManager::ReproducerStreamFactory &streamFactory,
bool localReproducer);
~PassCrashReproducerGenerator();

/// Initialize the generator in preparation for reproducer generation. The
/// generator should be reinitialized before each run of the pass manager.
void initialize(iterator_range<PassManager::pass_iterator> passes,
Operation *op, bool pmFlagVerifyPasses);
/// Finalize the current run of the generator, generating any necessary
/// reproducers if the provided execution result is a failure.
void finalize(Operation *rootOp, LogicalResult executionResult);

/// Prepare a new reproducer for the given pass, operating on `op`.
void prepareReproducerFor(Pass *pass, Operation *op);

/// Prepare a new reproducer for the given passes, operating on `op`.
void prepareReproducerFor(iterator_range<PassManager::pass_iterator> passes,
Operation *op);

/// Remove the last recorded reproducer anchored at the given pass and
/// operation.
void removeLastReproducerFor(Pass *pass, Operation *op);

private:
struct Impl;

/// The internal implementation of the crash reproducer.
std::unique_ptr<Impl> impl;
};

} // end namespace detail
} // end namespace mlir
#endif // MLIR_PASS_PASSDETAIL_H_
13 changes: 10 additions & 3 deletions mlir/lib/Pass/PassManagerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ struct PassManagerOptions {
llvm::cl::desc(
"When printing the IR after a pass, only print if the IR changed"),
llvm::cl::init(false)};
llvm::cl::opt<bool> printAfterFailure{
"print-ir-after-failure",
llvm::cl::desc(
"When printing the IR after a pass, only print if the pass failed"),
llvm::cl::init(false)};
llvm::cl::opt<bool> printModuleScope{
"print-ir-module-scope",
llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} "
Expand Down Expand Up @@ -96,8 +101,9 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
}

// Handle print-after.
if (printAfterAll) {
// If we are printing after all, then just return true for the filter.
if (printAfterAll || printAfterFailure) {
// If we are printing after all or failure, then just return true for the
// filter.
shouldPrintAfterPass = [](Pass *, Operation *) { return true; };
} else if (printAfter.hasAnyOccurrences()) {
// Otherwise if there are specific passes to print after, then check to see
Expand All @@ -114,7 +120,8 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {

// Otherwise, add the IR printing instrumentation.
pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
printModuleScope, printAfterChange, llvm::errs());
printModuleScope, printAfterChange, printAfterFailure,
llvm::errs());
}

void mlir::registerPassManagerCLOptions() {
Expand Down
41 changes: 24 additions & 17 deletions mlir/test/Pass/crash-recovery.mlir
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
// RUN: mlir-opt %s -pass-pipeline='func(test-function-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics
// RUN: mlir-opt %s -pass-pipeline='module(test-module-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics
// RUN: cat %t | FileCheck -check-prefix=REPRO %s
// RUN: mlir-opt %s -pass-pipeline='func(test-function-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics -pass-pipeline-local-reproducer
// RUN: mlir-opt %s -pass-pipeline='module(test-module-pass, test-pass-crash)' -pass-pipeline-crash-reproducer=%t -verify-diagnostics -pass-pipeline-local-reproducer -mlir-disable-threading
// RUN: cat %t | FileCheck -check-prefix=REPRO_LOCAL %s

// Check that we correctly handle verifiers passes with local reproducer, this use to crash.
// RUN: mlir-opt %s -test-function-pass -test-function-pass -test-module-pass -pass-pipeline-crash-reproducer=%t -pass-pipeline-local-reproducer
// Check that we correctly handle verifiers passes with local reproducer, this used to crash.
// RUN: mlir-opt %s -test-module-pass -test-module-pass -test-module-pass -pass-pipeline-crash-reproducer=%t -pass-pipeline-local-reproducer -mlir-disable-threading
// RUN: cat %t | FileCheck -check-prefix=REPRO_LOCAL %s

// Check that local reproducers will also traverse dynamic pass pipelines.
// RUN: mlir-opt %s -pass-pipeline='test-module-pass,test-dynamic-pipeline{op-name=inner_mod1 run-on-nested-operations=1 dynamic-pipeline=test-pass-crash}' -pass-pipeline-crash-reproducer=%t -verify-diagnostics -pass-pipeline-local-reproducer --mlir-disable-threading
// RUN: cat %t | FileCheck -check-prefix=REPRO_LOCAL_DYNAMIC %s

// expected-error@+1 {{A failure has been detected while processing the MLIR module}}
module {
func @foo() {
return
}
// expected-error@below {{Failures have been detected while processing an MLIR pass pipeline}}
// expected-note@below {{Pipeline failed while executing}}
module @inner_mod1 {
module @foo {}
}

// REPRO: configuration: -pass-pipeline='func(test-function-pass, test-pass-crash)'
// REPRO: configuration: -pass-pipeline='module(test-module-pass, test-pass-crash)'

// REPRO: module @inner_mod1
// REPRO: module @foo {

// REPRO_LOCAL: configuration: -pass-pipeline='module(test-pass-crash)'

// REPRO: module
// REPRO: func @foo() {
// REPRO-NEXT: return
// REPRO_LOCAL: module @inner_mod1
// REPRO_LOCAL: module @foo {

// REPRO_LOCAL: configuration: -pass-pipeline='func(test-pass-crash)'
// REPRO_LOCAL_DYNAMIC: configuration: -pass-pipeline='module(test-pass-crash)'

// REPRO_LOCAL: module
// REPRO_LOCAL: func @foo() {
// REPRO_LOCAL-NEXT: return
// REPRO_LOCAL_DYNAMIC: module @inner_mod1
// REPRO_LOCAL_DYNAMIC: module @foo {
4 changes: 4 additions & 0 deletions mlir/test/Pass/ir-printing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s
// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s
// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s
// RUN: not mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,test-pass-failure)' -print-ir-after-failure -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_FAILURE %s

func @foo() {
%0 = constant 0 : i32
Expand Down Expand Up @@ -60,3 +61,6 @@ func @bar() {
// AFTER_ALL_CHANGE-NOT: *** IR Dump After{{.*}}CSE ***
// We expect that only 'foo' changed during CSE, and the second run of CSE did
// nothing.

// AFTER_FAILURE-NOT: *** IR Dump After{{.*}}CSE
// AFTER_FAILURE: *** IR Dump After{{.*}}TestFailurePass Failed ***
8 changes: 8 additions & 0 deletions mlir/test/lib/Pass/TestPassManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ class TestCrashRecoveryPass
void runOnOperation() final { abort(); }
};

/// A test pass that always fails to enable testing the failure recovery
/// mechanisms of the pass manager.
class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
void runOnOperation() final { signalPassFailure(); }
};

/// A test pass that contains a statistic.
struct TestStatisticPass
: public PassWrapper<TestStatisticPass, OperationPass<>> {
Expand Down Expand Up @@ -103,6 +109,8 @@ void registerPassManagerTestPass() {

PassRegistration<TestCrashRecoveryPass>(
"test-pass-crash", "Test a pass in the pass manager that always crashes");
PassRegistration<TestFailurePass>(
"test-pass-failure", "Test a pass in the pass manager that always fails");

PassRegistration<TestStatisticPass> unusedStatP("test-stats-pass",
"Test pass statistics");
Expand Down