diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 96c8f3d79c5e0..b6b9e0e0ef239 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1156,17 +1156,32 @@ class AsmParser { virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result, Type type = {}) = 0; + /// Parse an optional attribute that must begin with a '#' hash identifier + /// (i.e. `#dialect.mnemonic<...>`). Returns std::nullopt if the next token + /// is not a hash identifier, without consuming any tokens. + virtual OptionalParseResult parseOptionalHashAttribute(Attribute &result, + Type type = {}) = 0; + /// Parse an optional attribute of a specific typed result. This overload /// handles concrete attribute types (e.g. FloatAttr) that are not covered by /// a dedicated virtual overload. It parses any attribute and then validates /// that the result is of the expected type, emitting an error if not. + /// + /// For attribute types that define a custom `parse` method (i.e. AttrDef + /// types with assemblyFormat), parsing is restricted to `#`-prefixed + /// attribute syntax to avoid greedily consuming tokens (like `@symbol`) + /// that are intended for later format elements. template < typename AttrType, typename = std::enable_if_t::value>> OptionalParseResult parseOptionalAttribute(AttrType &result, Type type = {}) { Attribute attr; - OptionalParseResult parseResult = parseOptionalAttribute(attr, type); + OptionalParseResult parseResult; + if constexpr (detect_has_parse_method::value) + parseResult = parseOptionalHashAttribute(attr, type); + else + parseResult = parseOptionalAttribute(attr, type); if (!parseResult.has_value() || failed(*parseResult)) return parseResult; result = dyn_cast(attr); diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 3c60287a09a0f..e26a12da9aada 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -458,6 +458,13 @@ class AsmParserImpl : public BaseT { Type type) override { return parser.parseOptionalAttribute(result, type); } + OptionalParseResult parseOptionalHashAttribute(Attribute &result, + Type type) override { + if (parser.getToken().isNot(Token::hash_identifier)) + return std::nullopt; + result = parser.parseAttribute(type); + return success(static_cast(result)); + } OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type) override { return parser.parseOptionalAttribute(result, type); diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td index 35a34b493035a..3ba52bff37c28 100644 --- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td @@ -477,6 +477,16 @@ def FormatOptionalPropDict : TEST_Op<"format_optional_prop_dict"> { let assemblyFormat = "prop-dict attr-dict"; } +// Test that an optional AttrDef attribute (with a custom parse method) does not +// greedily consume tokens meant for a later symbol name in the format. +def FormatOptionalAttrDefBeforeSymbol + : TEST_Op<"format_opt_attrdef_before_symbol", [Symbol]> { + let arguments = (ins + OptionalAttr:$opt_compound, + SymbolNameAttr:$sym_name); + let assemblyFormat = "(qualified($opt_compound)^)? $sym_name attr-dict"; +} + def FormatCompoundAttr : TEST_Op<"format_compound_attr"> { let arguments = (ins CompoundAttrA:$compound); let assemblyFormat = "$compound attr-dict-with-keyword"; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir index 7ff9091d5500d..1fb130cf2a3d5 100644 --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -306,6 +306,19 @@ test.format_optional_prop_dict <{b = 2 : i32}> // CHECK: test.format_optional_prop_dict <{a = ["foo"], b = 2 : i32}> test.format_optional_prop_dict <{a = ["foo"], b = 2 : i32}> +//===----------------------------------------------------------------------===// +// Optional AttrDef attribute before symbol name +//===----------------------------------------------------------------------===// + +// Verify that an optional AttrDef attribute does not greedily consume the @ +// token when the attribute is absent. +// CHECK: test.format_opt_attrdef_before_symbol @opt_attrdef_sym1 +test.format_opt_attrdef_before_symbol @opt_attrdef_sym1 + +// Verify round-trip when the optional AttrDef attribute is present (full form). +// CHECK: test.format_opt_attrdef_before_symbol #test.cmpnd_nested> @opt_attrdef_sym2 +test.format_opt_attrdef_before_symbol #test.cmpnd_nested> @opt_attrdef_sym2 + //===----------------------------------------------------------------------===// // Format a custom attribute //===----------------------------------------------------------------------===//