Skip to content

Commit 020b928

Browse files
authored
[MLIR] Implement remark emitting policies in MLIR (#160526)
This update introduces two new remark emitting policies: 1. `RemarkEmittingPolicyAll`, which emits all remarks, 2. `RemarkEmittingPolicyFinal`, which only emits final remarks after processing. The `RemarkEngine` is modified to support these policies, allowing for more flexible remark handling based on user configuration. PR also adds flag to `mlir-opt` ``` --remark-policy=<value> - Specify the policy for remark output. =all - Print all remarks =final - Print final remarks ```
1 parent 70a26da commit 020b928

File tree

10 files changed

+315
-35
lines changed

10 files changed

+315
-35
lines changed

mlir/include/mlir/IR/Remarks.h

Lines changed: 138 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "mlir/IR/MLIRContext.h"
2525
#include "mlir/IR/Value.h"
2626

27+
#include <functional>
28+
2729
namespace mlir::remark {
2830

2931
/// Define an the set of categories to accept. By default none are, the provided
@@ -144,7 +146,7 @@ class Remark {
144146

145147
llvm::StringRef getCategoryName() const { return categoryName; }
146148

147-
llvm::StringRef getFullCategoryName() const {
149+
llvm::StringRef getCombinedCategoryName() const {
148150
if (categoryName.empty() && subCategoryName.empty())
149151
return {};
150152
if (subCategoryName.empty())
@@ -318,7 +320,7 @@ class InFlightRemark {
318320
};
319321

320322
//===----------------------------------------------------------------------===//
321-
// MLIR Remark Streamer
323+
// Pluggable Remark Utilities
322324
//===----------------------------------------------------------------------===//
323325

324326
/// Base class for MLIR remark streamers that is used to stream
@@ -338,6 +340,26 @@ class MLIRRemarkStreamerBase {
338340
virtual void finalize() {} // optional
339341
};
340342

343+
using ReportFn = llvm::unique_function<void(const Remark &)>;
344+
345+
/// Base class for MLIR remark emitting policies that is used to emit
346+
/// optimization remarks to the underlying remark streamer. The derived classes
347+
/// should implement the `reportRemark` method to provide the actual emitting
348+
/// implementation.
349+
class RemarkEmittingPolicyBase {
350+
protected:
351+
ReportFn reportImpl;
352+
353+
public:
354+
RemarkEmittingPolicyBase() = default;
355+
virtual ~RemarkEmittingPolicyBase() = default;
356+
357+
void initialize(ReportFn fn) { reportImpl = std::move(fn); }
358+
359+
virtual void reportRemark(const Remark &remark) = 0;
360+
virtual void finalize() = 0;
361+
};
362+
341363
//===----------------------------------------------------------------------===//
342364
// Remark Engine (MLIR Context will own this class)
343365
//===----------------------------------------------------------------------===//
@@ -355,6 +377,8 @@ class RemarkEngine {
355377
std::optional<llvm::Regex> failedFilter;
356378
/// The MLIR remark streamer that will be used to emit the remarks.
357379
std::unique_ptr<MLIRRemarkStreamerBase> remarkStreamer;
380+
/// The MLIR remark policy that will be used to emit the remarks.
381+
std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy;
358382
/// When is enabled, engine also prints remarks as mlir::emitRemarks.
359383
bool printAsEmitRemarks = false;
360384

@@ -392,6 +416,8 @@ class RemarkEngine {
392416
InFlightRemark emitIfEnabled(Location loc, RemarkOpts opts,
393417
bool (RemarkEngine::*isEnabled)(StringRef)
394418
const);
419+
/// Report a remark.
420+
void reportImpl(const Remark &remark);
395421

396422
public:
397423
/// Default constructor is deleted, use the other constructor.
@@ -407,8 +433,10 @@ class RemarkEngine {
407433
~RemarkEngine();
408434

409435
/// Setup the remark engine with the given output path and format.
410-
LogicalResult initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
411-
std::string *errMsg);
436+
LogicalResult
437+
initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
438+
std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy,
439+
std::string *errMsg);
412440

413441
/// Report a remark.
414442
void report(const Remark &&remark);
@@ -446,6 +474,54 @@ inline InFlightRemark withEngine(Fn fn, Location loc, Args &&...args) {
446474

447475
namespace mlir::remark {
448476

477+
//===----------------------------------------------------------------------===//
478+
// Remark Emitting Policies
479+
//===----------------------------------------------------------------------===//
480+
481+
/// Policy that emits all remarks.
482+
class RemarkEmittingPolicyAll : public detail::RemarkEmittingPolicyBase {
483+
public:
484+
RemarkEmittingPolicyAll();
485+
486+
void reportRemark(const detail::Remark &remark) override {
487+
reportImpl(remark);
488+
}
489+
void finalize() override {}
490+
};
491+
492+
/// Policy that emits final remarks.
493+
class RemarkEmittingPolicyFinal : public detail::RemarkEmittingPolicyBase {
494+
private:
495+
/// user can intercept them for custom processing via a registered callback,
496+
/// otherwise they will be reported on engine destruction.
497+
llvm::DenseSet<detail::Remark> postponedRemarks;
498+
/// Optional user callback for intercepting postponed remarks.
499+
std::function<void(const detail::Remark &)> postponedRemarkCallback;
500+
501+
public:
502+
RemarkEmittingPolicyFinal();
503+
504+
/// Register a callback to intercept postponed remarks before they are
505+
/// reported. The callback will be invoked for each postponed remark in
506+
/// finalize().
507+
void
508+
setPostponedRemarkCallback(std::function<void(const detail::Remark &)> cb) {
509+
postponedRemarkCallback = std::move(cb);
510+
}
511+
512+
void reportRemark(const detail::Remark &remark) override {
513+
postponedRemarks.erase(remark);
514+
postponedRemarks.insert(remark);
515+
}
516+
void finalize() override {
517+
for (auto &remark : postponedRemarks) {
518+
if (postponedRemarkCallback)
519+
postponedRemarkCallback(remark);
520+
reportImpl(remark);
521+
}
522+
}
523+
};
524+
449525
/// Create a Reason with llvm::formatv formatting.
450526
template <class... Ts>
451527
inline detail::LazyTextBuild reason(const char *fmt, Ts &&...ts) {
@@ -505,16 +581,72 @@ inline detail::InFlightRemark analysis(Location loc, RemarkOpts opts) {
505581

506582
/// Setup remarks for the context. This function will enable the remark engine
507583
/// and set the streamer to be used for optimization remarks. The remark
508-
/// categories are used to filter the remarks that will be emitted by the remark
509-
/// engine. If a category is not specified, it will not be emitted. If
584+
/// categories are used to filter the remarks that will be emitted by the
585+
/// remark engine. If a category is not specified, it will not be emitted. If
510586
/// `printAsEmitRemarks` is true, the remarks will be printed as
511587
/// mlir::emitRemarks. 'streamer' must inherit from MLIRRemarkStreamerBase and
512588
/// will be used to stream the remarks.
513589
LogicalResult enableOptimizationRemarks(
514590
MLIRContext &ctx,
515591
std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
592+
std::unique_ptr<remark::detail::RemarkEmittingPolicyBase>
593+
remarkEmittingPolicy,
516594
const remark::RemarkCategories &cats, bool printAsEmitRemarks = false);
517595

518596
} // namespace mlir::remark
519597

598+
// DenseMapInfo specialization for Remark
599+
namespace llvm {
600+
template <>
601+
struct DenseMapInfo<mlir::remark::detail::Remark> {
602+
static constexpr StringRef kEmptyKey = "<EMPTY_KEY>";
603+
static constexpr StringRef kTombstoneKey = "<TOMBSTONE_KEY>";
604+
605+
/// Helper to provide a static dummy context for sentinel keys.
606+
static mlir::MLIRContext *getStaticDummyContext() {
607+
static mlir::MLIRContext dummyContext;
608+
return &dummyContext;
609+
}
610+
611+
/// Create an empty remark
612+
static inline mlir::remark::detail::Remark getEmptyKey() {
613+
return mlir::remark::detail::Remark(
614+
mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note,
615+
mlir::UnknownLoc::get(getStaticDummyContext()),
616+
mlir::remark::RemarkOpts::name(kEmptyKey));
617+
}
618+
619+
/// Create a dead remark
620+
static inline mlir::remark::detail::Remark getTombstoneKey() {
621+
return mlir::remark::detail::Remark(
622+
mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note,
623+
mlir::UnknownLoc::get(getStaticDummyContext()),
624+
mlir::remark::RemarkOpts::name(kTombstoneKey));
625+
}
626+
627+
/// Compute the hash value of the remark
628+
static unsigned getHashValue(const mlir::remark::detail::Remark &remark) {
629+
return llvm::hash_combine(
630+
remark.getLocation().getAsOpaquePointer(),
631+
llvm::hash_value(remark.getRemarkName()),
632+
llvm::hash_value(remark.getCombinedCategoryName()));
633+
}
634+
635+
static bool isEqual(const mlir::remark::detail::Remark &lhs,
636+
const mlir::remark::detail::Remark &rhs) {
637+
// Check for empty/tombstone keys first
638+
if (lhs.getRemarkName() == kEmptyKey ||
639+
lhs.getRemarkName() == kTombstoneKey ||
640+
rhs.getRemarkName() == kEmptyKey ||
641+
rhs.getRemarkName() == kTombstoneKey) {
642+
return lhs.getRemarkName() == rhs.getRemarkName();
643+
}
644+
645+
// For regular remarks, compare key identifying fields
646+
return lhs.getLocation() == rhs.getLocation() &&
647+
lhs.getRemarkName() == rhs.getRemarkName() &&
648+
lhs.getCombinedCategoryName() == rhs.getCombinedCategoryName();
649+
}
650+
};
651+
} // namespace llvm
520652
#endif // MLIR_IR_REMARKS_H

mlir/include/mlir/Remark/RemarkStreamer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ namespace mlir::remark {
4545
/// mlir::emitRemarks.
4646
LogicalResult enableOptimizationRemarksWithLLVMStreamer(
4747
MLIRContext &ctx, StringRef filePath, llvm::remarks::Format fmt,
48+
std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
4849
const RemarkCategories &cat, bool printAsEmitRemarks = false);
4950

5051
} // namespace mlir::remark

mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ enum class RemarkFormat {
4444
REMARK_FORMAT_BITSTREAM,
4545
};
4646

47+
enum class RemarkPolicy {
48+
REMARK_POLICY_ALL,
49+
REMARK_POLICY_FINAL,
50+
};
51+
4752
/// Configuration options for the mlir-opt tool.
4853
/// This is intended to help building tools like mlir-opt by collecting the
4954
/// supported options.
@@ -242,6 +247,8 @@ class MlirOptMainConfig {
242247

243248
/// Set the reproducer output filename
244249
RemarkFormat getRemarkFormat() const { return remarkFormatFlag; }
250+
/// Set the remark policy to use.
251+
RemarkPolicy getRemarkPolicy() const { return remarkPolicyFlag; }
245252
/// Set the remark format to use.
246253
std::string getRemarksAllFilter() const { return remarksAllFilterFlag; }
247254
/// Set the remark output file.
@@ -265,6 +272,8 @@ class MlirOptMainConfig {
265272

266273
/// Remark format
267274
RemarkFormat remarkFormatFlag = RemarkFormat::REMARK_FORMAT_STDOUT;
275+
/// Remark policy
276+
RemarkPolicy remarkPolicyFlag = RemarkPolicy::REMARK_POLICY_ALL;
268277
/// Remark file to output to
269278
std::string remarksOutputFileFlag = "";
270279
/// Remark filters

mlir/lib/IR/MLIRContext.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ class MLIRContextImpl {
278278
}
279279
}
280280
~MLIRContextImpl() {
281+
// finalize remark engine before destroying anything else.
282+
remarkEngine.reset();
281283
for (auto typeMapping : registeredTypes)
282284
typeMapping.second->~AbstractType();
283285
for (auto attrMapping : registeredAttributes)

mlir/lib/IR/Remarks.cpp

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "llvm/ADT/StringRef.h"
1717

1818
using namespace mlir::remark::detail;
19-
19+
using namespace mlir::remark;
2020
//------------------------------------------------------------------------------
2121
// Remark
2222
//------------------------------------------------------------------------------
@@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) {
7070
void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
7171
// Header: [Type] pass:remarkName
7272
StringRef type = getRemarkTypeString();
73-
StringRef categoryName = getFullCategoryName();
73+
StringRef categoryName = getCombinedCategoryName();
7474
StringRef name = remarkName;
7575

7676
os << '[' << type << "] ";
@@ -140,7 +140,7 @@ llvm::remarks::Remark Remark::generateRemark() const {
140140
r.RemarkType = getRemarkType();
141141
r.RemarkName = getRemarkName();
142142
// MLIR does not use passes; instead, it has categories and sub-categories.
143-
r.PassName = getFullCategoryName();
143+
r.PassName = getCombinedCategoryName();
144144
r.FunctionName = getFunction();
145145
r.Loc = locLambda();
146146
for (const Remark::Arg &arg : getArgs()) {
@@ -225,7 +225,7 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc,
225225
// RemarkEngine
226226
//===----------------------------------------------------------------------===//
227227

228-
void RemarkEngine::report(const Remark &&remark) {
228+
void RemarkEngine::reportImpl(const Remark &remark) {
229229
// Stream the remark
230230
if (remarkStreamer)
231231
remarkStreamer->streamOptimizationRemark(remark);
@@ -235,19 +235,19 @@ void RemarkEngine::report(const Remark &&remark) {
235235
emitRemark(remark.getLocation(), remark.getMsg());
236236
}
237237

238+
void RemarkEngine::report(const Remark &&remark) {
239+
if (remarkEmittingPolicy)
240+
remarkEmittingPolicy->reportRemark(remark);
241+
}
242+
238243
RemarkEngine::~RemarkEngine() {
244+
if (remarkEmittingPolicy)
245+
remarkEmittingPolicy->finalize();
246+
239247
if (remarkStreamer)
240248
remarkStreamer->finalize();
241249
}
242250

243-
llvm::LogicalResult
244-
RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
245-
std::string *errMsg) {
246-
// If you need to validate categories/filters, do so here and set errMsg.
247-
remarkStreamer = std::move(streamer);
248-
return success();
249-
}
250-
251251
/// Returns true if filter is already anchored like ^...$
252252
static bool isAnchored(llvm::StringRef s) {
253253
s = s.trim();
@@ -300,19 +300,44 @@ RemarkEngine::RemarkEngine(bool printAsEmitRemarks,
300300
failedFilter = buildFilter(cats, cats.failed);
301301
}
302302

303+
llvm::LogicalResult RemarkEngine::initialize(
304+
std::unique_ptr<MLIRRemarkStreamerBase> streamer,
305+
std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy,
306+
std::string *errMsg) {
307+
308+
remarkStreamer = std::move(streamer);
309+
310+
// Capture `this`. Ensure RemarkEngine is not moved after this.
311+
auto reportFunc = [this](const Remark &r) { this->reportImpl(r); };
312+
remarkEmittingPolicy->initialize(ReportFn(std::move(reportFunc)));
313+
314+
this->remarkEmittingPolicy = std::move(remarkEmittingPolicy);
315+
return success();
316+
}
317+
303318
llvm::LogicalResult mlir::remark::enableOptimizationRemarks(
304-
MLIRContext &ctx,
305-
std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
306-
const remark::RemarkCategories &cats, bool printAsEmitRemarks) {
319+
MLIRContext &ctx, std::unique_ptr<detail::MLIRRemarkStreamerBase> streamer,
320+
std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
321+
const RemarkCategories &cats, bool printAsEmitRemarks) {
307322
auto engine =
308-
std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats);
323+
std::make_unique<detail::RemarkEngine>(printAsEmitRemarks, cats);
309324

310325
std::string errMsg;
311-
if (failed(engine->initialize(std::move(streamer), &errMsg))) {
326+
if (failed(engine->initialize(std::move(streamer),
327+
std::move(remarkEmittingPolicy), &errMsg))) {
312328
llvm::report_fatal_error(
313329
llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg);
314330
}
315331
ctx.setRemarkEngine(std::move(engine));
316332

317333
return success();
318334
}
335+
336+
//===----------------------------------------------------------------------===//
337+
// Remark emitting policies
338+
//===----------------------------------------------------------------------===//
339+
340+
namespace mlir::remark {
341+
RemarkEmittingPolicyAll::RemarkEmittingPolicyAll() = default;
342+
RemarkEmittingPolicyFinal::RemarkEmittingPolicyFinal() = default;
343+
} // namespace mlir::remark

mlir/lib/Remark/RemarkStreamer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,16 @@ void LLVMRemarkStreamer::finalize() {
6060
namespace mlir::remark {
6161
LogicalResult enableOptimizationRemarksWithLLVMStreamer(
6262
MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt,
63+
std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
6364
const RemarkCategories &cat, bool printAsEmitRemarks) {
6465

6566
FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr =
6667
detail::LLVMRemarkStreamer::createToFile(path, fmt);
6768
if (failed(sOr))
6869
return failure();
6970

70-
return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat,
71+
return remark::enableOptimizationRemarks(ctx, std::move(*sOr),
72+
std::move(remarkEmittingPolicy), cat,
7173
printAsEmitRemarks);
7274
}
7375

0 commit comments

Comments
 (0)