diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 49a6442a79d90..16862ae4fe586 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -62,6 +62,10 @@ std::unique_ptr createStripDebugInfoPass(); /// the module. std::unique_ptr createPrintOpStatsPass(raw_ostream &os = llvm::errs()); +/// Creates a pass which prints the list of ops and the number of occurrences in +/// the module with the output format option. +std::unique_ptr createPrintOpStatsPass(raw_ostream &os, bool printAsJSON); + /// Creates a pass which inlines calls and callable operations as defined by /// the CallGraph. std::unique_ptr createInlinerPass(); diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp index e7740abb3d19f..0ee49e371b8dd 100644 --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -21,6 +21,10 @@ namespace { struct PrintOpStatsPass : public PrintOpStatsBase { explicit PrintOpStatsPass(raw_ostream &os) : os(os) {} + explicit PrintOpStatsPass(raw_ostream &os, bool printAsJSON) : os(os) { + this->printAsJSON = printAsJSON; + } + // Prints the resultant operation statistics post iterating over the module. void runOnOperation() override; @@ -107,3 +111,8 @@ void PrintOpStatsPass::printSummaryInJSON() { std::unique_ptr mlir::createPrintOpStatsPass(raw_ostream &os) { return std::make_unique(os); } + +std::unique_ptr mlir::createPrintOpStatsPass(raw_ostream &os, + bool printAsJSON) { + return std::make_unique(os); +}