diff --git a/mlir/include/mlir/IR/Remarks.h b/mlir/include/mlir/IR/Remarks.h index 9877926116e24..3102542731b33 100644 --- a/mlir/include/mlir/IR/Remarks.h +++ b/mlir/include/mlir/IR/Remarks.h @@ -99,18 +99,30 @@ class Remark { } // Remark argument that is a key-value pair that can be printed as machine - // parsable args. + // parsable args. For Attribute arguments, the original attribute is also + // stored to allow custom streamers to handle them specially. struct Arg { std::string key; std::string val; + /// Optional attribute storage for Attribute-based args. Allows streamers + /// to access the original attribute for custom handling. + std::optional attr; + Arg(llvm::StringRef m) : key("Remark"), val(m) {} Arg(llvm::StringRef k, llvm::StringRef v) : key(k), val(v) {} Arg(llvm::StringRef k, std::string v) : key(k), val(std::move(v)) {} Arg(llvm::StringRef k, const char *v) : Arg(k, llvm::StringRef(v)) {} Arg(llvm::StringRef k, Value v); Arg(llvm::StringRef k, Type t); + Arg(llvm::StringRef k, Attribute a); Arg(llvm::StringRef k, bool b) : key(k), val(b ? "true" : "false") {} + /// Check if this arg has an associated attribute. + bool hasAttribute() const { return attr.has_value(); } + + /// Get the attribute if present. + Attribute getAttribute() const { return attr.value_or(Attribute()); } + // One constructor for all arithmetic types except bool. template && !std::is_same_v>> diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp index 031eae22af7f2..4cce16b172d80 100644 --- a/mlir/lib/IR/Remarks.cpp +++ b/mlir/lib/IR/Remarks.cpp @@ -31,6 +31,11 @@ Remark::Arg::Arg(llvm::StringRef k, Type t) : key(k) { os << t; } +Remark::Arg::Arg(llvm::StringRef k, Attribute a) : key(k), attr(a) { + llvm::raw_string_ostream os(val); + os << a; +} + void Remark::insert(llvm::StringRef s) { args.emplace_back(s); } void Remark::insert(Arg a) { args.push_back(std::move(a)); } diff --git a/mlir/unittests/IR/RemarkTest.cpp b/mlir/unittests/IR/RemarkTest.cpp index 94753c10a9a93..f33d3caebad37 100644 --- a/mlir/unittests/IR/RemarkTest.cpp +++ b/mlir/unittests/IR/RemarkTest.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Remarks.h" @@ -377,4 +379,35 @@ TEST(Remark, TestRemarkFinal) { EXPECT_NE(errOut.find(pass3Msg), std::string::npos); // shown EXPECT_NE(errOut.find(pass4Msg), std::string::npos); // shown } + +TEST(Remark, TestArgWithAttribute) { + MLIRContext context; + + SmallVector elements; + elements.push_back(IntegerAttr::get(IntegerType::get(&context, 32), 1)); + elements.push_back(IntegerAttr::get(IntegerType::get(&context, 32), 2)); + elements.push_back(IntegerAttr::get(IntegerType::get(&context, 32), 3)); + ArrayAttr arrayAttr = ArrayAttr::get(&context, elements); + remark::detail::Remark::Arg argWithArray("Values", arrayAttr); + + // Verify the attribute is stored + EXPECT_TRUE(argWithArray.hasAttribute()); + EXPECT_EQ(argWithArray.getAttribute(), arrayAttr); + + // Ensure it can be retrieved as an ArrayAttr. + auto retrievedAttr = dyn_cast(argWithArray.getAttribute()); + EXPECT_TRUE(retrievedAttr); + EXPECT_EQ(retrievedAttr.size(), 3u); + EXPECT_EQ(cast(retrievedAttr[0]).getInt(), 1); + EXPECT_EQ(cast(retrievedAttr[1]).getInt(), 2); + EXPECT_EQ(cast(retrievedAttr[2]).getInt(), 3); + + // Create an Arg without an Attribute (string-based) + remark::detail::Remark::Arg argWithoutAttr("Key", "Value"); + + // Verify no attribute is stored + EXPECT_FALSE(argWithoutAttr.hasAttribute()); + EXPECT_FALSE(argWithoutAttr.getAttribute()); // Returns null Attribute + EXPECT_EQ(argWithoutAttr.val, "Value"); +} } // namespace