Skip to content

Commit

Permalink
[mlir][OpAsmFormat] Add support for an "else" group on optional elements
Browse files Browse the repository at this point in the history
The "else" group of an optional element is a collection of elements that get parsed/printed when the anchor of the main element group is *not* present. This is useful when there is a special syntax when an element is not present. The new syntax for an optional element is shown below:

```
optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?`
```

An example of how this might be used is shown below:

```tablegen
def FooOp : ... {
  let arguments = (ins UnitAttr:$foo);

  let assemblyFormat = "attr-dict (`foo_is_present` $foo^):(`foo_is_absent`)?";
}
```

would be formatted as such:

```mlir
// When the `foo` attribute is present:
foo.op foo_is_present

// When the `foo` attribute is not present:
foo.op foo_is_absent
```

Differential Revision: https://reviews.llvm.org/D99129
  • Loading branch information
River707 committed Mar 23, 2021
1 parent 0524a09 commit 6d6fe9c
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 24 deletions.
35 changes: 33 additions & 2 deletions mlir/docs/OpDefinitions.md
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,13 @@ When a variable is optional, the provided value may be null.
In certain situations operations may have "optional" information, e.g.
attributes or an empty set of variadic operands. In these situations a section
of the assembly format can be marked as `optional` based on the presence of this
information. An optional group is defined by wrapping a set of elements within
`()` followed by a `?` and has the following requirements:
information. An optional group is defined as follows:

```
optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?`
```

The `elements` of an optional group have the following requirements:

* The first element of the group must either be a attribute, literal, operand,
or region.
Expand Down Expand Up @@ -837,6 +842,32 @@ foo.op is_read_only
foo.op
```

##### Optional "else" Group

Optional groups also have support for an "else" group of elements. These are
elements that are parsed/printed if the `anchor` element of the optional group
is *not* present. Unlike the main element group, the "else" group has no
restriction on the first element and none of the elements may act as the
`anchor` for the optional. An example is shown below:

```tablegen
def FooOp : ... {
let arguments = (ins UnitAttr:$foo);
let assemblyFormat = "attr-dict (`foo_is_present` $foo^):(`foo_is_absent`)?";
}
```

would be formatted as such:

```mlir
// When the `foo` attribute is present:
foo.op foo_is_present
// When the `foo` attribute is not present:
foo.op foo_is_absent
```

#### Requirements

The format specification has a certain set of requirements that must be adhered
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,11 @@ def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> {
let assemblyFormat = "($attr^)? attr-dict";
}

def FormatOptionalWithElse : TEST_Op<"format_optional_else"> {
let arguments = (ins UnitAttr:$isFirstBranchPresent);
let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict";
}

//===----------------------------------------------------------------------===//
// Custom Directives

Expand Down
12 changes: 12 additions & 0 deletions mlir/test/mlir-tblgen/op-format-spec.td
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,18 @@ def OptionalInvalidL : TestFormat_Op<[{
def OptionalInvalidM : TestFormat_Op<[{
(` `^)?
}]>, Arguments<(ins)>;
// CHECK: error: expected '(' to start else branch of optional group
def OptionalInvalidN : TestFormat_Op<[{
($arg^):
}]>, Arguments<(ins Variadic<I64>:$arg)>;
// CHECK: error: expected directive, literal, variable, or optional group
def OptionalInvalidO : TestFormat_Op<[{
($arg^):(`test`
}]>, Arguments<(ins Variadic<I64>:$arg)>;
// CHECK: error: expected '?' after optional group
def OptionalInvalidP : TestFormat_Op<[{
($arg^):(`test`)
}]>, Arguments<(ins Variadic<I64>:$arg)>;

// CHECK-NOT: error
def OptionalValidA : TestFormat_Op<[{
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/mlir-tblgen/op-format.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ test.format_optional_result_b_op : i64 -> i64, i64
// CHECK: test.format_optional_result_c_op : (i64) -> (i64, i64)
test.format_optional_result_c_op : (i64) -> (i64, i64)

//===----------------------------------------------------------------------===//
// Format optional with else
//===----------------------------------------------------------------------===//

// CHECK: test.format_optional_else then
test.format_optional_else then

// CHECK: test.format_optional_else else
test.format_optional_else else

//===----------------------------------------------------------------------===//
// Format custom directives
//===----------------------------------------------------------------------===//
Expand Down
105 changes: 83 additions & 22 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,29 +348,41 @@ class SpaceElement : public WhitespaceElement {

namespace {
/// This class represents a group of elements that are optionally emitted based
/// upon an optional variable of the operation.
/// upon an optional variable of the operation, and a group of elements that are
/// emotted when the anchor element is not present.
class OptionalElement : public Element {
public:
OptionalElement(std::vector<std::unique_ptr<Element>> &&elements,
OptionalElement(std::vector<std::unique_ptr<Element>> &&thenElements,
std::vector<std::unique_ptr<Element>> &&elseElements,
unsigned anchor, unsigned parseStart)
: Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor),
: Element{Kind::Optional}, thenElements(std::move(thenElements)),
elseElements(std::move(elseElements)), anchor(anchor),
parseStart(parseStart) {}
static bool classof(const Element *element) {
return element->getKind() == Kind::Optional;
}

/// Return the nested elements of this grouping.
auto getElements() const { return llvm::make_pointee_range(elements); }
/// Return the `then` elements of this grouping.
auto getThenElements() const {
return llvm::make_pointee_range(thenElements);
}

/// Return the `else` elements of this grouping.
auto getElseElements() const {
return llvm::make_pointee_range(elseElements);
}

/// Return the anchor of this optional group.
Element *getAnchor() const { return elements[anchor].get(); }
Element *getAnchor() const { return thenElements[anchor].get(); }

/// Return the index of the first element that needs to be parsed.
unsigned getParseStart() const { return parseStart; }

private:
/// The child elements of this optional.
std::vector<std::unique_ptr<Element>> elements;
/// The child elements of `then` branch of this optional.
std::vector<std::unique_ptr<Element>> thenElements;
/// The child elements of `else` branch of this optional.
std::vector<std::unique_ptr<Element>> elseElements;
/// The index of the element that acts as the anchor for the optional group.
unsigned anchor;
/// The index of the first element that is parsed (is not a
Expand Down Expand Up @@ -792,7 +804,7 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
/// Generate the storage code required for parsing the given element.
static void genElementParserStorage(Element *element, OpMethodBody &body) {
if (auto *optional = dyn_cast<OptionalElement>(element)) {
auto elements = optional->getElements();
auto elements = optional->getThenElements();

// If the anchor is a unit attribute, it won't be parsed directly so elide
// it.
Expand All @@ -803,6 +815,8 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
for (auto &childElement : elements)
if (&childElement != elidedAnchorElement)
genElementParserStorage(&childElement, body);
for (auto &childElement : optional->getElseElements())
genElementParserStorage(&childElement, body);

} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
for (auto &paramElement : custom->getArguments())
Expand Down Expand Up @@ -1094,8 +1108,8 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
FmtContext &attrTypeCtx) {
/// Optional Group.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
auto elements =
llvm::drop_begin(optional->getElements(), optional->getParseStart());
auto elements = llvm::drop_begin(optional->getThenElements(),
optional->getParseStart());

// Generate a special optional parser for the first element to gate the
// parsing of the rest of the elements.
Expand Down Expand Up @@ -1140,7 +1154,17 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
if (&childElement != elidedAnchorElement)
genElementParser(&childElement, body, attrTypeCtx);
}
body << " }\n";
body << " }";

// Generate the else elements.
auto elseElements = optional->getElseElements();
if (!elseElements.empty()) {
body << " else {\n";
for (Element &childElement : elseElements)
genElementParser(&childElement, body, attrTypeCtx);
body << " }";
}
body << "\n";

/// Literals.
} else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
Expand Down Expand Up @@ -1778,7 +1802,7 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,

// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
auto elements = optional->getElements();
auto elements = optional->getThenElements();
Element *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
if (anchorAttr && anchorAttr != &*elements.begin() &&
Expand All @@ -1793,7 +1817,20 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
lastWasPunctuation);
}
}
body << " }\n";
body << " }";

// Emit each of the else elements.
auto elseElements = optional->getElseElements();
if (!elseElements.empty()) {
body << " else {\n";
for (Element &childElement : elseElements) {
genElementPrinter(&childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
body << " }";
}

body << "\n";
return;
}

Expand Down Expand Up @@ -1911,6 +1948,7 @@ class Token {
l_paren,
r_paren,
caret,
colon,
comma,
equal,
less,
Expand Down Expand Up @@ -2065,6 +2103,8 @@ Token FormatLexer::lexToken() {
// Lex punctuation.
case '^':
return formToken(Token::caret, tokStart);
case ':':
return formToken(Token::colon, tokStart);
case ',':
return formToken(Token::comma, tokStart);
case '=':
Expand Down Expand Up @@ -2393,8 +2433,11 @@ LogicalResult FormatParser::verifyAttributes(

// Traverse into optional groups.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
auto elements = optional->getElements();
iteratorStack.emplace_back(elements.begin(), elements.end());
auto thenElements = optional->getThenElements();
iteratorStack.emplace_back(thenElements.begin(), thenElements.end());

auto elseElements = optional->getElseElements();
iteratorStack.emplace_back(elseElements.begin(), elseElements.end());
return ::mlir::success();
}

Expand Down Expand Up @@ -2795,13 +2838,31 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
consumeToken();

// Parse the child elements for this optional group.
std::vector<std::unique_ptr<Element>> elements;
std::vector<std::unique_ptr<Element>> thenElements, elseElements;
Optional<unsigned> anchorIdx;
do {
if (failed(parseOptionalChildElement(elements, anchorIdx)))
if (failed(parseOptionalChildElement(thenElements, anchorIdx)))
return ::mlir::failure();
} while (curToken.getKind() != Token::r_paren);
consumeToken();

// Parse the `else` elements of this optional group.
if (curToken.getKind() == Token::colon) {
consumeToken();
if (failed(parseToken(Token::l_paren, "expected '(' to start else branch "
"of optional group")))
return failure();
do {
llvm::SMLoc childLoc = curToken.getLoc();
elseElements.push_back({});
if (failed(parseElement(elseElements.back(), TopLevelContext)) ||
failed(verifyOptionalChildElement(elseElements.back().get(), childLoc,
/*isAnchor=*/false)))
return failure();
} while (curToken.getKind() != Token::r_paren);
consumeToken();
}

if (failed(parseToken(Token::question, "expected '?' after optional group")))
return ::mlir::failure();

Expand All @@ -2811,7 +2872,7 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,

// The first parsable element of the group must be able to be parsed in an
// optional fashion.
auto parseBegin = llvm::find_if_not(elements, [](auto &element) {
auto parseBegin = llvm::find_if_not(thenElements, [](auto &element) {
return isa<WhitespaceElement>(element.get());
});
Element *firstElement = parseBegin->get();
Expand All @@ -2822,9 +2883,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
"first parsable element of an operand group must be "
"an attribute, literal, operand, or region");

auto parseStart = parseBegin - elements.begin();
element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx,
parseStart);
auto parseStart = parseBegin - thenElements.begin();
element = std::make_unique<OptionalElement>(
std::move(thenElements), std::move(elseElements), *anchorIdx, parseStart);
return ::mlir::success();
}

Expand Down

0 comments on commit 6d6fe9c

Please sign in to comment.