diff --git a/mlir/include/mlir/IR/Remarks.h b/mlir/include/mlir/IR/Remarks.h index 20e84ec83cd01..bfba018d99b82 100644 --- a/mlir/include/mlir/IR/Remarks.h +++ b/mlir/include/mlir/IR/Remarks.h @@ -13,6 +13,7 @@ #ifndef MLIR_IR_REMARKS_H #define MLIR_IR_REMARKS_H +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/Remarks/Remark.h" @@ -60,22 +61,27 @@ struct RemarkOpts { StringRef categoryName; // Category name (subject to regex filtering) StringRef subCategoryName; // Subcategory name StringRef functionName; // Function name if available + bool postponed = false; // Postpone showing the remark // Construct RemarkOpts from a remark name. static constexpr RemarkOpts name(StringRef n) { - return RemarkOpts{n, {}, {}, {}}; + return RemarkOpts{n, {}, {}, {}, false}; } /// Return a copy with the category set. constexpr RemarkOpts category(StringRef v) const { - return {remarkName, v, subCategoryName, functionName}; + return {remarkName, v, subCategoryName, functionName, postponed}; } /// Return a copy with the subcategory set. constexpr RemarkOpts subCategory(StringRef v) const { - return {remarkName, categoryName, v, functionName}; + return {remarkName, categoryName, v, functionName, postponed}; } /// Return a copy with the function name set. constexpr RemarkOpts function(StringRef v) const { - return {remarkName, categoryName, subCategoryName, v}; + return {remarkName, categoryName, subCategoryName, v, postponed}; + } + /// Return a copy with the postponed flag set. + constexpr RemarkOpts postpone() const { + return {remarkName, categoryName, subCategoryName, functionName, true}; } }; @@ -92,7 +98,7 @@ class Remark { RemarkOpts opts) : remarkKind(remarkKind), functionName(opts.functionName), loc(loc), categoryName(opts.categoryName), subCategoryName(opts.subCategoryName), - remarkName(opts.remarkName) { + remarkName(opts.remarkName), postponed(opts.postponed) { if (!categoryName.empty() && !subCategoryName.empty()) { (llvm::Twine(categoryName) + ":" + subCategoryName) .toStringRef(fullCategoryName); @@ -168,6 +174,8 @@ class Remark { StringRef getRemarkTypeString() const; + bool isPostponed() const { return postponed; } + protected: /// Keeps the MLIR diagnostic kind, which is used to determine the /// diagnostic kind in the LLVM remark streamer. @@ -191,6 +199,9 @@ class Remark { /// Args collected via the streaming interface. SmallVector args; + /// Whether the remark is postponed (to be shown later). + bool postponed = false; + private: /// Convert the MLIR diagnostic severity to LLVM diagnostic severity. static llvm::DiagnosticSeverity @@ -344,6 +355,10 @@ class MLIRRemarkStreamerBase { class RemarkEngine { private: + /// Postponed remarks. They are deferred to the end of the pipeline, where the + /// user can intercept them for custom processing, otherwise they will be + /// reported on engine destruction. + llvm::SmallVector postponedRemarks; /// Regex that filters missed optimization remarks: only matching one are /// reported. std::optional missFilter; @@ -392,6 +407,12 @@ class RemarkEngine { InFlightRemark emitIfEnabled(Location loc, RemarkOpts opts, bool (RemarkEngine::*isEnabled)(StringRef) const); + /// Emit all postponed remarks. + void emitPostponedRemarks(); + + /// Report a remark. When `forcePrintPostponedRemarks` is true, the remark + /// will be printed even if it is postponed. + void reportImpl(const Remark &remark); public: /// Default constructor is deleted, use the other constructor. @@ -411,7 +432,7 @@ class RemarkEngine { std::string *errMsg); /// Report a remark. - void report(const Remark &&remark); + void report(const Remark &remark); /// Report a successful remark, this will create an InFlightRemark /// that can be used to build the remark using the << operator. @@ -428,6 +449,12 @@ class RemarkEngine { /// Report an analysis remark, this will create an InFlightRemark /// that can be used to build the remark using the << operator. InFlightRemark emitOptimizationRemarkAnalysis(Location loc, RemarkOpts opts); + + /// Get the postponed remarks. + ArrayRef getPostponedRemarks() const { return postponedRemarks; } + + /// Clear the postponed remarks. + void clearPostponedRemarks() { postponedRemarks.clear(); } }; template diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 1fa04ed8e738f..b9c8bf1de6787 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -278,6 +278,9 @@ class MLIRContextImpl { } } ~MLIRContextImpl() { + // finalize remark engine before destroying anything else. + remarkEngine.reset(); + for (auto typeMapping : registeredTypes) typeMapping.second->~AbstractType(); for (auto attrMapping : registeredAttributes) diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp index a55f61aff77bb..6df002401f05a 100644 --- a/mlir/lib/IR/Remarks.cpp +++ b/mlir/lib/IR/Remarks.cpp @@ -157,7 +157,7 @@ llvm::remarks::Remark Remark::generateRemark() const { InFlightRemark::~InFlightRemark() { if (remark && owner) - owner->report(std::move(*remark)); + owner->report(*remark); owner = nullptr; } @@ -225,7 +225,7 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc, // RemarkEngine //===----------------------------------------------------------------------===// -void RemarkEngine::report(const Remark &&remark) { +void RemarkEngine::reportImpl(const Remark &remark) { // Stream the remark if (remarkStreamer) remarkStreamer->streamOptimizationRemark(remark); @@ -235,7 +235,25 @@ void RemarkEngine::report(const Remark &&remark) { emitRemark(remark.getLocation(), remark.getMsg()); } +void RemarkEngine::report(const Remark &remark) { + // Postponed remarks are deferred to the end of pipeline. + if (remark.isPostponed()) { + postponedRemarks.push_back(remark); + return; + } + + reportImpl(remark); +} + +void RemarkEngine::emitPostponedRemarks() { + for (auto &remark : postponedRemarks) + reportImpl(remark); + postponedRemarks.clear(); +} + RemarkEngine::~RemarkEngine() { + emitPostponedRemarks(); + if (remarkStreamer) remarkStreamer->finalize(); } diff --git a/mlir/unittests/IR/RemarkTest.cpp b/mlir/unittests/IR/RemarkTest.cpp index 5bfca255c22ca..ece7b9fb8624a 100644 --- a/mlir/unittests/IR/RemarkTest.cpp +++ b/mlir/unittests/IR/RemarkTest.cpp @@ -280,7 +280,7 @@ TEST(Remark, TestCustomOptimizationRemarkDiagnostic) { Location loc = UnknownLoc::get(&context); // Setup the remark engine - mlir::remark::RemarkCategories cats{/*all=*/"", + mlir::remark::RemarkCategories cats{/*all=*/std::nullopt, /*passed=*/categoryLoopunroll, /*missed=*/std::nullopt, /*analysis=*/std::nullopt, @@ -315,4 +315,94 @@ TEST(Remark, TestCustomOptimizationRemarkDiagnostic) { EXPECT_NE(errOut.find(pass2Msg), std::string::npos); // printed EXPECT_EQ(errOut.find(pass3Msg), std::string::npos); // filtered out } + +TEST(Remark, TestCustomOptimizationRemarkPostponeDiagnostic) { + testing::internal::CaptureStderr(); + const auto *pass1Msg = "My message"; + const auto *pass2Msg = "My another message"; + const auto *pass3Msg = "Do not show this message"; + + std::string categoryLoopunroll("LoopUnroll"); + std::string myPassname1("myPass1"); + std::string myPassname2("myPass2"); + std::string funcName("myFunc"); + + { + MLIRContext context; + Location loc = UnknownLoc::get(&context); + + // Setup the remark engine + mlir::remark::RemarkCategories cats{/*all=*/std::nullopt, + /*passed=*/categoryLoopunroll, + /*missed=*/std::nullopt, + /*analysis=*/std::nullopt, + /*failed=*/categoryLoopunroll}; + + LogicalResult isEnabled = remark::enableOptimizationRemarks( + context, std::make_unique(), cats, true); + ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine"; + + // Postponed remark should not be printed yet. + // Remark 1: pass, category LoopUnroll + { + remark::passed(loc, remark::RemarkOpts::name("") + .category(categoryLoopunroll) + .subCategory(myPassname2) + .postpone()) + << pass1Msg; + llvm::errs().flush(); + std::string errOut1 = testing::internal::GetCapturedStderr(); + // Ensure no remark has been printed yet. + EXPECT_TRUE(errOut1.empty()) + << "Expected no stderr output before postponed remarks are flushed"; + } + { + // Postponed remark should not be printed yet. + testing::internal::CaptureStderr(); + // Remark 2: failure, category LoopUnroll + remark::failed(loc, remark::RemarkOpts::name("") + .category(categoryLoopunroll) + .subCategory(myPassname2) + .postpone()) + << remark::reason(pass2Msg); + + llvm::errs().flush(); + std::string errOut2 = testing::internal::GetCapturedStderr(); + // Ensure no remark has been printed yet. + EXPECT_TRUE(errOut2.empty()) + << "Expected no stderr output before postponed remarks are flushed"; + } + { + // Remark 3: pass, category Inline (should be printed) + testing::internal::CaptureStderr(); + remark::passed(loc, remark::RemarkOpts::name("") + .category(categoryLoopunroll) + .subCategory(myPassname1)) + << pass3Msg; + + llvm::errs().flush(); + std::string errOut = testing::internal::GetCapturedStderr(); + auto third = errOut.find("Custom remark:"); + EXPECT_NE(third, std::string::npos); + } + + testing::internal::CaptureStderr(); + } + + llvm::errs().flush(); + std::string errOut = ::testing::internal::GetCapturedStderr(); + + // Expect exactly two "Custom remark:" lines. + auto first = errOut.find("Custom remark:"); + EXPECT_NE(first, std::string::npos); + auto second = errOut.find("Custom remark:", first + 1); + EXPECT_NE(second, std::string::npos); + auto third = errOut.find("Custom remark:", second + 1); + EXPECT_EQ(third, std::string::npos); + + // Containment checks for messages. + EXPECT_NE(errOut.find(pass1Msg), std::string::npos); // printed + EXPECT_NE(errOut.find(pass2Msg), std::string::npos); // printed +} + } // namespace