diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 6e59b0f32ac6f..16f12d2dc16eb 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -463,6 +463,8 @@ class PassManager : public OpPassManager { void enableStatistics(PassDisplayMode displayMode = PassDisplayMode::Pipeline); + void dumpStatistics(raw_ostream &os, PassDisplayMode displayMode); + private: /// Dump the statistics of the passes within this pass manager. void dumpStatistics(); diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp index 01191aa824440..3d271129ebfe6 100644 --- a/mlir/lib/Pass/PassStatistics.cpp +++ b/mlir/lib/Pass/PassStatistics.cpp @@ -138,27 +138,26 @@ static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) { #endif } -static void printStatistics(OpPassManager &pm, PassDisplayMode displayMode) { - auto os = llvm::CreateInfoOutputFile(); - +static void printStatistics(OpPassManager &pm, raw_ostream &os, + PassDisplayMode displayMode) { // Print the stats header. - *os << "===" << std::string(73, '-') << "===\n"; + os << "===" << std::string(73, '-') << "===\n"; // Figure out how many spaces for the description name. unsigned padding = (80 - kPassStatsDescription.size()) / 2; - os->indent(padding) << kPassStatsDescription << '\n'; - *os << "===" << std::string(73, '-') << "===\n"; + os.indent(padding) << kPassStatsDescription << '\n'; + os << "===" << std::string(73, '-') << "===\n"; // Defer to a specialized printer for each display mode. switch (displayMode) { case PassDisplayMode::List: - printResultsAsList(*os, pm); + printResultsAsList(os, pm); break; case PassDisplayMode::Pipeline: - printResultsAsPipeline(*os, pm); + printResultsAsPipeline(os, pm); break; } - *os << "\n"; - os->flush(); + os << "\n"; + os.flush(); } //===----------------------------------------------------------------------===// @@ -242,8 +241,13 @@ static void prepareStatistics(OpPassManager &pm) { /// Dump the statistics of the passes within this pass manager. void PassManager::dumpStatistics() { + auto os = llvm::CreateInfoOutputFile(); + dumpStatistics(*os, passStatisticsMode.value()); +} + +void PassManager::dumpStatistics(raw_ostream &os, PassDisplayMode displayMode) { prepareStatistics(*this); - printStatistics(*this, *passStatisticsMode); + printStatistics(*this, os, displayMode); } /// Dump the statistics for each pass after running. diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp index 3f5db8ebcbb6d..ea6ecc1759ad8 100644 --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -380,4 +380,48 @@ TEST(PassManagerTest, PassInitialization) { EXPECT_TRUE(succeeded(pm.run(module.get()))); } +struct IncrementStatisticsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IncrementStatisticsPass) + + IncrementStatisticsPass() {}; + IncrementStatisticsPass(const IncrementStatisticsPass &other) {} + + void runOnOperation() override { + testStat1++; + testStat2 += 5; + testStat3 = 10; + } + + std::unique_ptr clonePass() const override { + return std::make_unique(); + } + +private: + Pass::Statistic testStat1{this, "test-stat-1", "Test1"}; + Pass::Statistic testStat2{this, "test-stat-2", "Test2"}; + Pass::Statistic testStat3{this, "test-stat-3", "Test3"}; +}; + +TEST(PassManagerTest, StatisticsPrint) { + MLIRContext context; + context.allowUnregisteredDialects(); + + OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); + + auto pm = PassManager::on(&context); + pm.addPass(std::make_unique()); + LogicalResult result = pm.run(module.get()); + EXPECT_TRUE(succeeded(result)); + + std::string statistics; + llvm::raw_string_ostream os(statistics); + + pm.dumpStatistics(os, PassDisplayMode::List); + + EXPECT_NE(statistics.find("1 test-stat-1"), std::string::npos); + EXPECT_NE(statistics.find("5 test-stat-2"), std::string::npos); + EXPECT_NE(statistics.find("10 test-stat-3"), std::string::npos); +} + } // namespace