Skip to content

Commit

Permalink
[mlir][ir] Custom ops' parse/print fall back to dialect hooks
Browse files Browse the repository at this point in the history
Custom ops that have no parser or printer should fall back to the dialect's parser and/or printer hooks. This avoids the need to define parsers and printers that simply dispatch to the dialect hook.

Reviewed By: mehdi_amini, rriddle

Differential Revision: https://reviews.llvm.org/D115481
  • Loading branch information
Mogball committed Dec 10, 2021
1 parent 7c8f4e7 commit 0845635
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 10 deletions.
18 changes: 12 additions & 6 deletions mlir/include/mlir/IR/OpDefinition.h
Expand Up @@ -173,13 +173,19 @@ class OpState {
/// back to this one which accepts everything.
LogicalResult verify() { return success(); }

/// Unless overridden, the custom assembly form of an op is always rejected.
/// Op implementations should implement this to return failure.
/// On success, they should fill in result with the fields to use.
/// Parse the custom form of an operation. Unless overridden, this method will
/// first try to get an operation parser from the op's dialect. Otherwise the
/// custom assembly form of an op is always rejected. Op implementations
/// should implement this to return failure. On success, they should fill in
/// result with the fields to use.
static ParseResult parse(OpAsmParser &parser, OperationState &result);

// The fallback for the printer is to print it the generic assembly form.
static void print(Operation *op, OpAsmPrinter &p);
/// Print the operation. Unless overridden, this method will first try to get
/// an operation printer from the dialect. Otherwise, it prints the operation
/// in generic form.
static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);

/// Print an operation name, eliding the dialect prefix if necessary.
static void printOpName(Operation *op, OpAsmPrinter &p,
StringRef defaultDialect);

Expand Down Expand Up @@ -1781,7 +1787,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
OperationName::PrintAssemblyFn>
getPrintAssemblyFnImpl() {
return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
return OpState::print(op, printer);
return OpState::print(op, printer, defaultDialect);
};
}
/// The internal implementation of `getPrintAssemblyFn` that is invoked when
Expand Down
21 changes: 17 additions & 4 deletions mlir/lib/IR/Operation.cpp
Expand Up @@ -580,14 +580,27 @@ Operation *Operation::clone() {
// OpState trait class.
//===----------------------------------------------------------------------===//

// The fallback for the parser is to reject the custom assembly form.
// The fallback for the parser is to try for a dialect operation parser.
// Otherwise, reject the custom assembly form.
ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) {
if (auto parseFn = result.name.getDialect()->getParseOperationHook(
result.name.getStringRef()))
return (*parseFn)(parser, result);
return parser.emitError(parser.getNameLoc(), "has no custom assembly form");
}

// The fallback for the printer is to print in the generic assembly form.
void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); }
// The fallback for the printer is to print in the generic assembly form.
// The fallback for the printer is to try for a dialect operation printer.
// Otherwise, it prints the generic form.
void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
if (auto printFn = op->getDialect()->getOperationPrinter(op)) {
printOpName(op, p, defaultDialect);
printFn(op, p);
} else {
p.printGenericOp(op);
}
}

/// Print an operation name, eliding the dialect prefix if necessary.
void OpState::printOpName(Operation *op, OpAsmPrinter &p,
StringRef defaultDialect) {
StringRef name = op->getName().getStringRef();
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/IR/parser.mlir
Expand Up @@ -1425,3 +1425,8 @@ test.graph_region {
// This is an unregister operation, the printing/parsing is handled by the dialect.
// CHECK: test.dialect_custom_printer custom_format
test.dialect_custom_printer custom_format

// This is a registered operation with no custom parser and printer, and should
// be handled by the dialect.
// CHECK: test.dialect_custom_format_fallback custom_format_fallback
test.dialect_custom_format_fallback custom_format_fallback
10 changes: 10 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Expand Up @@ -318,6 +318,11 @@ TestDialect::getParseOperationHook(StringRef opName) const {
return parser.parseKeyword("custom_format");
}};
}
if (opName == "test.dialect_custom_format_fallback") {
return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
return parser.parseKeyword("custom_format_fallback");
}};
}
return None;
}

Expand All @@ -329,6 +334,11 @@ TestDialect::getOperationPrinter(Operation *op) const {
printer.getStream() << " custom_format";
};
}
if (opName == "test.dialect_custom_format_fallback") {
return [](Operation *op, OpAsmPrinter &printer) {
printer.getStream() << " custom_format_fallback";
};
}
return {};
}

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Expand Up @@ -597,6 +597,10 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results",
);
}

// This is used to test that the fallback for a custom op's parser and printer
// is the dialect parser and printer hooks.
def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;

// This is used to test encoding of a string attribute into an SSA name of a
// pretty printed value name.
def StringAttrPrettyNameOp
Expand Down

0 comments on commit 0845635

Please sign in to comment.