Skip to content

Commit

Permalink
[mlir][OpFormat] Add support for emitting newlines from the custom fo…
Browse files Browse the repository at this point in the history
…rmat of an operation

This revision adds a new `printNewline` hook to OpAsmPrinter that allows for printing a newline within the custom format of an operation, that is then indented to the start of the operation. Support for the declarative assembly format is also added, in the form of a `\n` literal.

Differential Revision: https://reviews.llvm.org/D93151
  • Loading branch information
River707 committed Dec 14, 2020
1 parent 0936655 commit c234b65
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 19 deletions.
24 changes: 24 additions & 0 deletions mlir/docs/OpDefinitions.md
Expand Up @@ -646,6 +646,30 @@ The following are the set of valid punctuation:

`:`, `,`, `=`, `<`, `>`, `(`, `)`, `{`, `}`, `[`, `]`, `->`, `?`, `+`, `*`

The following are valid whitespace punctuation:

`\n`, ` `

The `\n` literal emits a newline an indents to the start of the operation. An
example is shown below:

```tablegen
let assemblyFormat = [{
`{` `\n` ` ` ` ` `this_is_on_a_newline` `\n` `}` attr-dict
}];
```

```mlir
%results = my.operation {
this_is_on_a_newline
}
```

An empty literal \`\` may be used to remove a space that is inserted implicitly
after certain literal elements, such as `)`/`]`/etc. For example, "`]`" may
result in an output of `]` it is not the last element in the format. "`]` \`\`"
would trim the trailing space in this situation.

#### Variables

A variable is an entity that has been registered on the operation itself, i.e.
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/IR/OpImplementation.h
Expand Up @@ -36,6 +36,10 @@ class OpAsmPrinter {
virtual ~OpAsmPrinter();
virtual raw_ostream &getStream() const = 0;

/// Print a newline and indent the printer to the start of the current
/// operation.
virtual void printNewline() = 0;

/// Print implementations for various things an operation contains.
virtual void printOperand(Value value) = 0;
virtual void printOperand(Value value, raw_ostream &os) = 0;
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Expand Up @@ -429,6 +429,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
/// The following are hooks of `OpAsmPrinter` that are not necessary for
/// determining potential aliases.
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
void printNewline() override {}
void printOperand(Value) override {}
void printOperand(Value, raw_ostream &os) override {
// Users expect the output string to have at least the prefixed % to signal
Expand Down Expand Up @@ -2218,6 +2219,13 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
/// Return the current stream of the printer.
raw_ostream &getStream() const override { return os; }

/// Print a newline and indent the printer to the start of the current
/// operation.
void printNewline() override {
os << newLine;
os.indent(currentIndent);
}

/// Print the given type.
void printType(Type type) override { ModulePrinter::printType(type); }

Expand Down
3 changes: 2 additions & 1 deletion mlir/test/lib/Dialect/Test/TestOps.td
Expand Up @@ -1393,7 +1393,8 @@ def AsmDialectInterfaceOp : TEST_Op<"asm_dialect_interface_op"> {

def FormatLiteralOp : TEST_Op<"format_literal_op"> {
let assemblyFormat = [{
`keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)` `?` `+` `*` attr-dict
`keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)`
`?` `+` `*` `{` `\n` `}` attr-dict
}];
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/mlir-tblgen/op-format-spec.td
Expand Up @@ -309,7 +309,7 @@ def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{
}]>;
// CHECK-NOT: error
def LiteralValid : TestFormat_Op<"literal_valid", [{
`_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `abc$._`
`_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `\n` `abc$._`
attr-dict
}]>;

Expand Down
6 changes: 4 additions & 2 deletions mlir/test/mlir-tblgen/op-format.mlir
Expand Up @@ -7,8 +7,10 @@
// CHECK: %[[MEMREF:.*]] =
%memref = "foo.op"() : () -> (memref<1xf64>)

// CHECK: test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * {foo.some_attr}
test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * {foo.some_attr}
// CHECK: test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * {
// CHECK-NEXT: } {foo.some_attr}
test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * {
} {foo.some_attr}

// CHECK: test.format_attr_op 10
// CHECK-NOT: {attr
Expand Down
65 changes: 50 additions & 15 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Expand Up @@ -58,7 +58,8 @@ class Element {
/// This element is a literal.
Literal,

/// This element prints or omits a space. It is ignored by the parser.
/// This element is a whitespace.
Newline,
Space,

/// This element is an variable value.
Expand Down Expand Up @@ -296,14 +297,35 @@ bool LiteralElement::isValidLiteral(StringRef value) {
}

//===----------------------------------------------------------------------===//
// SpaceElement
// WhitespaceElement

namespace {
/// This class represents a whitespace element, e.g. newline or space. It's a
/// literal that is printed but never parsed.
class WhitespaceElement : public Element {
public:
WhitespaceElement(Kind kind) : Element{kind} {}
static bool classof(const Element *element) {
Kind kind = element->getKind();
return kind == Kind::Newline || kind == Kind::Space;
}
};

/// This class represents an instance of a newline element. It's a literal that
/// prints a newline. It is ignored by the parser.
class NewlineElement : public WhitespaceElement {
public:
NewlineElement() : WhitespaceElement(Kind::Newline) {}
static bool classof(const Element *element) {
return element->getKind() == Kind::Newline;
}
};

/// This class represents an instance of a space element. It's a literal that
/// prints or omits printing a space. It is ignored by the parser.
class SpaceElement : public Element {
class SpaceElement : public WhitespaceElement {
public:
SpaceElement(bool value) : Element{Kind::Space}, value(value) {}
SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {}
static bool classof(const Element *element) {
return element->getKind() == Kind::Space;
}
Expand Down Expand Up @@ -347,7 +369,8 @@ class OptionalElement : public Element {
std::vector<std::unique_ptr<Element>> elements;
/// The index of the element that acts as the anchor for the optional group.
unsigned anchor;
/// The index of the first element that is parsed (is not a SpaceElement).
/// The index of the first element that is parsed (is not a
/// WhitespaceElement).
unsigned parseStart;
};
} // end anonymous namespace
Expand Down Expand Up @@ -1098,8 +1121,8 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
genLiteralParser(literal->getLiteral(), body);
body << ")\n return ::mlir::failure();\n";

/// Spaces.
} else if (isa<SpaceElement>(element)) {
/// Whitespaces.
} else if (isa<WhitespaceElement>(element)) {
// Nothing to parse.

/// Arguments.
Expand Down Expand Up @@ -1620,6 +1643,11 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
lastWasPunctuation);

// Emit a whitespace element.
if (NewlineElement *newline = dyn_cast<NewlineElement>(element)) {
body << " p.printNewline();\n";
return;
}
if (SpaceElement *space = dyn_cast<SpaceElement>(element))
return genSpacePrinter(space->getValue(), body, shouldEmitSpace,
lastWasPunctuation);
Expand Down Expand Up @@ -2272,9 +2300,10 @@ LogicalResult FormatParser::verifyAttributes(
for (auto &nextItPair : iteratorStack) {
ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
for (; nextIt != nextE; ++nextIt) {
// Skip any trailing spaces, attribute dictionaries, or optional groups.
if (isa<SpaceElement>(*nextIt) || isa<AttrDictDirective>(*nextIt) ||
isa<OptionalElement>(*nextIt))
// Skip any trailing whitespace, attribute dictionaries, or optional
// groups.
if (isa<WhitespaceElement>(*nextIt) ||
isa<AttrDictDirective>(*nextIt) || isa<OptionalElement>(*nextIt))
continue;

// We are only interested in `:` literals.
Expand Down Expand Up @@ -2600,6 +2629,11 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element) {
element = std::make_unique<SpaceElement>(!value.empty());
return ::mlir::success();
}
// The parsed literal is a newline element.
if (value == "\\n") {
element = std::make_unique<NewlineElement>();
return ::mlir::success();
}

// Check that the parsed literal is valid.
if (!LiteralElement::isValidLiteral(value))
Expand Down Expand Up @@ -2635,8 +2669,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,

// The first parsable element of the group must be able to be parsed in an
// optional fashion.
auto parseBegin = llvm::find_if_not(
elements, [](auto &element) { return isa<SpaceElement>(element.get()); });
auto parseBegin = llvm::find_if_not(elements, [](auto &element) {
return isa<WhitespaceElement>(element.get());
});
Element *firstElement = parseBegin->get();
if (!isa<AttributeVariable>(firstElement) &&
!isa<LiteralElement>(firstElement) &&
Expand Down Expand Up @@ -2718,9 +2753,9 @@ LogicalResult FormatParser::parseOptionalChildElement(
// a check here.
return ::mlir::success();
})
// Literals, spaces, custom directives, and type directives may be used,
// but they can't anchor the group.
.Case<LiteralElement, SpaceElement, CustomDirective,
// Literals, whitespace, custom directives, and type directives may be
// used, but they can't anchor the group.
.Case<LiteralElement, WhitespaceElement, CustomDirective,
FunctionalTypeDirective, OptionalElement, TypeRefDirective,
TypeDirective>([&](Element *) {
if (isAnchor)
Expand Down

0 comments on commit c234b65

Please sign in to comment.