diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 4a8cc562cb52db..f02470891fcfbc 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -292,6 +292,12 @@ class DictionaryAttr /// Requires: uniquely named attributes. static bool sortInPlace(SmallVectorImpl &array); + /// Returns an entry with a duplicate name in `array`, if it exists, else + /// returns llvm::None. If `isSorted` is true, the array is assumed to be + /// sorted else it will be sorted in place before finding the duplicate entry. + static Optional + findDuplicate(SmallVectorImpl &array, bool isSorted); + private: /// Return empty dictionary. static DictionaryAttr getEmpty(MLIRContext *context); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 563a906ec803d3..96d6d1194b60ba 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -262,6 +262,10 @@ class NamedAttrList { /// Pop last element from list. void pop_back() { attrs.pop_back(); } + /// Returns an entry with a duplicate name the list, if it exists, else + /// returns llvm::None. + Optional findDuplicate() const; + /// Return a dictionary attribute for the underlying dictionary. This will /// return an empty dictionary attribute if empty rather than null. DictionaryAttr getDictionary(MLIRContext *context) const; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 59a66d0c342595..37c4edb7322fb8 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -99,8 +99,6 @@ static bool dictionaryAttrSort(ArrayRef value, storage.assign({value[0]}); break; case 2: { - assert(value[0].first != value[1].first && - "DictionaryAttr element names must be unique"); bool isSorted = value[0] < value[1]; if (inPlace) { if (!isSorted) @@ -122,25 +120,49 @@ static bool dictionaryAttrSort(ArrayRef value, llvm::array_pod_sort(storage.begin(), storage.end()); value = storage; } - - // Ensure that the attribute elements are unique. - assert(std::adjacent_find(value.begin(), value.end(), - [](NamedAttribute l, NamedAttribute r) { - return l.first == r.first; - }) == value.end() && - "DictionaryAttr element names must be unique"); return !isSorted; } return false; } +/// Returns an entry with a duplicate name from the given sorted array of named +/// attributes. Returns llvm::None if all elements have unique names. +static Optional +findDuplicateElement(ArrayRef value) { + const Optional none{llvm::None}; + if (value.size() < 2) + return none; + + if (value.size() == 2) + return value[0].first == value[1].first ? value[0] : none; + + auto it = std::adjacent_find( + value.begin(), value.end(), + [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); + return it != value.end() ? *it : none; +} + bool DictionaryAttr::sort(ArrayRef value, SmallVectorImpl &storage) { - return dictionaryAttrSort(value, storage); + bool isSorted = dictionaryAttrSort(value, storage); + assert(!findDuplicateElement(storage) && + "DictionaryAttr element names must be unique"); + return isSorted; } bool DictionaryAttr::sortInPlace(SmallVectorImpl &array) { - return dictionaryAttrSort(array, array); + bool isSorted = dictionaryAttrSort(array, array); + assert(!findDuplicateElement(array) && + "DictionaryAttr element names must be unique"); + return isSorted; +} + +Optional +DictionaryAttr::findDuplicate(SmallVectorImpl &array, + bool isSorted) { + if (!isSorted) + dictionaryAttrSort(array, array); + return findDuplicateElement(array); } DictionaryAttr DictionaryAttr::get(ArrayRef value, @@ -155,7 +177,8 @@ DictionaryAttr DictionaryAttr::get(ArrayRef value, SmallVector storage; if (dictionaryAttrSort(value, storage)) value = storage; - + assert(!findDuplicateElement(value) && + "DictionaryAttr element names must be unique"); return Base::get(context, value); } /// Construct a dictionary with an array of values that is known to already be @@ -170,10 +193,7 @@ DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef value, return l.first.strref() < r.first.strref(); }) && "expected attribute values to be sorted"); - assert(std::adjacent_find(value.begin(), value.end(), - [](NamedAttribute l, NamedAttribute r) { - return l.first == r.first; - }) == value.end() && + assert(!findDuplicateElement(value) && "DictionaryAttr element names must be unique"); return Base::get(context, value); } diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 0d72ea5f0ea99e..c7aa0e323088d1 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -32,6 +32,16 @@ NamedAttrList::NamedAttrList(const_iterator in_start, const_iterator in_end) { ArrayRef NamedAttrList::getAttrs() const { return attrs; } +Optional NamedAttrList::findDuplicate() const { + Optional duplicate = + DictionaryAttr::findDuplicate(attrs, isSorted()); + // DictionaryAttr::findDuplicate will sort the list, so reset the sorted + // state. + if (!isSorted()) + dictionarySorted.setPointerAndInt(nullptr, true); + return duplicate; +} + DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const { if (!isSorted()) { DictionaryAttr::sortInPlace(attrs); diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp index 6234d3b596600b..e8f5213768bd8d 100644 --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -253,7 +253,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { else return emitError("expected attribute name"); if (!seenKeys.insert(*nameId).second) - return emitError("duplicate key in dictionary attribute"); + return emitError("duplicate key '") + << *nameId << "' in dictionary attribute"; consumeToken(); // Lazy load a dialect in the context if there is a possible namespace. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e89972007156fb..a824687aefb292 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -846,6 +846,15 @@ class CustomOpAsmParser : public OpAsmParser { ParseResult parseOperation(OperationState &opState) { if (opDefinition->parseAssembly(*this, opState)) return failure(); + // Verify that the parsed attributes does not have duplicate attributes. + // This can happen if an attribute set during parsing is also specified in + // the attribute dictionary in the assembly, or the attribute is set + // multiple during parsing. + Optional duplicate = opState.attributes.findDuplicate(); + if (duplicate) + return emitError(getNameLoc(), "attribute '") + << duplicate->first + << "' occurs more than once in the attribute list"; return success(); } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index ae686365dded7f..c8ed517434b213 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1513,12 +1513,17 @@ func @really_large_bound() { // ----- func @duplicate_dictionary_attr_key() { - // expected-error @+1 {{duplicate key in dictionary attribute}} + // expected-error @+1 {{duplicate key 'a' in dictionary attribute}} "foo.op"() {a, a} : () -> () } // ----- +// expected-error @+1 {{attribute 'attr' occurs more than once in the attribute list}} +test.format_symbol_name_attr_op @name { attr = "xx" } + +// ----- + func @forward_reference_type_check() -> (i8) { br ^bb2