23 changes: 12 additions & 11 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ const char *const variadicOperandParserCode = R"(
const char *const optionalOperandParserCode = R"(
{
{0}OperandsLoc = parser.getCurrentLocation();
::mlir::OpAsmParser::OperandType operand;
::mlir::OpAsmParser::UnresolvedOperand operand;
::mlir::OptionalParseResult parseResult =
parser.parseOptionalOperand(operand);
if (parseResult.hasValue()) {
Expand Down Expand Up @@ -787,7 +787,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
genElementParserStorage(paramElement, op, body);

} else if (isa<OperandsDirective>(element)) {
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
body << " ::mlir::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
"allOperands;\n";

} else if (isa<RegionsDirective>(element)) {
Expand All @@ -805,17 +805,18 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
StringRef name = operand->getVar()->name;
if (operand->getVar()->isVariableLength()) {
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
<< name << "Operands;\n";
body
<< " ::mlir::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
<< name << "Operands;\n";
if (operand->getVar()->isVariadicOfVariadic()) {
body << " llvm::SmallVector<int32_t> " << name
<< "OperandGroupSizes;\n";
}
} else {
body << " ::mlir::OpAsmParser::OperandType " << name
body << " ::mlir::OpAsmParser::UnresolvedOperand " << name
<< "RawOperands[1];\n"
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
<< "Operands(" << name << "RawOperands);";
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
<< name << "Operands(" << name << "RawOperands);";
}
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
" (void){0}OperandsLoc;\n",
Expand Down Expand Up @@ -929,13 +930,13 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
<< "OperandsLoc = parser.getCurrentLocation();\n";
if (var->isOptional()) {
body << llvm::formatv(
" ::llvm::Optional<::mlir::OpAsmParser::OperandType> "
" ::llvm::Optional<::mlir::OpAsmParser::UnresolvedOperand> "
"{0}Operand;\n",
var->name);
} else if (var->isVariadicOfVariadic()) {
body << llvm::formatv(" "
"::llvm::SmallVector<::llvm::SmallVector<::mlir::"
"OpAsmParser::OperandType>> "
"OpAsmParser::UnresolvedOperand>> "
"{0}OperandGroups;\n",
var->name);
}
Expand All @@ -958,7 +959,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
body << llvm::formatv(
" {0} {1}Operand = {1}Operands.empty() ? {0}() : "
"{1}Operands[0];\n",
"::llvm::Optional<::mlir::OpAsmParser::OperandType>",
"::llvm::Optional<::mlir::OpAsmParser::UnresolvedOperand>",
operand->getVar()->name);

} else if (auto *type = dyn_cast<TypeDirective>(input)) {
Expand Down Expand Up @@ -1432,7 +1433,7 @@ void OperationFormat::genParserOperandTypeResolution(
// llvm::concat does not allow the case of a single range, so guard it here.
body << " if (parser.resolveOperands(";
if (op.getNumOperands() > 1) {
body << "::llvm::concat<const ::mlir::OpAsmParser::OperandType>(";
body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
body << operand.name << "Operands";
});
Expand Down