Skip to content
129 changes: 119 additions & 10 deletions mlir/include/mlir/IR/Remarks.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef MLIR_IR_REMARKS_H
#define MLIR_IR_REMARKS_H

#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/Remarks/Remark.h"
Expand All @@ -21,6 +23,7 @@
#include <optional>

#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Value.h"

Expand Down Expand Up @@ -60,22 +63,27 @@ struct RemarkOpts {
StringRef categoryName; // Category name (subject to regex filtering)
StringRef subCategoryName; // Subcategory name
StringRef functionName; // Function name if available
bool postponed = false; // Postpone showing the remark

// Construct RemarkOpts from a remark name.
static constexpr RemarkOpts name(StringRef n) {
return RemarkOpts{n, {}, {}, {}};
return RemarkOpts{n, {}, {}, {}, false};
}
/// Return a copy with the category set.
constexpr RemarkOpts category(StringRef v) const {
return {remarkName, v, subCategoryName, functionName};
return {remarkName, v, subCategoryName, functionName, postponed};
}
/// Return a copy with the subcategory set.
constexpr RemarkOpts subCategory(StringRef v) const {
return {remarkName, categoryName, v, functionName};
return {remarkName, categoryName, v, functionName, postponed};
}
/// Return a copy with the function name set.
constexpr RemarkOpts function(StringRef v) const {
return {remarkName, categoryName, subCategoryName, v};
return {remarkName, categoryName, subCategoryName, v, postponed};
}
/// Return a copy with the postponed flag set.
constexpr RemarkOpts postpone() const {
return {remarkName, categoryName, subCategoryName, functionName, true};
}
};

Expand All @@ -92,10 +100,10 @@ class Remark {
RemarkOpts opts)
: remarkKind(remarkKind), functionName(opts.functionName), loc(loc),
categoryName(opts.categoryName), subCategoryName(opts.subCategoryName),
remarkName(opts.remarkName) {
remarkName(opts.remarkName), postponed(opts.postponed) {
if (!categoryName.empty() && !subCategoryName.empty()) {
(llvm::Twine(categoryName) + ":" + subCategoryName)
.toStringRef(fullCategoryName);
.toStringRef(combinedCategoryName);
}
}

Expand Down Expand Up @@ -144,14 +152,14 @@ class Remark {

llvm::StringRef getCategoryName() const { return categoryName; }

llvm::StringRef getFullCategoryName() const {
llvm::StringRef getCombinedCategoryName() const {
if (categoryName.empty() && subCategoryName.empty())
return {};
if (subCategoryName.empty())
return categoryName;
if (categoryName.empty())
return subCategoryName;
return fullCategoryName;
return combinedCategoryName;
}

StringRef getRemarkName() const {
Expand All @@ -168,6 +176,8 @@ class Remark {

StringRef getRemarkTypeString() const;

bool isPostponed() const { return postponed; }

protected:
/// Keeps the MLIR diagnostic kind, which is used to determine the
/// diagnostic kind in the LLVM remark streamer.
Expand All @@ -183,14 +193,17 @@ class Remark {
StringRef subCategoryName;

/// Combined name for category and sub-category
SmallString<64> fullCategoryName;
SmallString<64> combinedCategoryName;

/// Remark identifier
StringRef remarkName;

/// Args collected via the streaming interface.
SmallVector<Arg, 4> args;

/// Whether the remark is postponed (to be shown later).
bool postponed = false;

private:
/// Convert the MLIR diagnostic severity to LLVM diagnostic severity.
static llvm::DiagnosticSeverity
Expand Down Expand Up @@ -344,6 +357,10 @@ class MLIRRemarkStreamerBase {

class RemarkEngine {
private:
/// Postponed remarks. They are deferred to the end of the pipeline, where the
/// user can intercept them for custom processing, otherwise they will be
/// reported on engine destruction.
llvm::DenseSet<Remark> postponedRemarks;
/// Regex that filters missed optimization remarks: only matching one are
/// reported.
std::optional<llvm::Regex> missFilter;
Expand Down Expand Up @@ -392,6 +409,12 @@ class RemarkEngine {
InFlightRemark emitIfEnabled(Location loc, RemarkOpts opts,
bool (RemarkEngine::*isEnabled)(StringRef)
const);
/// Emit all postponed remarks.
void emitPostponedRemarks();

/// Report a remark. When `forcePrintPostponedRemarks` is true, the remark
/// will be printed even if it is postponed.
void reportImpl(const Remark &remark);

public:
/// Default constructor is deleted, use the other constructor.
Expand All @@ -411,7 +434,7 @@ class RemarkEngine {
std::string *errMsg);

/// Report a remark.
void report(const Remark &&remark);
void report(const Remark &remark);

/// Report a successful remark, this will create an InFlightRemark
/// that can be used to build the remark using the << operator.
Expand All @@ -428,6 +451,17 @@ class RemarkEngine {
/// Report an analysis remark, this will create an InFlightRemark
/// that can be used to build the remark using the << operator.
InFlightRemark emitOptimizationRemarkAnalysis(Location loc, RemarkOpts opts);

/// Get the postponed remarks.
const DenseSet<Remark> &getPostponedRemarks() const { return postponedRemarks; }

/// Clear the postponed remarks.
void clearPostponedRemarks() { postponedRemarks.clear(); }

/// Drop a postponed remark.
bool dropPostponedRemark(Remark &remark) {
return postponedRemarks.erase(remark);
}
};

template <typename Fn, typename... Args>
Expand Down Expand Up @@ -498,6 +532,21 @@ inline detail::InFlightRemark analysis(Location loc, RemarkOpts opts) {
return withEngine(&detail::RemarkEngine::emitOptimizationRemarkAnalysis, loc,
opts);
}
//===----------------------------------------------------------------------===//
// Utils
//===----------------------------------------------------------------------===//

/// Drop a postponed remark.
inline bool dropPostponedRemark(Location loc, RemarkKind remarkKind,
RemarkOpts opts) {
MLIRContext *ctx = loc->getContext();
detail::Remark remark(remarkKind, DiagnosticSeverity::Remark, loc, opts);

// Drop the remark from the postponed remarks.
if (detail::RemarkEngine *engine = ctx->getRemarkEngine())
return engine->dropPostponedRemark(remark);
return false;
}

//===----------------------------------------------------------------------===//
// Setup
Expand All @@ -517,4 +566,64 @@ LogicalResult enableOptimizationRemarks(

} // namespace mlir::remark

//===----------------------------------------------------------------------===//
// DenseMapInfo specialization for Remark
//===----------------------------------------------------------------------===//

namespace llvm {
template <>
struct DenseMapInfo<mlir::remark::detail::Remark> {
static constexpr StringRef kEmptyKey = "<EMPTY_KEY>";
static constexpr StringRef kTombstoneKey = "<TOMBSTONE_KEY>";

/// Helper to provide a static dummy context for sentinel keys.
static mlir::MLIRContext *getStaticDummyContext() {
static mlir::MLIRContext dummyContext;
return &dummyContext;
}

/// Create an empty remark
static inline mlir::remark::detail::Remark getEmptyKey() {
return mlir::remark::detail::Remark(
mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note,
mlir::UnknownLoc::get(getStaticDummyContext()),
mlir::remark::RemarkOpts::name(kEmptyKey));
}

/// Create a dead remark
static inline mlir::remark::detail::Remark getTombstoneKey() {
return mlir::remark::detail::Remark(
mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note,
mlir::UnknownLoc::get(getStaticDummyContext()),
mlir::remark::RemarkOpts::name(kTombstoneKey));
}

/// Compute the hash value of the remark
static unsigned getHashValue(const mlir::remark::detail::Remark &remark) {
return llvm::hash_combine(
remark.getLocation().getAsOpaquePointer(),
llvm::hash_value(remark.getRemarkName()),
llvm::hash_value(remark.getCombinedCategoryName()),
static_cast<unsigned>(remark.getRemarkType()));
}

static bool isEqual(const mlir::remark::detail::Remark &lhs,
const mlir::remark::detail::Remark &rhs) {
// Check for empty/tombstone keys first
if (lhs.getRemarkName() == kEmptyKey ||
lhs.getRemarkName() == kTombstoneKey ||
rhs.getRemarkName() == kEmptyKey ||
rhs.getRemarkName() == kTombstoneKey) {
return lhs.getRemarkName() == rhs.getRemarkName();
}

// For regular remarks, compare key identifying fields
return lhs.getLocation() == rhs.getLocation() &&
lhs.getRemarkName() == rhs.getRemarkName() &&
lhs.getCombinedCategoryName() == rhs.getCombinedCategoryName() &&
lhs.getRemarkType() == rhs.getRemarkType();
}
};
} // namespace llvm

#endif // MLIR_IR_REMARKS_H
3 changes: 3 additions & 0 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ class MLIRContextImpl {
}
}
~MLIRContextImpl() {
// finalize remark engine before destroying anything else.
remarkEngine.reset();

for (auto typeMapping : registeredTypes)
typeMapping.second->~AbstractType();
for (auto attrMapping : registeredAttributes)
Expand Down
26 changes: 22 additions & 4 deletions mlir/lib/IR/Remarks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) {
void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
// Header: [Type] pass:remarkName
StringRef type = getRemarkTypeString();
StringRef categoryName = getFullCategoryName();
StringRef categoryName = getCombinedCategoryName();
StringRef name = remarkName;

os << '[' << type << "] ";
Expand Down Expand Up @@ -140,7 +140,7 @@ llvm::remarks::Remark Remark::generateRemark() const {
r.RemarkType = getRemarkType();
r.RemarkName = getRemarkName();
// MLIR does not use passes; instead, it has categories and sub-categories.
r.PassName = getFullCategoryName();
r.PassName = getCombinedCategoryName();
r.FunctionName = getFunction();
r.Loc = locLambda();
for (const Remark::Arg &arg : getArgs()) {
Expand All @@ -157,7 +157,7 @@ llvm::remarks::Remark Remark::generateRemark() const {

InFlightRemark::~InFlightRemark() {
if (remark && owner)
owner->report(std::move(*remark));
owner->report(*remark);
owner = nullptr;
}

Expand Down Expand Up @@ -225,7 +225,7 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc,
// RemarkEngine
//===----------------------------------------------------------------------===//

void RemarkEngine::report(const Remark &&remark) {
void RemarkEngine::reportImpl(const Remark &remark) {
// Stream the remark
if (remarkStreamer)
remarkStreamer->streamOptimizationRemark(remark);
Expand All @@ -235,7 +235,25 @@ void RemarkEngine::report(const Remark &&remark) {
emitRemark(remark.getLocation(), remark.getMsg());
}

void RemarkEngine::report(const Remark &remark) {
// Postponed remarks are deferred to the end of pipeline.
if (remark.isPostponed()) {
postponedRemarks.insert(remark);
return;
}

reportImpl(remark);
}

void RemarkEngine::emitPostponedRemarks() {
for (auto &remark : postponedRemarks)
reportImpl(remark);
postponedRemarks.clear();
}

RemarkEngine::~RemarkEngine() {
emitPostponedRemarks();

if (remarkStreamer)
remarkStreamer->finalize();
}
Expand Down
Loading
Loading