89 changes: 37 additions & 52 deletions mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,19 @@ class ParameterElement
/// Returns the name of the parameter.
StringRef getName() const { return param.getName(); }

/// Return the code to check whether the parameter is present.
auto genIsPresent(FmtContext &ctx, const Twine &self) const {
assert(isOptional() && "cannot guard on a mandatory parameter");
std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str();
ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
return tgfmt(getParam().getComparator(), &ctx);
}

/// Generate the code to check whether the parameter should be printed.
MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
assert(isOptional() && "cannot guard on a mandatory parameter");
std::string self = param.getAccessorName() + "()";
ctx.withSelf(self);
os << tgfmt("($_self", &ctx);
if (llvm::Optional<StringRef> defaultValue = getParam().getDefaultValue()) {
// Use the `comparator` field if it exists, else the equality operator.
std::string valueStr = tgfmt(*defaultValue, &ctx).str();
ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
os << " && !(" << tgfmt(getParam().getComparator(), &ctx) << ")";
}
return os << ")";
return os << "!(" << genIsPresent(ctx, self) << ")";
}

private:
Expand Down Expand Up @@ -332,13 +333,9 @@ void DefFormat::genParser(MethodBody &os) {
os << ",\n ";
std::string paramSelfStr;
llvm::raw_string_ostream selfOs(paramSelfStr);
if (param.isOptional()) {
selfOs << formatv("(_result_{0}.value_or(", param.getName());
if (Optional<StringRef> defaultValue = param.getDefaultValue())
selfOs << tgfmt(*defaultValue, &ctx);
else
selfOs << param.getCppStorageType() << "()";
selfOs << "))";
if (Optional<StringRef> defaultValue = param.getDefaultValue()) {
selfOs << formatv("(_result_{0}.value_or(", param.getName())
<< tgfmt(*defaultValue, &ctx) << "))";
} else {
selfOs << formatv("(*_result_{0})", param.getName());
}
Expand Down Expand Up @@ -447,8 +444,9 @@ void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
ParameterElement *el = *std::prev(it);
// Parse a comma if the last optional parameter had a value.
if (el->isOptional()) {
os << formatv("if (::mlir::succeeded(_result_{0}) && *_result_{0}) {{\n",
el->getName());
os << formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
el->getName(),
el->genIsPresent(ctx, "(*_result_" + el->getName() + ")"));
os.indent();
}
if (it <= lastReqIt) {
Expand Down Expand Up @@ -522,18 +520,6 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
}
)";

// Optional parameters in a struct must be parsed successfully if the
// keyword is present.
//
// {0}: Name of the parameter.
// {1}: Emit error string
const char *const checkOptionalParam = R"(
if (::mlir::succeeded(_result_{0}) && !*_result_{0}) {{
{1}"expected a value for parameter '{0}'");
return {{};
}
)";

// First iteration of the loop parsing an optional struct.
const char *const optionalStructFirst = R"(
::llvm::StringRef _paramKey;
Expand All @@ -558,11 +544,6 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
" _seen_{0} = true;\n",
param->getName());
genVariableParser(param, ctx, os.indent());
if (param->isOptional()) {
os.getStream().printReindented(strfmt(checkOptionalParam,
param->getName(),
tgfmt(parserErrorStr, &ctx).str()));
}
os.unindent() << "} else ";
// Print the check for duplicate or unknown parameter.
}
Expand Down Expand Up @@ -656,10 +637,10 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,

void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
ArrayRef<FormatElement *> elements =
el->getThenElements().drop_front(el->getParseStart());
ArrayRef<FormatElement *> thenElements =
el->getThenElements(/*parseable=*/true);

FormatElement *first = elements.front();
FormatElement *first = thenElements.front();
const auto guardOn = [&](auto params) {
os << "if (!(";
llvm::interleave(
Expand Down Expand Up @@ -687,12 +668,12 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
}
os.indent();

// Generate the parsers for the rest of the elements.
for (FormatElement *element : el->getElseElements())
// Generate the parsers for the rest of the thenElements.
for (FormatElement *element : el->getElseElements(/*parseable=*/true))
genElementParser(element, ctx, os);
os.unindent() << "} else {\n";
os.indent();
for (FormatElement *element : elements.drop_front())
for (FormatElement *element : thenElements.drop_front())
genElementParser(element, ctx, os);
os.unindent() << "}\n";
}
Expand Down Expand Up @@ -781,12 +762,16 @@ void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,

/// Generate code to guard printing on the presence of any optional parameters.
template <typename ParameterRange>
static void guardOnAny(FmtContext &ctx, MethodBody &os,
ParameterRange &&params) {
static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
bool inverted = false) {
os << "if (";
if (inverted)
os << "!(";
llvm::interleave(
params, os,
[&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
if (inverted)
os << ")";
os << ") {\n";
os.indent();
}
Expand Down Expand Up @@ -860,12 +845,12 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
FormatElement *anchor = el->getAnchor();
if (auto *param = dyn_cast<ParameterElement>(anchor)) {
guardOnAny(ctx, os, llvm::makeArrayRef(param));
guardOnAny(ctx, os, llvm::makeArrayRef(param), el->isInverted());
} else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
guardOnAny(ctx, os, params->getParams());
guardOnAny(ctx, os, params->getParams(), el->isInverted());
} else {
auto *strct = cast<StructDirective>(anchor);
guardOnAny(ctx, os, strct->getParams());
guardOnAny(ctx, os, strct->getParams(), el->isInverted());
}
// Generate the printer for the contained elements.
{
Expand Down Expand Up @@ -917,9 +902,9 @@ class DefFormatParser : public FormatParser {
verifyCustomDirectiveArguments(SMLoc loc,
ArrayRef<FormatElement *> arguments) override;
/// Verify the elements of an optional group.
LogicalResult
verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
Optional<unsigned> anchorIndex) override;
LogicalResult verifyOptionalGroupElements(SMLoc loc,
ArrayRef<FormatElement *> elements,
FormatElement *anchor) override;

/// Parse an attribute or type variable.
FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
Expand Down Expand Up @@ -989,7 +974,7 @@ LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
LogicalResult
DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
ArrayRef<FormatElement *> elements,
Optional<unsigned> anchorIndex) {
FormatElement *anchor) {
// `params` and `struct` directives are allowed only if all the contained
// parameters are optional.
for (FormatElement *el : elements) {
Expand All @@ -1011,8 +996,8 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
}
}
// The anchor must be a parameter or one of the aforementioned directives.
if (anchorIndex && !isa<ParameterElement, ParamsDirective, StructDirective>(
elements[*anchorIndex])) {
if (anchor &&
!isa<ParameterElement, ParamsDirective, StructDirective>(anchor)) {
return emitError(loc,
"optional group anchor must be a parameter or directive");
}
Expand Down
78 changes: 44 additions & 34 deletions mlir/tools/mlir-tblgen/FormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,65 +320,75 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {

// Parse the child elements for this optional group.
std::vector<FormatElement *> thenElements, elseElements;
Optional<unsigned> anchorIndex;
do {
FailureOr<FormatElement *> element = parseElement(TopLevelContext);
if (failed(element))
return failure();
// Check for an anchor.
if (curToken.is(FormatToken::caret)) {
if (anchorIndex)
return emitError(curToken.getLoc(), "only one element can be marked as "
"the anchor of an optional group");
anchorIndex = thenElements.size();
consumeToken();
}
thenElements.push_back(*element);
} while (!curToken.is(FormatToken::r_paren));
FormatElement *anchor = nullptr;
auto parseChildElements =
[this, &anchor](std::vector<FormatElement *> &elements) -> LogicalResult {
do {
FailureOr<FormatElement *> element = parseElement(TopLevelContext);
if (failed(element))
return failure();
// Check for an anchor.
if (curToken.is(FormatToken::caret)) {
if (anchor) {
return emitError(curToken.getLoc(),
"only one element can be marked as the anchor of an "
"optional group");
}
anchor = *element;
consumeToken();
}
elements.push_back(*element);
} while (!curToken.is(FormatToken::r_paren));
return success();
};

// Parse the 'then' elements. If the anchor was found in this group, then the
// optional is not inverted.
if (failed(parseChildElements(thenElements)))
return failure();
consumeToken();
bool inverted = !anchor;

// Parse the `else` elements of this optional group.
if (curToken.is(FormatToken::colon)) {
consumeToken();
if (failed(
parseToken(FormatToken::l_paren,
"expected '(' to start else branch of optional group")))
if (failed(parseToken(
FormatToken::l_paren,
"expected '(' to start else branch of optional group")) ||
failed(parseChildElements(elseElements)))
return failure();
do {
FailureOr<FormatElement *> element = parseElement(TopLevelContext);
if (failed(element))
return failure();
elseElements.push_back(*element);
} while (!curToken.is(FormatToken::r_paren));
consumeToken();
}
if (failed(parseToken(FormatToken::question,
"expected '?' after optional group")))
return failure();

// The optional group is required to have an anchor.
if (!anchorIndex)
if (!anchor)
return emitError(loc, "optional group has no anchor element");

// Verify the child elements.
if (failed(verifyOptionalGroupElements(loc, thenElements, anchorIndex)) ||
failed(verifyOptionalGroupElements(loc, elseElements, llvm::None)))
if (failed(verifyOptionalGroupElements(loc, thenElements, anchor)) ||
failed(verifyOptionalGroupElements(loc, elseElements, nullptr)))
return failure();

// Get the first parsable element. It must be an element that can be
// optionally-parsed.
auto parseBegin = llvm::find_if_not(thenElements, [](FormatElement *element) {
auto isWhitespace = [](FormatElement *element) {
return isa<WhitespaceElement>(element);
});
if (!isa<LiteralElement, VariableElement>(*parseBegin)) {
};
auto thenParseBegin = llvm::find_if_not(thenElements, isWhitespace);
auto elseParseBegin = llvm::find_if_not(elseElements, isWhitespace);
unsigned thenParseStart = std::distance(thenElements.begin(), thenParseBegin);
unsigned elseParseStart = std::distance(elseElements.begin(), elseParseBegin);

if (!isa<LiteralElement, VariableElement>(*thenParseBegin)) {
return emitError(loc, "first parsable element of an optional group must be "
"a literal or variable");
}

unsigned parseStart = std::distance(thenElements.begin(), parseBegin);
return create<OptionalElement>(std::move(thenElements),
std::move(elseElements), *anchorIndex,
parseStart);
std::move(elseElements), thenParseStart,
elseParseStart, anchor, inverted);
}

FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
Expand Down
46 changes: 30 additions & 16 deletions mlir/tools/mlir-tblgen/FormatGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,34 +378,48 @@ class OptionalElement : public FormatElementBase<FormatElement::Optional> {
/// Create an optional group with the given child elements.
OptionalElement(std::vector<FormatElement *> &&thenElements,
std::vector<FormatElement *> &&elseElements,
unsigned anchorIndex, unsigned parseStart)
unsigned thenParseStart, unsigned elseParseStart,
FormatElement *anchor, bool inverted)
: thenElements(std::move(thenElements)),
elseElements(std::move(elseElements)), anchorIndex(anchorIndex),
parseStart(parseStart) {}

/// Return the `then` elements of the optional group.
ArrayRef<FormatElement *> getThenElements() const { return thenElements; }
elseElements(std::move(elseElements)), thenParseStart(thenParseStart),
elseParseStart(elseParseStart), anchor(anchor), inverted(inverted) {}

/// Return the `then` elements of the optional group. Drops the first
/// `thenParseStart` whitespace elements if `parseable` is true.
ArrayRef<FormatElement *> getThenElements(bool parseable = false) const {
return llvm::makeArrayRef(thenElements)
.drop_front(parseable ? thenParseStart : 0);
}

/// Return the `else` elements of the optional group.
ArrayRef<FormatElement *> getElseElements() const { return elseElements; }
/// Return the `else` elements of the optional group. Drops the first
/// `elseParseStart` whitespace elements if `parseable` is true.
ArrayRef<FormatElement *> getElseElements(bool parseable = false) const {
return llvm::makeArrayRef(elseElements)
.drop_front(parseable ? elseParseStart : 0);
}

/// Return the anchor of the optional group.
FormatElement *getAnchor() const { return thenElements[anchorIndex]; }
FormatElement *getAnchor() const { return anchor; }

/// Return the index of the first element to be parsed.
unsigned getParseStart() const { return parseStart; }
/// Return true if the optional group is inverted.
bool isInverted() const { return inverted; }

private:
/// The child elements emitted when the anchor is present.
std::vector<FormatElement *> thenElements;
/// The child elements emitted when the anchor is not present.
std::vector<FormatElement *> elseElements;
/// The index of the anchor element of the optional group within
/// `thenElements`.
unsigned anchorIndex;
/// The index of the first element that is parsed in `thenElements`. That is,
/// the first non-whitespace element.
unsigned parseStart;
unsigned thenParseStart;
/// The index of the first element that is parsed in `elseElements`. That is,
/// the first non-whitespace element.
unsigned elseParseStart;
/// The anchor element of the optional group.
FormatElement *anchor;
/// Whether the optional group condition is inverted and the anchor element is
/// in the else group.
bool inverted;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -496,7 +510,7 @@ class FormatParser {
virtual LogicalResult
verifyOptionalGroupElements(llvm::SMLoc loc,
ArrayRef<FormatElement *> elements,
Optional<unsigned> anchorIndex) = 0;
FormatElement *anchor) = 0;

//===--------------------------------------------------------------------===//
// Lexer Utilities
Expand Down
123 changes: 69 additions & 54 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1119,17 +1119,43 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
GenContext genCtx) {
/// Optional Group.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
ArrayRef<FormatElement *> elements =
optional->getThenElements().drop_front(optional->getParseStart());
auto genElementParsers = [&](FormatElement *firstElement,
ArrayRef<FormatElement *> elements,
bool thenGroup) {
// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
FormatElement *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
if (anchorAttr && anchorAttr != firstElement &&
anchorAttr->isUnitAttr()) {
elidedAnchorElement = anchorAttr;

if (!thenGroup == optional->isInverted()) {
// Add the anchor unit attribute to the operation state.
body << " result.addAttribute(\"" << anchorAttr->getVar()->name
<< "\", parser.getBuilder().getUnitAttr());\n";
}
}

// Generate the rest of the elements inside an optional group. Elements in
// an optional group after the guard are parsed as required.
for (FormatElement *childElement : elements)
if (childElement != elidedAnchorElement)
genElementParser(childElement, body, attrTypeCtx,
GenContext::Optional);
};

ArrayRef<FormatElement *> thenElements =
optional->getThenElements(/*parseable=*/true);

// Generate a special optional parser for the first element to gate the
// parsing of the rest of the elements.
FormatElement *firstElement = elements.front();
FormatElement *firstElement = thenElements.front();
if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
genElementParser(attrVar, body, attrTypeCtx);
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (succeeded(parser.parseOptional";
body << " if (::mlir::succeeded(parser.parseOptional";
genLiteralParser(literal->getSpelling(), body);
body << ")) {\n";
} else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
Expand All @@ -1151,31 +1177,18 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
}
}

// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
FormatElement *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
elidedAnchorElement = anchorAttr;

// Add the anchor unit attribute to the operation state.
body << " result.addAttribute(\"" << anchorAttr->getVar()->name
<< "\", parser.getBuilder().getUnitAttr());\n";
}

// Generate the rest of the elements inside an optional group. Elements in
// an optional group after the guard are parsed as required.
for (FormatElement *childElement : llvm::drop_begin(elements, 1))
if (childElement != elidedAnchorElement)
genElementParser(childElement, body, attrTypeCtx, GenContext::Optional);
genElementParsers(firstElement, thenElements.drop_front(),
/*thenGroup=*/true);
body << " }";

// Generate the else elements.
auto elseElements = optional->getElseElements();
if (!elseElements.empty()) {
body << " else {\n";
for (FormatElement *childElement : elseElements)
genElementParser(childElement, body, attrTypeCtx);
ArrayRef<FormatElement *> elseElements =
optional->getElseElements(/*parsable=*/true);
genElementParsers(elseElements.front(), elseElements,
/*thenGroup=*/false);
body << " }";
}
body << "\n";
Expand Down Expand Up @@ -1842,15 +1855,15 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
const NamedTypeConstraint *var = element->getVar();
std::string name = op.getGetterName(var->name);
if (var->isOptional())
body << " if (" << name << "()) {\n";
body << name << "()";
else if (var->isVariadic())
body << " if (!" << name << "().empty()) {\n";
body << "!" << name << "().empty()";
})
.Case<RegionVariable>([&](RegionVariable *element) {
const NamedRegion *var = element->getVar();
std::string name = op.getGetterName(var->name);
// TODO: Add a check for optional regions here when ODS supports it.
body << " if (!" << name << "().empty()) {\n";
body << "!" << name << "().empty()";
})
.Case<TypeDirective>([&](TypeDirective *element) {
genOptionalGroupPrinterAnchor(element->getArg(), op, body);
Expand All @@ -1859,8 +1872,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
})
.Case<AttributeVariable>([&](AttributeVariable *attr) {
body << " if ((*this)->getAttr(\"" << attr->getVar()->name
<< "\")) {\n";
body << "(*this)->getAttr(\"" << attr->getVar()->name << "\")";
});
}

Expand Down Expand Up @@ -1912,39 +1924,45 @@ void OperationFormat::genElementPrinter(FormatElement *element,
if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
// Emit the check for the presence of the anchor element.
FormatElement *anchor = optional->getAnchor();
body << " if (";
if (optional->isInverted())
body << "!";
genOptionalGroupPrinterAnchor(anchor, op, body);
body << ") {\n";
body.indent();

// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
auto elements = optional->getThenElements();
ArrayRef<FormatElement *> thenElements = optional->getThenElements();
ArrayRef<FormatElement *> elseElements = optional->getElseElements();
FormatElement *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
if (anchorAttr && anchorAttr != elements.front() &&
if (anchorAttr && anchorAttr != thenElements.front() &&
(elseElements.empty() || anchorAttr != elseElements.front()) &&
anchorAttr->isUnitAttr()) {
elidedAnchorElement = anchorAttr;
}
auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) {
for (FormatElement *childElement : elements) {
if (childElement != elidedAnchorElement) {
genElementPrinter(childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
}
};

// Emit each of the elements.
for (FormatElement *childElement : elements) {
if (childElement != elidedAnchorElement) {
genElementPrinter(childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
}
body << " }";
genElementPrinters(thenElements);
body << "}";

// Emit each of the else elements.
auto elseElements = optional->getElseElements();
if (!elseElements.empty()) {
body << " else {\n";
for (FormatElement *childElement : elseElements) {
genElementPrinter(childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
body << " }";
genElementPrinters(elseElements);
body << "}";
}

body << "\n";
body.unindent() << "\n";
return;
}

Expand Down Expand Up @@ -2170,9 +2188,9 @@ class OpFormatParser : public FormatParser {
verifyCustomDirectiveArguments(SMLoc loc,
ArrayRef<FormatElement *> arguments) override;
/// Verify the elements of an optional group.
LogicalResult
verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
Optional<unsigned> anchorIndex) override;
LogicalResult verifyOptionalGroupElements(SMLoc loc,
ArrayRef<FormatElement *> elements,
FormatElement *anchor) override;
LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
bool isAnchor);

Expand Down Expand Up @@ -3150,13 +3168,10 @@ OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) {
return element;
}

LogicalResult
OpFormatParser::verifyOptionalGroupElements(SMLoc loc,
ArrayRef<FormatElement *> elements,
Optional<unsigned> anchorIndex) {
for (auto &it : llvm::enumerate(elements)) {
if (failed(verifyOptionalGroupElement(
loc, it.value(), anchorIndex && *anchorIndex == it.index())))
LogicalResult OpFormatParser::verifyOptionalGroupElements(
SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) {
for (FormatElement *element : elements) {
if (failed(verifyOptionalGroupElement(loc, element, element == anchor)))
return failure();
}
return success();
Expand Down