diff --git a/mlir/include/mlir/IR/Remarks.h b/mlir/include/mlir/IR/Remarks.h index 20e84ec83cd01..9877926116e24 100644 --- a/mlir/include/mlir/IR/Remarks.h +++ b/mlir/include/mlir/IR/Remarks.h @@ -18,7 +18,6 @@ #include "llvm/Remarks/Remark.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Regex.h" -#include #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" @@ -144,7 +143,7 @@ class Remark { llvm::StringRef getCategoryName() const { return categoryName; } - llvm::StringRef getFullCategoryName() const { + llvm::StringRef getCombinedCategoryName() const { if (categoryName.empty() && subCategoryName.empty()) return {}; if (subCategoryName.empty()) @@ -318,7 +317,7 @@ class InFlightRemark { }; //===----------------------------------------------------------------------===// -// MLIR Remark Streamer +// Pluggable Remark Utilities //===----------------------------------------------------------------------===// /// Base class for MLIR remark streamers that is used to stream @@ -338,6 +337,26 @@ class MLIRRemarkStreamerBase { virtual void finalize() {} // optional }; +using ReportFn = llvm::unique_function; + +/// Base class for MLIR remark emitting policies that is used to emit +/// optimization remarks to the underlying remark streamer. The derived classes +/// should implement the `reportRemark` method to provide the actual emitting +/// implementation. +class RemarkEmittingPolicyBase { +protected: + ReportFn reportImpl; + +public: + RemarkEmittingPolicyBase() = default; + virtual ~RemarkEmittingPolicyBase() = default; + + void initialize(ReportFn fn) { reportImpl = std::move(fn); } + + virtual void reportRemark(const Remark &remark) = 0; + virtual void finalize() = 0; +}; + //===----------------------------------------------------------------------===// // Remark Engine (MLIR Context will own this class) //===----------------------------------------------------------------------===// @@ -355,6 +374,8 @@ class RemarkEngine { std::optional failedFilter; /// The MLIR remark streamer that will be used to emit the remarks. std::unique_ptr remarkStreamer; + /// The MLIR remark policy that will be used to emit the remarks. + std::unique_ptr remarkEmittingPolicy; /// When is enabled, engine also prints remarks as mlir::emitRemarks. bool printAsEmitRemarks = false; @@ -392,6 +413,8 @@ class RemarkEngine { InFlightRemark emitIfEnabled(Location loc, RemarkOpts opts, bool (RemarkEngine::*isEnabled)(StringRef) const); + /// Report a remark. + void reportImpl(const Remark &remark); public: /// Default constructor is deleted, use the other constructor. @@ -407,8 +430,15 @@ class RemarkEngine { ~RemarkEngine(); /// Setup the remark engine with the given output path and format. - LogicalResult initialize(std::unique_ptr streamer, - std::string *errMsg); + LogicalResult + initialize(std::unique_ptr streamer, + std::unique_ptr remarkEmittingPolicy, + std::string *errMsg); + + /// Get the remark emitting policy. + RemarkEmittingPolicyBase *getRemarkEmittingPolicy() const { + return remarkEmittingPolicy.get(); + } /// Report a remark. void report(const Remark &&remark); @@ -446,6 +476,46 @@ inline InFlightRemark withEngine(Fn fn, Location loc, Args &&...args) { namespace mlir::remark { +//===----------------------------------------------------------------------===// +// Remark Emitting Policies +//===----------------------------------------------------------------------===// + +/// Policy that emits all remarks. +class RemarkEmittingPolicyAll : public detail::RemarkEmittingPolicyBase { +public: + RemarkEmittingPolicyAll(); + + void reportRemark(const detail::Remark &remark) override { + assert(reportImpl && "reportImpl is not set"); + reportImpl(remark); + } + void finalize() override {} +}; + +/// Policy that emits final remarks. +class RemarkEmittingPolicyFinal : public detail::RemarkEmittingPolicyBase { +private: + /// user can intercept them for custom processing via a registered callback, + /// otherwise they will be reported on engine destruction. + llvm::DenseSet postponedRemarks; + +public: + RemarkEmittingPolicyFinal(); + + void reportRemark(const detail::Remark &remark) override { + postponedRemarks.erase(remark); + postponedRemarks.insert(remark); + } + + void finalize() override { + assert(reportImpl && "reportImpl is not set"); + for (auto &remark : postponedRemarks) { + if (reportImpl) + reportImpl(remark); + } + } +}; + /// Create a Reason with llvm::formatv formatting. template inline detail::LazyTextBuild reason(const char *fmt, Ts &&...ts) { @@ -505,16 +575,72 @@ inline detail::InFlightRemark analysis(Location loc, RemarkOpts opts) { /// Setup remarks for the context. This function will enable the remark engine /// and set the streamer to be used for optimization remarks. The remark -/// categories are used to filter the remarks that will be emitted by the remark -/// engine. If a category is not specified, it will not be emitted. If +/// categories are used to filter the remarks that will be emitted by the +/// remark engine. If a category is not specified, it will not be emitted. If /// `printAsEmitRemarks` is true, the remarks will be printed as /// mlir::emitRemarks. 'streamer' must inherit from MLIRRemarkStreamerBase and /// will be used to stream the remarks. LogicalResult enableOptimizationRemarks( MLIRContext &ctx, std::unique_ptr streamer, + std::unique_ptr + remarkEmittingPolicy, const remark::RemarkCategories &cats, bool printAsEmitRemarks = false); } // namespace mlir::remark +// DenseMapInfo specialization for Remark +namespace llvm { +template <> +struct DenseMapInfo { + static constexpr StringRef kEmptyKey = ""; + static constexpr StringRef kTombstoneKey = ""; + + /// Helper to provide a static dummy context for sentinel keys. + static mlir::MLIRContext *getStaticDummyContext() { + static mlir::MLIRContext dummyContext; + return &dummyContext; + } + + /// Create an empty remark + static inline mlir::remark::detail::Remark getEmptyKey() { + return mlir::remark::detail::Remark( + mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note, + mlir::UnknownLoc::get(getStaticDummyContext()), + mlir::remark::RemarkOpts::name(kEmptyKey)); + } + + /// Create a dead remark + static inline mlir::remark::detail::Remark getTombstoneKey() { + return mlir::remark::detail::Remark( + mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note, + mlir::UnknownLoc::get(getStaticDummyContext()), + mlir::remark::RemarkOpts::name(kTombstoneKey)); + } + + /// Compute the hash value of the remark + static unsigned getHashValue(const mlir::remark::detail::Remark &remark) { + return llvm::hash_combine( + remark.getLocation().getAsOpaquePointer(), + llvm::hash_value(remark.getRemarkName()), + llvm::hash_value(remark.getCombinedCategoryName())); + } + + static bool isEqual(const mlir::remark::detail::Remark &lhs, + const mlir::remark::detail::Remark &rhs) { + // Check for empty/tombstone keys first + if (lhs.getRemarkName() == kEmptyKey || + lhs.getRemarkName() == kTombstoneKey || + rhs.getRemarkName() == kEmptyKey || + rhs.getRemarkName() == kTombstoneKey) { + return lhs.getRemarkName() == rhs.getRemarkName(); + } + + // For regular remarks, compare key identifying fields + return lhs.getLocation() == rhs.getLocation() && + lhs.getRemarkName() == rhs.getRemarkName() && + lhs.getCombinedCategoryName() == rhs.getCombinedCategoryName(); + } +}; +} // namespace llvm #endif // MLIR_IR_REMARKS_H diff --git a/mlir/include/mlir/Remark/RemarkStreamer.h b/mlir/include/mlir/Remark/RemarkStreamer.h index 170d6b439a442..19a70fa4c4daa 100644 --- a/mlir/include/mlir/Remark/RemarkStreamer.h +++ b/mlir/include/mlir/Remark/RemarkStreamer.h @@ -45,6 +45,7 @@ namespace mlir::remark { /// mlir::emitRemarks. LogicalResult enableOptimizationRemarksWithLLVMStreamer( MLIRContext &ctx, StringRef filePath, llvm::remarks::Format fmt, + std::unique_ptr remarkEmittingPolicy, const RemarkCategories &cat, bool printAsEmitRemarks = false); } // namespace mlir::remark diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h index 0fbe15fa2e0db..b7394387b0f9a 100644 --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -44,6 +44,11 @@ enum class RemarkFormat { REMARK_FORMAT_BITSTREAM, }; +enum class RemarkPolicy { + REMARK_POLICY_ALL, + REMARK_POLICY_FINAL, +}; + /// Configuration options for the mlir-opt tool. /// This is intended to help building tools like mlir-opt by collecting the /// supported options. @@ -242,6 +247,8 @@ class MlirOptMainConfig { /// Set the reproducer output filename RemarkFormat getRemarkFormat() const { return remarkFormatFlag; } + /// Set the remark policy to use. + RemarkPolicy getRemarkPolicy() const { return remarkPolicyFlag; } /// Set the remark format to use. std::string getRemarksAllFilter() const { return remarksAllFilterFlag; } /// Set the remark output file. @@ -265,6 +272,8 @@ class MlirOptMainConfig { /// Remark format RemarkFormat remarkFormatFlag = RemarkFormat::REMARK_FORMAT_STDOUT; + /// Remark policy + RemarkPolicy remarkPolicyFlag = RemarkPolicy::REMARK_POLICY_ALL; /// Remark file to output to std::string remarksOutputFileFlag = ""; /// Remark filters diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 1fa04ed8e738f..89b81cfb1e2f9 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -120,6 +120,11 @@ namespace mlir { /// This class is completely private to this file, so everything is public. class MLIRContextImpl { public: + //===--------------------------------------------------------------------===// + // Remark + //===--------------------------------------------------------------------===// + std::unique_ptr remarkEngine; + //===--------------------------------------------------------------------===// // Debugging //===--------------------------------------------------------------------===// @@ -134,11 +139,6 @@ class MLIRContextImpl { //===--------------------------------------------------------------------===// DiagnosticEngine diagEngine; - //===--------------------------------------------------------------------===// - // Remark - //===--------------------------------------------------------------------===// - std::unique_ptr remarkEngine; - //===--------------------------------------------------------------------===// // Options //===--------------------------------------------------------------------===// @@ -357,7 +357,10 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) impl->affineUniquer.registerParametricStorageType(); } -MLIRContext::~MLIRContext() = default; +MLIRContext::~MLIRContext() { + // finalize remark engine before destroying anything else. + impl->remarkEngine.reset(); +} /// Copy the specified array of elements into memory managed by the provided /// bump pointer allocator. This assumes the elements are all PODs. diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp index a55f61aff77bb..e310d23dd2cbb 100644 --- a/mlir/lib/IR/Remarks.cpp +++ b/mlir/lib/IR/Remarks.cpp @@ -16,7 +16,7 @@ #include "llvm/ADT/StringRef.h" using namespace mlir::remark::detail; - +using namespace mlir::remark; //------------------------------------------------------------------------------ // Remark //------------------------------------------------------------------------------ @@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef args) { void Remark::print(llvm::raw_ostream &os, bool printLocation) const { // Header: [Type] pass:remarkName StringRef type = getRemarkTypeString(); - StringRef categoryName = getFullCategoryName(); + StringRef categoryName = getCombinedCategoryName(); StringRef name = remarkName; os << '[' << type << "] "; @@ -81,9 +81,10 @@ void Remark::print(llvm::raw_ostream &os, bool printLocation) const { os << "Function=" << getFunction() << " | "; if (printLocation) { - if (auto flc = mlir::dyn_cast(getLocation())) + if (auto flc = mlir::dyn_cast(getLocation())) { os << " @" << flc.getFilename() << ":" << flc.getLine() << ":" << flc.getColumn(); + } } printArgs(os, getArgs()); @@ -140,7 +141,7 @@ llvm::remarks::Remark Remark::generateRemark() const { r.RemarkType = getRemarkType(); r.RemarkName = getRemarkName(); // MLIR does not use passes; instead, it has categories and sub-categories. - r.PassName = getFullCategoryName(); + r.PassName = getCombinedCategoryName(); r.FunctionName = getFunction(); r.Loc = locLambda(); for (const Remark::Arg &arg : getArgs()) { @@ -225,26 +226,39 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc, // RemarkEngine //===----------------------------------------------------------------------===// -void RemarkEngine::report(const Remark &&remark) { +void RemarkEngine::reportImpl(const Remark &remark) { // Stream the remark - if (remarkStreamer) + if (remarkStreamer) { remarkStreamer->streamOptimizationRemark(remark); + } // Print using MLIR's diagnostic if (printAsEmitRemarks) emitRemark(remark.getLocation(), remark.getMsg()); } +void RemarkEngine::report(const Remark &&remark) { + if (remarkEmittingPolicy) + remarkEmittingPolicy->reportRemark(remark); +} + RemarkEngine::~RemarkEngine() { if (remarkStreamer) remarkStreamer->finalize(); } -llvm::LogicalResult -RemarkEngine::initialize(std::unique_ptr streamer, - std::string *errMsg) { - // If you need to validate categories/filters, do so here and set errMsg. +llvm::LogicalResult RemarkEngine::initialize( + std::unique_ptr streamer, + std::unique_ptr remarkEmittingPolicy, + std::string *errMsg) { + remarkStreamer = std::move(streamer); + + auto reportFunc = + std::bind(&RemarkEngine::reportImpl, this, std::placeholders::_1); + remarkEmittingPolicy->initialize(ReportFn(std::move(reportFunc))); + + this->remarkEmittingPolicy = std::move(remarkEmittingPolicy); return success(); } @@ -301,14 +315,15 @@ RemarkEngine::RemarkEngine(bool printAsEmitRemarks, } llvm::LogicalResult mlir::remark::enableOptimizationRemarks( - MLIRContext &ctx, - std::unique_ptr streamer, - const remark::RemarkCategories &cats, bool printAsEmitRemarks) { + MLIRContext &ctx, std::unique_ptr streamer, + std::unique_ptr remarkEmittingPolicy, + const RemarkCategories &cats, bool printAsEmitRemarks) { auto engine = - std::make_unique(printAsEmitRemarks, cats); + std::make_unique(printAsEmitRemarks, cats); std::string errMsg; - if (failed(engine->initialize(std::move(streamer), &errMsg))) { + if (failed(engine->initialize(std::move(streamer), + std::move(remarkEmittingPolicy), &errMsg))) { llvm::report_fatal_error( llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg); } @@ -316,3 +331,12 @@ llvm::LogicalResult mlir::remark::enableOptimizationRemarks( return success(); } + +//===----------------------------------------------------------------------===// +// Remark emitting policies +//===----------------------------------------------------------------------===// + +namespace mlir::remark { +RemarkEmittingPolicyAll::RemarkEmittingPolicyAll() = default; +RemarkEmittingPolicyFinal::RemarkEmittingPolicyFinal() = default; +} // namespace mlir::remark diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp index d213a1a2068d6..bf362862d24f6 100644 --- a/mlir/lib/Remark/RemarkStreamer.cpp +++ b/mlir/lib/Remark/RemarkStreamer.cpp @@ -60,6 +60,7 @@ void LLVMRemarkStreamer::finalize() { namespace mlir::remark { LogicalResult enableOptimizationRemarksWithLLVMStreamer( MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt, + std::unique_ptr remarkEmittingPolicy, const RemarkCategories &cat, bool printAsEmitRemarks) { FailureOr> sOr = @@ -67,7 +68,8 @@ LogicalResult enableOptimizationRemarksWithLLVMStreamer( if (failed(sOr)) return failure(); - return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat, + return remark::enableOptimizationRemarks(ctx, std::move(*sOr), + std::move(remarkEmittingPolicy), cat, printAsEmitRemarks); } diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 30fd384f3977c..0766795a420d7 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -37,6 +37,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Remarks/RemarkFormat.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/ManagedStatic.h" @@ -226,6 +227,18 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { "bitstream", "Print bitstream file")), llvm::cl::cat(remarkCategory)}; + static llvm::cl::opt remarkPolicy{ + "remark-policy", + llvm::cl::desc("Specify the policy for remark output."), + cl::location(remarkPolicyFlag), + llvm::cl::value_desc("format"), + llvm::cl::init(RemarkPolicy::REMARK_POLICY_ALL), + llvm::cl::values(clEnumValN(RemarkPolicy::REMARK_POLICY_ALL, "all", + "Print all remarks"), + clEnumValN(RemarkPolicy::REMARK_POLICY_FINAL, "final", + "Print final remarks")), + llvm::cl::cat(remarkCategory)}; + static cl::opt remarksAll( "remarks-filter", cl::desc("Show all remarks: passed, missed, failed, analysis"), @@ -517,18 +530,28 @@ performActions(raw_ostream &os, return failure(); context->enableMultithreading(wasThreadingEnabled); - + // Set the remark categories and policy. remark::RemarkCategories cats{ config.getRemarksAllFilter(), config.getRemarksPassedFilter(), config.getRemarksMissedFilter(), config.getRemarksAnalyseFilter(), config.getRemarksFailedFilter()}; mlir::MLIRContext &ctx = *context; + // Helper to create the appropriate policy based on configuration + auto createPolicy = [&config]() + -> std::unique_ptr { + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_ALL) + return std::make_unique(); + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_FINAL) + return std::make_unique(); + + llvm_unreachable("Invalid remark policy"); + }; switch (config.getRemarkFormat()) { case RemarkFormat::REMARK_FORMAT_STDOUT: if (failed(mlir::remark::enableOptimizationRemarks( - ctx, nullptr, cats, true /*printAsEmitRemarks*/))) + ctx, nullptr, createPolicy(), cats, true /*printAsEmitRemarks*/))) return failure(); break; @@ -537,7 +560,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.yaml" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::YAML, cats))) + ctx, file, llvm::remarks::Format::YAML, createPolicy(), cats))) return failure(); break; } @@ -547,7 +570,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.bitstream" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::Bitstream, cats))) + ctx, file, llvm::remarks::Format::Bitstream, createPolicy(), cats))) return failure(); break; } @@ -593,6 +616,10 @@ performActions(raw_ostream &os, AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); os << OpWithState(op.get(), asmState) << '\n'; + + if (remark::detail::RemarkEngine *engine = ctx.getRemarkEngine()) + engine->getRemarkEmittingPolicy()->finalize(); + return success(); } diff --git a/mlir/test/Pass/remark-final.mlir b/mlir/test/Pass/remark-final.mlir new file mode 100644 index 0000000000000..325271e04cc5c --- /dev/null +++ b/mlir/test/Pass/remark-final.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s --test-remark --remarks-filter="category.*" --remark-policy=final 2>&1 | FileCheck %s +// RUN: mlir-opt %s --test-remark --remarks-filter="category.*" --remark-policy=final --remark-format=yaml --remarks-output-file=%t.yaml +// RUN: FileCheck --check-prefix=CHECK-YAML %s < %t.yaml +module @foo { + "test.op"() : () -> () + +} + +// CHECK-YAML-NOT: This is a test passed remark (should be dropped) +// CHECK-YAML-DAG: !Analysis +// CHECK-YAML-DAG: !Failure +// CHECK-YAML-DAG: !Passed + +// CHECK-NOT: This is a test passed remark (should be dropped) +// CHECK-DAG: remark: [Analysis] test-remark +// CHECK-DAG: remark: [Failure] test-remark | Category:category-2-failed +// CHECK-DAG: remark: [Passed] test-remark | Category:category-1-passed diff --git a/mlir/test/lib/Pass/TestRemarksPass.cpp b/mlir/test/lib/Pass/TestRemarksPass.cpp index 3b25686b3dc14..5ca2d1a8550aa 100644 --- a/mlir/test/lib/Pass/TestRemarksPass.cpp +++ b/mlir/test/lib/Pass/TestRemarksPass.cpp @@ -43,7 +43,12 @@ class TestRemarkPass : public PassWrapper> { << remark::add("This is a test missed remark") << remark::reason("because we are testing the remark pipeline") << remark::suggest("try using the remark pipeline feature"); - + mlir::remark::passed( + loc, + remark::RemarkOpts::name("test-remark").category("category-1-passed")) + << remark::add("This is a test passed remark (should be dropped)") + << remark::reason("because we are testing the remark pipeline") + << remark::suggest("try using the remark pipeline feature"); mlir::remark::passed( loc, remark::RemarkOpts::name("test-remark").category("category-1-passed")) diff --git a/mlir/unittests/IR/RemarkTest.cpp b/mlir/unittests/IR/RemarkTest.cpp index 5bfca255c22ca..885d226c8f24c 100644 --- a/mlir/unittests/IR/RemarkTest.cpp +++ b/mlir/unittests/IR/RemarkTest.cpp @@ -53,10 +53,12 @@ TEST(Remark, TestOutputOptimizationRemark) { /*missed=*/categoryUnroll, /*analysis=*/categoryRegister, /*failed=*/categoryInliner}; - + std::unique_ptr policy = + std::make_unique(); LogicalResult isEnabled = mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - context, yamlFile, llvm::remarks::Format::YAML, cats); + context, yamlFile, llvm::remarks::Format::YAML, std::move(policy), + cats); ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine"; // PASS: something succeeded @@ -203,9 +205,10 @@ TEST(Remark, TestOutputOptimizationRemarkDiagnostic) { /*missed=*/categoryUnroll, /*analysis=*/categoryRegister, /*failed=*/categoryUnroll}; - - LogicalResult isEnabled = - remark::enableOptimizationRemarks(context, nullptr, cats, true); + std::unique_ptr policy = + std::make_unique(); + LogicalResult isEnabled = remark::enableOptimizationRemarks( + context, nullptr, std::move(policy), cats, true); ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine"; @@ -286,8 +289,11 @@ TEST(Remark, TestCustomOptimizationRemarkDiagnostic) { /*analysis=*/std::nullopt, /*failed=*/categoryLoopunroll}; + std::unique_ptr policy = + std::make_unique(); LogicalResult isEnabled = remark::enableOptimizationRemarks( - context, std::make_unique(), cats, true); + context, std::make_unique(), std::move(policy), cats, + true); ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine"; // Remark 1: pass, category LoopUnroll @@ -315,4 +321,69 @@ TEST(Remark, TestCustomOptimizationRemarkDiagnostic) { EXPECT_NE(errOut.find(pass2Msg), std::string::npos); // printed EXPECT_EQ(errOut.find(pass3Msg), std::string::npos); // filtered out } + +TEST(Remark, TestRemarkFinal) { + testing::internal::CaptureStderr(); + const auto *pass1Msg = "I failed"; + const auto *pass2Msg = "I failed too"; + const auto *pass3Msg = "I succeeded"; + const auto *pass4Msg = "I succeeded too"; + + std::string categoryLoopunroll("LoopUnroll"); + + std::string seenMsg = ""; + + { + MLIRContext context; + Location loc = FileLineColLoc::get(&context, "test.cpp", 1, 5); + Location locOther = FileLineColLoc::get(&context, "test.cpp", 55, 5); + + // Setup the remark engine + mlir::remark::RemarkCategories cats{/*all=*/"", + /*passed=*/categoryLoopunroll, + /*missed=*/categoryLoopunroll, + /*analysis=*/categoryLoopunroll, + /*failed=*/categoryLoopunroll}; + + std::unique_ptr policy = + std::make_unique(); + LogicalResult isEnabled = remark::enableOptimizationRemarks( + context, std::make_unique(), std::move(policy), cats, + true); + ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine"; + + // Remark 1: failure + remark::failed( + loc, remark::RemarkOpts::name("Unroller").category(categoryLoopunroll)) + << pass1Msg; + + // Remark 2: failure + remark::missed( + loc, remark::RemarkOpts::name("Unroller").category(categoryLoopunroll)) + << remark::reason(pass2Msg); + + // Remark 3: pass + remark::passed( + loc, remark::RemarkOpts::name("Unroller").category(categoryLoopunroll)) + << pass3Msg; + + // Remark 4: pass + remark::passed( + locOther, + remark::RemarkOpts::name("Unroller").category(categoryLoopunroll)) + << pass4Msg; + + // Finalize the remark engine + policy->finalize(); + } + + llvm::errs().flush(); + std::string errOut = ::testing::internal::GetCapturedStderr(); + + // Containment checks for messages. + EXPECT_EQ(errOut.find(pass1Msg), std::string::npos); // dropped + EXPECT_EQ(errOut.find(pass2Msg), std::string::npos); // dropped + EXPECT_NE(errOut.find(pass3Msg), std::string::npos); // shown + EXPECT_NE(errOut.find(pass4Msg), std::string::npos); // shown +} } // namespace