Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlir/include/mlir/Pass/PassManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,20 @@ struct ReproducerStream {
using ReproducerStreamFactory =
std::function<std::unique_ptr<ReproducerStream>(std::string &error)>;

ReproducerStreamFactory makeReproducerStreamFactory(StringRef outputFile);

std::string
makeReproducer(StringRef anchorName,
const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
Operation *op, StringRef outputFile, bool disableThreads = false,
bool verifyPasses = false);

std::string
makeReproducer(StringRef anchorName,
const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
Operation *op, const ReproducerStreamFactory &streamFactory,
bool disableThreads = false, bool verifyPasses = false);

/// The main pass manager and pipeline builder.
class PassManager : public OpPassManager {
public:
Expand Down Expand Up @@ -282,6 +290,15 @@ class PassManager : public OpPassManager {
/// Add the provided instrumentation to the pass manager.
void addInstrumentation(std::unique_ptr<PassInstrumentation> pi);

/// Enable or disable the printing of pass manager reproducer.
void enableGeneratePassManagerReproducer(StringRef outputFile) {
forceGenerateReproducer = makeReproducerStreamFactory(outputFile);
}

void enableGeneratePassManagerReproducer(ReproducerStreamFactory factory) {
forceGenerateReproducer = std::move(factory);
}

//===--------------------------------------------------------------------===//
// IR Printing

Expand Down Expand Up @@ -492,6 +509,9 @@ class PassManager : public OpPassManager {
llvm::hash_code pipelineInitializationKey =
DenseMapInfo<llvm::hash_code>::getTombstoneKey();

/// A flag that indicates if the pass manager reproducer should be generated.
ReproducerStreamFactory forceGenerateReproducer;

/// Flag that specifies if pass timing is enabled.
bool passTiming : 1;

Expand Down
8 changes: 1 addition & 7 deletions mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,7 @@ class MlirOptMainConfig {
return hasFilters;
}

/// Reproducer file generation (no crash required).
StringRef getReproducerFilename() const { return generateReproducerFileFlag; }

/// Set the reproducer output filename
/// Set the remarks output filename
RemarkFormat getRemarkFormat() const { return remarkFormatFlag; }
/// Set the remark format to use.
std::string getRemarksAllFilter() const { return remarksAllFilterFlag; }
Expand Down Expand Up @@ -340,9 +337,6 @@ class MlirOptMainConfig {

/// Verify that the input IR round-trips perfectly.
bool verifyRoundtripFlag = false;

/// The reproducer output filename (no crash required).
std::string generateReproducerFileFlag = "";
};

/// This defines the function type used to setup the pass manager. This can be
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Pass/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,14 @@ LogicalResult PassManager::run(Operation *op) {
<< size() << " passes, verifyPasses=" << verifyPasses << " pipeline: ";
printAsTextualPipeline(os, /*pretty=*/false);
});
// Generate reproducers if requested
if (forceGenerateReproducer) {
StringRef anchorName = getAnyOpAnchorName();
const auto &passes = getPasses();
makeReproducer(anchorName, passes, op, forceGenerateReproducer,
/*disableThreads=*/!getContext()->isMultithreadingEnabled(),
verifyPasses);
}

MLIRContext *context = getContext();
std::optional<OperationName> anchorOp = getOpName(*context);
Expand Down
17 changes: 13 additions & 4 deletions mlir/lib/Pass/PassCrashRecovery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ LogicalResult PassManager::runWithCrashRecovery(Operation *op,
return passManagerResult;
}

static ReproducerStreamFactory
makeReproducerStreamFactory(StringRef outputFile) {
ReproducerStreamFactory
mlir::makeReproducerStreamFactory(StringRef outputFile) {
// Capture the filename by value in case outputFile is out of scope when
// invoked.
std::string filename = outputFile.str();
Expand All @@ -453,13 +453,22 @@ std::string mlir::makeReproducer(
const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
Operation *op, StringRef outputFile, bool disableThreads,
bool verifyPasses) {
return makeReproducer(anchorName, passes, op,
makeReproducerStreamFactory(outputFile), disableThreads,
verifyPasses);
}

std::string mlir::makeReproducer(
StringRef anchorName,
const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
Operation *op, const ReproducerStreamFactory &streamFactory,
bool disableThreads, bool verifyPasses) {
std::string description;
std::string pipelineStr;
llvm::raw_string_ostream passOS(pipelineStr);
::printAsTextualPipeline(passOS, anchorName, passes);
appendReproducer(description, op, makeReproducerStreamFactory(outputFile),
pipelineStr, disableThreads, verifyPasses);
appendReproducer(description, op, streamFactory, pipelineStr, disableThreads,
verifyPasses);
return description;
}

Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Pass/PassManagerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ struct PassManagerOptions {
llvm::cl::desc("When printing the IR before/after a pass, print file "
"tree rooted at this directory. Use in conjunction with "
"mlir-print-ir-* flags")};
llvm::cl::opt<std::string> generateReproducerFile{
"mlir-generate-reproducer",
llvm::cl::desc("Generate an mlir reproducer at the provided filename"
" (no crash required)"),
llvm::cl::init(""), llvm::cl::value_desc("filename")};

/// Add an IR printing instrumentation if enabled by any 'print-ir' flags.
void addPrinterInstrumentation(PassManager &pm);
Expand Down Expand Up @@ -172,6 +177,9 @@ LogicalResult mlir::applyPassManagerCLOptions(PassManager &pm) {

// Add the IR printing instrumentation.
options->addPrinterInstrumentation(pm);

if (options->generateReproducerFile.getNumOccurrences())
pm.enableGeneratePassManagerReproducer(options->generateReproducerFile);
return success();
}

Expand Down
17 changes: 0 additions & 17 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,6 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));

static cl::opt<std::string, /*ExternalStorage=*/true>
generateReproducerFile(
"mlir-generate-reproducer",
llvm::cl::desc(
"Generate an mlir reproducer at the provided filename"
" (no crash required)"),
cl::location(generateReproducerFileFlag), cl::init(""),
cl::value_desc("filename"));

static cl::OptionCategory remarkCategory(
"Remark Options",
"Filter remarks by regular expression (llvm::Regex syntax).");
Expand Down Expand Up @@ -568,14 +559,6 @@ performActions(raw_ostream &os,
if (failed(pm.run(*op)))
return failure();

// Generate reproducers if requested
if (!config.getReproducerFilename().empty()) {
StringRef anchorName = pm.getAnyOpAnchorName();
const auto &passes = pm.getPasses();
makeReproducer(anchorName, passes, op.get(),
config.getReproducerFilename());
}

// Print the output.
TimingScope outputTiming = timing.nest("Output");
if (config.shouldEmitBytecode()) {
Expand Down
Loading