diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 726152f8a7b3f..cf8d3370da1b9 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -199,7 +199,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> { $allocate_vars, type($allocate_vars), $allocators_vars, type($allocators_vars) ) `)` - | `nowait` + | `nowait` $nowait ) $region attr-dict }]; @@ -438,7 +438,7 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> { oilist( `if` `(` $if_expr `)` | `device` `(` $device `:` type($device) `)` | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)` - | `nowait` + | `nowait` $nowait ) $region attr-dict }]; } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 2bd0ca6f3f020..bcd76797413d4 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -498,6 +498,10 @@ func @succeededOilistTrivial() { test.oilist_with_keywords_only keyword otherKeyword // CHECK: test.oilist_with_keywords_only keyword otherKeyword test.oilist_with_keywords_only otherKeyword keyword + // CHECK: test.oilist_with_keywords_only thirdKeyword + test.oilist_with_keywords_only thirdKeyword + // CHECK: test.oilist_with_keywords_only keyword thirdKeyword + test.oilist_with_keywords_only keyword thirdKeyword return } @@ -550,7 +554,7 @@ func @succeededOilistCustom(%arg0: i32, %arg1: i32, %arg2: i32) { test.oilist_custom private (%arg0, %arg1 : i32, i32) // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) nowait test.oilist_custom private (%arg0, %arg1 : i32, i32) nowait - // CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) nowait reduction (%arg1) + // CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) reduction (%arg1) nowait test.oilist_custom nowait reduction (%arg1) private (%arg0, %arg1 : i32, i32) return } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index b675a5515c994..5bb397353af28 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -656,9 +656,12 @@ def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">; // Ops related to OIList primitive def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> { + let arguments = (ins UnitAttr:$keyword, UnitAttr:$otherKeyword, + UnitAttr:$diffNameUnitAttrKeyword); let assemblyFormat = [{ - oilist( `keyword` - | `otherKeyword`) attr-dict + oilist( `keyword` $keyword + | `otherKeyword` $otherKeyword + | `thirdKeyword` $diffNameUnitAttrKeyword) attr-dict }]; } @@ -690,8 +693,8 @@ def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> { UnitAttr:$nowait); let assemblyFormat = [{ oilist( `private` `(` $arg0 `:` type($arg0) `)` - | `nowait` | `reduction` custom($optOperand) + | `nowait` $nowait ) attr-dict }]; } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 3bf9ad99a34b8..fb54dcb3d85f8 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -207,6 +207,18 @@ class OIListElement : public DirectiveElementBase { return llvm::zip(getLiteralElements(), getParsingElements()); } + /// If the parsing element is a single UnitAttr element, then it returns the + /// attribute variable. Otherwise, returns nullptr. + AttributeVariable * + getUnitAttrParsingElement(ArrayRef pelement) { + if (pelement.size() == 1) { + auto attrElem = dyn_cast(pelement[0]); + if (attrElem && attrElem->isUnitAttr()) + return attrElem; + } + return nullptr; + } + private: /// A vector of `LiteralElement` objects. Each element stores the keyword /// for one case of oilist element. For example, an oilist element along with @@ -684,7 +696,6 @@ const char *oilistParserCode = R"( "oilist directive"; } {0}Clause = true; - result.addAttribute("{0}", UnitAttr::get(parser.getContext())); )"; namespace { @@ -778,9 +789,11 @@ static void genElementParserStorage(FormatElement *element, const Operator &op, genElementParserStorage(childElement, op, body); } else if (auto *oilist = dyn_cast(element)) { - for (ArrayRef pelement : oilist->getParsingElements()) - for (FormatElement *element : pelement) - genElementParserStorage(element, op, body); + for (ArrayRef pelement : oilist->getParsingElements()) { + if (!oilist->getUnitAttrParsingElement(pelement)) + for (FormatElement *element : pelement) + genElementParserStorage(element, op, body); + } } else if (auto *custom = dyn_cast(element)) { for (FormatElement *paramElement : custom->getArguments()) @@ -1180,11 +1193,16 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, body << "if (succeeded(parser.parseOptional"; genLiteralParser(lelement->getSpelling(), body); body << ")) {\n"; - StringRef attrName = lelement->getSpelling(); - body << formatv(oilistParserCode, attrName); - inferredAttributes.insert(attrName); - for (FormatElement *el : pelement) - genElementParser(el, body, attrTypeCtx); + StringRef lelementName = lelement->getSpelling(); + body << formatv(oilistParserCode, lelementName); + if (AttributeVariable *unitAttrElem = + oilist->getUnitAttrParsingElement(pelement)) { + body << " result.addAttribute(\"" << unitAttrElem->getVar()->name + << "\", UnitAttr::get(parser.getContext()));\n"; + } else { + for (FormatElement *el : pelement) + genElementParser(el, body, attrTypeCtx); + } body << " } else "; } body << " {\n"; @@ -1873,6 +1891,31 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor, }); } +void collect(FormatElement *element, + SmallVectorImpl &variables) { + TypeSwitch(element) + .Case([&](VariableElement *var) { variables.emplace_back(var); }) + .Case([&](CustomDirective *ele) { + for (FormatElement *arg : ele->getArguments()) + collect(arg, variables); + }) + .Case([&](OptionalElement *ele) { + for (FormatElement *arg : ele->getThenElements()) + collect(arg, variables); + for (FormatElement *arg : ele->getElseElements()) + collect(arg, variables); + }) + .Case([&](FunctionalTypeDirective *funcType) { + collect(funcType->getInputs(), variables); + collect(funcType->getResults(), variables); + }) + .Case([&](OIListElement *oilist) { + for (ArrayRef arg : oilist->getParsingElements()) + for (FormatElement *arg_ : arg) + collect(arg_, variables); + }); +} + void OperationFormat::genElementPrinter(FormatElement *element, MethodBody &body, Operator &op, bool &shouldEmitSpace, @@ -1939,13 +1982,44 @@ void OperationFormat::genElementPrinter(FormatElement *element, LiteralElement *lelement = std::get<0>(clause); ArrayRef pelement = std::get<1>(clause); - body << " if ((*this)->hasAttrOfType(\"" - << lelement->getSpelling() << "\")) {\n"; + SmallVector vars; + for (FormatElement *el : pelement) + collect(el, vars); + body << " if (false"; + for (VariableElement *var : vars) { + TypeSwitch(var) + .Case([&](AttributeVariable *attrEle) { + body << " || " << op.getGetterName(attrEle->getVar()->name) + << "Attr()"; + }) + .Case([&](OperandVariable *ele) { + if (ele->getVar()->isVariadic()) { + body << " || " << op.getGetterName(ele->getVar()->name) + << "().size()"; + } else { + body << " || " << op.getGetterName(ele->getVar()->name) << "()"; + } + }) + .Case([&](ResultVariable *ele) { + if (ele->getVar()->isVariadic()) { + body << " || " << op.getGetterName(ele->getVar()->name) + << "().size()"; + } else { + body << " || " << op.getGetterName(ele->getVar()->name) << "()"; + } + }) + .Case([&](RegionVariable *reg) { + body << " || " << op.getGetterName(reg->getVar()->name) << "()"; + }); + } + + body << ") {\n"; genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace, lastWasPunctuation); - for (FormatElement *element : pelement) { - genElementPrinter(element, body, op, shouldEmitSpace, - lastWasPunctuation); + if (oilist->getUnitAttrParsingElement(pelement) == nullptr) { + for (FormatElement *element : pelement) + genElementPrinter(element, body, op, shouldEmitSpace, + lastWasPunctuation); } body << " }\n"; } @@ -2866,51 +2940,45 @@ OpFormatParser::parseOIListDirective(SMLoc loc, Context context) { LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element, SMLoc loc) { - return TypeSwitch(element) - // Only optional attributes can be within an oilist parsing group. - .Case([&](AttributeVariable *attrEle) { - if (!attrEle->getVar()->attr.isOptional()) - return emitError(loc, "only optional attributes can be used to " - "in an oilist parsing group"); - return success(); - }) - // Only optional-like(i.e. variadic) operands can be within an oilist - // parsing group. - .Case([&](OperandVariable *ele) { - if (!ele->getVar()->isVariableLength()) - return emitError(loc, "only variable length operands can be " - "used within an oilist parsing group"); - return success(); - }) - // Only optional-like(i.e. variadic) results can be within an oilist - // parsing group. - .Case([&](ResultVariable *ele) { - if (!ele->getVar()->isVariableLength()) - return emitError(loc, "only variable length results can be " - "used within an oilist parsing group"); - return success(); - }) - .Case([&](RegionVariable *) { - // TODO: When ODS has proper support for marking "optional" regions, add - // a check here. - return success(); - }) - .Case([&](TypeDirective *ele) { - return verifyOIListParsingElement(ele->getArg(), loc); - }) - .Case([&](FunctionalTypeDirective *ele) { - if (failed(verifyOIListParsingElement(ele->getInputs(), loc))) - return failure(); - return verifyOIListParsingElement(ele->getResults(), loc); - }) - // Literals, whitespace, and custom directives may be used. - .Case( - [&](FormatElement *) { return success(); }) - .Default([&](FormatElement *) { - return emitError(loc, "only literals, types, and variables can be " - "used within an oilist group"); - }); + SmallVector vars; + collect(element, vars); + for (VariableElement *elem : vars) { + LogicalResult res = + TypeSwitch(elem) + // Only optional attributes can be within an oilist parsing group. + .Case([&](AttributeVariable *attrEle) { + if (!attrEle->getVar()->attr.isOptional() && + !attrEle->getVar()->attr.hasDefaultValue()) + return emitError(loc, "only optional attributes can be used in " + "an oilist parsing group"); + return success(); + }) + // Only optional-like(i.e. variadic) operands can be within an + // oilist parsing group. + .Case([&](OperandVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(loc, "only variable length operands can be " + "used within an oilist parsing group"); + return success(); + }) + // Only optional-like(i.e. variadic) results can be within an oilist + // parsing group. + .Case([&](ResultVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(loc, "only variable length results can be " + "used within an oilist parsing group"); + return success(); + }) + .Case([&](RegionVariable *) { return success(); }) + .Default([&](FormatElement *) { + return emitError(loc, + "only literals, types, and variables can be " + "used within an oilist group"); + }); + if (failed(res)) + return failure(); + } + return success(); } FailureOr OpFormatParser::parseTypeDirective(SMLoc loc,