diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md index 63b727ae428b7a..5f413582c698cf 100644 --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -772,8 +772,13 @@ When a variable is optional, the provided value may be null. In certain situations operations may have "optional" information, e.g. attributes or an empty set of variadic operands. In these situations a section of the assembly format can be marked as `optional` based on the presence of this -information. An optional group is defined by wrapping a set of elements within -`()` followed by a `?` and has the following requirements: +information. An optional group is defined as follows: + +``` +optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?` +``` + +The `elements` of an optional group have the following requirements: * The first element of the group must either be a attribute, literal, operand, or region. @@ -837,6 +842,32 @@ foo.op is_read_only foo.op ``` +##### Optional "else" Group + +Optional groups also have support for an "else" group of elements. These are +elements that are parsed/printed if the `anchor` element of the optional group +is *not* present. Unlike the main element group, the "else" group has no +restriction on the first element and none of the elements may act as the +`anchor` for the optional. An example is shown below: + +```tablegen +def FooOp : ... { + let arguments = (ins UnitAttr:$foo); + + let assemblyFormat = "attr-dict (`foo_is_present` $foo^):(`foo_is_absent`)?"; +} +``` + +would be formatted as such: + +```mlir +// When the `foo` attribute is present: +foo.op foo_is_present + +// When the `foo` attribute is not present: +foo.op foo_is_absent +``` + #### Requirements The format specification has a certain set of requirements that must be adhered diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 7d48f8d4547a95..8be84f2aacbc9b 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1651,6 +1651,11 @@ def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> { let assemblyFormat = "($attr^)? attr-dict"; } +def FormatOptionalWithElse : TEST_Op<"format_optional_else"> { + let arguments = (ins UnitAttr:$isFirstBranchPresent); + let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict"; +} + //===----------------------------------------------------------------------===// // Custom Directives diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td index 4f5ca63c4e72a3..8c6bb09f34a37e 100644 --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -390,6 +390,18 @@ def OptionalInvalidL : TestFormat_Op<[{ def OptionalInvalidM : TestFormat_Op<[{ (` `^)? }]>, Arguments<(ins)>; +// CHECK: error: expected '(' to start else branch of optional group +def OptionalInvalidN : TestFormat_Op<[{ + ($arg^): +}]>, Arguments<(ins Variadic:$arg)>; +// CHECK: error: expected directive, literal, variable, or optional group +def OptionalInvalidO : TestFormat_Op<[{ + ($arg^):(`test` +}]>, Arguments<(ins Variadic:$arg)>; +// CHECK: error: expected '?' after optional group +def OptionalInvalidP : TestFormat_Op<[{ + ($arg^):(`test`) +}]>, Arguments<(ins Variadic:$arg)>; // CHECK-NOT: error def OptionalValidA : TestFormat_Op<[{ diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir index 8043786faf0808..e6f998fa4ac39c 100644 --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -239,6 +239,16 @@ test.format_optional_result_b_op : i64 -> i64, i64 // CHECK: test.format_optional_result_c_op : (i64) -> (i64, i64) test.format_optional_result_c_op : (i64) -> (i64, i64) +//===----------------------------------------------------------------------===// +// Format optional with else +//===----------------------------------------------------------------------===// + +// CHECK: test.format_optional_else then +test.format_optional_else then + +// CHECK: test.format_optional_else else +test.format_optional_else else + //===----------------------------------------------------------------------===// // Format custom directives //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index f474bbfb4f2045..abf77a55004ecc 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -348,29 +348,41 @@ class SpaceElement : public WhitespaceElement { namespace { /// This class represents a group of elements that are optionally emitted based -/// upon an optional variable of the operation. +/// upon an optional variable of the operation, and a group of elements that are +/// emotted when the anchor element is not present. class OptionalElement : public Element { public: - OptionalElement(std::vector> &&elements, + OptionalElement(std::vector> &&thenElements, + std::vector> &&elseElements, unsigned anchor, unsigned parseStart) - : Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor), + : Element{Kind::Optional}, thenElements(std::move(thenElements)), + elseElements(std::move(elseElements)), anchor(anchor), parseStart(parseStart) {} static bool classof(const Element *element) { return element->getKind() == Kind::Optional; } - /// Return the nested elements of this grouping. - auto getElements() const { return llvm::make_pointee_range(elements); } + /// Return the `then` elements of this grouping. + auto getThenElements() const { + return llvm::make_pointee_range(thenElements); + } + + /// Return the `else` elements of this grouping. + auto getElseElements() const { + return llvm::make_pointee_range(elseElements); + } /// Return the anchor of this optional group. - Element *getAnchor() const { return elements[anchor].get(); } + Element *getAnchor() const { return thenElements[anchor].get(); } /// Return the index of the first element that needs to be parsed. unsigned getParseStart() const { return parseStart; } private: - /// The child elements of this optional. - std::vector> elements; + /// The child elements of `then` branch of this optional. + std::vector> thenElements; + /// The child elements of `else` branch of this optional. + std::vector> elseElements; /// The index of the element that acts as the anchor for the optional group. unsigned anchor; /// The index of the first element that is parsed (is not a @@ -792,7 +804,7 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) { /// Generate the storage code required for parsing the given element. static void genElementParserStorage(Element *element, OpMethodBody &body) { if (auto *optional = dyn_cast(element)) { - auto elements = optional->getElements(); + auto elements = optional->getThenElements(); // If the anchor is a unit attribute, it won't be parsed directly so elide // it. @@ -803,6 +815,8 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) { for (auto &childElement : elements) if (&childElement != elidedAnchorElement) genElementParserStorage(&childElement, body); + for (auto &childElement : optional->getElseElements()) + genElementParserStorage(&childElement, body); } else if (auto *custom = dyn_cast(element)) { for (auto ¶mElement : custom->getArguments()) @@ -1094,8 +1108,8 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body, FmtContext &attrTypeCtx) { /// Optional Group. if (auto *optional = dyn_cast(element)) { - auto elements = - llvm::drop_begin(optional->getElements(), optional->getParseStart()); + auto elements = llvm::drop_begin(optional->getThenElements(), + optional->getParseStart()); // Generate a special optional parser for the first element to gate the // parsing of the rest of the elements. @@ -1140,7 +1154,17 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body, if (&childElement != elidedAnchorElement) genElementParser(&childElement, body, attrTypeCtx); } - body << " }\n"; + body << " }"; + + // Generate the else elements. + auto elseElements = optional->getElseElements(); + if (!elseElements.empty()) { + body << " else {\n"; + for (Element &childElement : elseElements) + genElementParser(&childElement, body, attrTypeCtx); + body << " }"; + } + body << "\n"; /// Literals. } else if (LiteralElement *literal = dyn_cast(element)) { @@ -1778,7 +1802,7 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body, // If the anchor is a unit attribute, we don't need to print it. When // parsing, we will add this attribute if this group is present. - auto elements = optional->getElements(); + auto elements = optional->getThenElements(); Element *elidedAnchorElement = nullptr; auto *anchorAttr = dyn_cast(anchor); if (anchorAttr && anchorAttr != &*elements.begin() && @@ -1793,7 +1817,20 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body, lastWasPunctuation); } } - body << " }\n"; + body << " }"; + + // Emit each of the else elements. + auto elseElements = optional->getElseElements(); + if (!elseElements.empty()) { + body << " else {\n"; + for (Element &childElement : elseElements) { + genElementPrinter(&childElement, body, op, shouldEmitSpace, + lastWasPunctuation); + } + body << " }"; + } + + body << "\n"; return; } @@ -1911,6 +1948,7 @@ class Token { l_paren, r_paren, caret, + colon, comma, equal, less, @@ -2065,6 +2103,8 @@ Token FormatLexer::lexToken() { // Lex punctuation. case '^': return formToken(Token::caret, tokStart); + case ':': + return formToken(Token::colon, tokStart); case ',': return formToken(Token::comma, tokStart); case '=': @@ -2393,8 +2433,11 @@ LogicalResult FormatParser::verifyAttributes( // Traverse into optional groups. if (auto *optional = dyn_cast(element)) { - auto elements = optional->getElements(); - iteratorStack.emplace_back(elements.begin(), elements.end()); + auto thenElements = optional->getThenElements(); + iteratorStack.emplace_back(thenElements.begin(), thenElements.end()); + + auto elseElements = optional->getElseElements(); + iteratorStack.emplace_back(elseElements.begin(), elseElements.end()); return ::mlir::success(); } @@ -2795,13 +2838,31 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr &element, consumeToken(); // Parse the child elements for this optional group. - std::vector> elements; + std::vector> thenElements, elseElements; Optional anchorIdx; do { - if (failed(parseOptionalChildElement(elements, anchorIdx))) + if (failed(parseOptionalChildElement(thenElements, anchorIdx))) return ::mlir::failure(); } while (curToken.getKind() != Token::r_paren); consumeToken(); + + // Parse the `else` elements of this optional group. + if (curToken.getKind() == Token::colon) { + consumeToken(); + if (failed(parseToken(Token::l_paren, "expected '(' to start else branch " + "of optional group"))) + return failure(); + do { + llvm::SMLoc childLoc = curToken.getLoc(); + elseElements.push_back({}); + if (failed(parseElement(elseElements.back(), TopLevelContext)) || + failed(verifyOptionalChildElement(elseElements.back().get(), childLoc, + /*isAnchor=*/false))) + return failure(); + } while (curToken.getKind() != Token::r_paren); + consumeToken(); + } + if (failed(parseToken(Token::question, "expected '?' after optional group"))) return ::mlir::failure(); @@ -2811,7 +2872,7 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr &element, // The first parsable element of the group must be able to be parsed in an // optional fashion. - auto parseBegin = llvm::find_if_not(elements, [](auto &element) { + auto parseBegin = llvm::find_if_not(thenElements, [](auto &element) { return isa(element.get()); }); Element *firstElement = parseBegin->get(); @@ -2822,9 +2883,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr &element, "first parsable element of an operand group must be " "an attribute, literal, operand, or region"); - auto parseStart = parseBegin - elements.begin(); - element = std::make_unique(std::move(elements), *anchorIdx, - parseStart); + auto parseStart = parseBegin - thenElements.begin(); + element = std::make_unique( + std::move(thenElements), std::move(elseElements), *anchorIdx, parseStart); return ::mlir::success(); }