diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 1fccd3b5620c1..646ced1cbd3ad 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -551,96 +551,6 @@ class fir_IntegralSwitchTerminatorOp]; - let parser = [{ - mlir::OpAsmParser::OperandType selector; - mlir::Type type; - if (parseSelector(parser, result, selector, type)) - return mlir::failure(); - - llvm::SmallVector ivalues; - llvm::SmallVector dests; - llvm::SmallVector> destArgs; - while (true) { - mlir::Attribute ivalue; // Integer or Unit - mlir::Block *dest; - llvm::SmallVector destArg; - mlir::NamedAttrList temp; - if (parser.parseAttribute(ivalue, "i", temp) || - parser.parseComma() || - parser.parseSuccessorAndUseList(dest, destArg)) - return mlir::failure(); - ivalues.push_back(ivalue); - dests.push_back(dest); - destArgs.push_back(destArg); - if (!parser.parseOptionalRSquare()) - break; - if (parser.parseComma()) - return mlir::failure(); - } - auto &bld = parser.getBuilder(); - result.addAttribute(getCasesAttr(), bld.getArrayAttr(ivalues)); - llvm::SmallVector argOffs; - int32_t sumArgs = 0; - const auto count = dests.size(); - for (std::remove_const_t i = 0; i != count; ++i) { - result.addSuccessors(dests[i]); - result.addOperands(destArgs[i]); - auto argSize = destArgs[i].size(); - argOffs.push_back(argSize); - sumArgs += argSize; - } - result.addAttribute(getOperandSegmentSizeAttr(), - bld.getI32VectorAttr({1, 0, sumArgs})); - result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); - return mlir::success(); - }]; - - let printer = [{ - p << ' '; - p.printOperand(getSelector()); - p << " : " << getSelector().getType() << " ["; - auto cases = - (*this)->getAttrOfType(getCasesAttr()).getValue(); - auto count = getNumConditions(); - for (decltype(count) i = 0; i != count; ++i) { - if (i) - p << ", "; - auto &attr = cases[i]; - if (auto intAttr = attr.dyn_cast_or_null()) - p << intAttr.getValue(); - else - p.printAttribute(attr); - p << ", "; - printSuccessorAtIndex(p, i); - } - p << ']'; - p.printOptionalAttrDict((*this)->getAttrs(), {getCasesAttr(), - getCompareOffsetAttr(), getTargetOffsetAttr(), - getOperandSegmentSizeAttr()}); - }]; - - let verifier = [{ - if (!(getSelector().getType().isa() || - getSelector().getType().isa() || - getSelector().getType().isa())) - return emitOpError("must be an integer"); - auto cases = - (*this)->getAttrOfType(getCasesAttr()).getValue(); - auto count = getNumDest(); - if (count == 0) - return emitOpError("must have at least one successor"); - if (getNumConditions() != count) - return emitOpError("number of cases and targets don't match"); - if (targetOffsetSize() != count) - return emitOpError("incorrect number of successor operand groups"); - for (decltype(count) i = 0; i != count; ++i) { - auto &attr = cases[i]; - if (!(attr.isa() || attr.isa())) - return emitOpError("invalid case alternative"); - } - return mlir::success(); - }]; - let extraClassDeclaration = extraSwitchClassDeclaration; } @@ -663,6 +573,9 @@ def fir_SelectOp : fir_IntegralSwitchTerminatorOp<"select"> { unit, ^bb5] ``` }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + } def fir_SelectRankOp : fir_IntegralSwitchTerminatorOp<"select_rank"> { @@ -683,6 +596,8 @@ def fir_SelectRankOp : fir_IntegralSwitchTerminatorOp<"select_rank"> { unit, ^bb5] ``` }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; } def fir_SelectCaseOp : fir_SwitchTerminatorOp<"select_case"> { @@ -2103,9 +2018,8 @@ def FirRegionTerminator : SingleBlockImplicitTerminator<"ResultOp">; class region_Op traits = []> : fir_Op { - let printer = "return ::print(p, *this);"; - let verifier = "return ::verify(*this);"; - let parser = "return ::parse$cppClass(parser, result);"; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; } def fir_DoLoopOp : region_Op<"do_loop", @@ -2131,6 +2045,9 @@ def fir_DoLoopOp : region_Op<"do_loop", keyword indicates that the iterations can be executed in any order. }]; + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + let arguments = (ins Index:$lowerBound, Index:$upperBound, diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 3864fa3d93108..2333dfc66aaf9 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1585,8 +1585,8 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder, result.addAttributes(attributes); } -static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, - mlir::OperationState &result) { +mlir::ParseResult IterWhileOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { auto &builder = parser.getBuilder(); mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || @@ -1684,65 +1684,64 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, return mlir::success(); } -static mlir::LogicalResult verify(fir::IterWhileOp op) { +mlir::LogicalResult IterWhileOp::verify() { // Check that the body defines as single block argument for the induction // variable. - auto *body = op.getBody(); + auto *body = getBody(); if (!body->getArgument(1).getType().isInteger(1)) - return op.emitOpError( + return emitOpError( "expected body second argument to be an index argument for " "the induction variable"); if (!body->getArgument(0).getType().isIndex()) - return op.emitOpError( + return emitOpError( "expected body first argument to be an index argument for " "the induction variable"); - auto opNumResults = op.getNumResults(); - if (op.getFinalValue()) { + auto opNumResults = getNumResults(); + if (getFinalValue()) { // Result type must be "(index, i1, ...)". - if (!op.getResult(0).getType().isa()) - return op.emitOpError("result #0 expected to be index"); - if (!op.getResult(1).getType().isSignlessInteger(1)) - return op.emitOpError("result #1 expected to be i1"); + if (!getResult(0).getType().isa()) + return emitOpError("result #0 expected to be index"); + if (!getResult(1).getType().isSignlessInteger(1)) + return emitOpError("result #1 expected to be i1"); opNumResults--; } else { // iterate_while always returns the early exit induction value. // Result type must be "(i1, ...)" - if (!op.getResult(0).getType().isSignlessInteger(1)) - return op.emitOpError("result #0 expected to be i1"); + if (!getResult(0).getType().isSignlessInteger(1)) + return emitOpError("result #0 expected to be i1"); } if (opNumResults == 0) return mlir::failure(); - if (op.getNumIterOperands() != opNumResults) - return op.emitOpError( + if (getNumIterOperands() != opNumResults) + return emitOpError( "mismatch in number of loop-carried values and defined values"); - if (op.getNumRegionIterArgs() != opNumResults) - return op.emitOpError( + if (getNumRegionIterArgs() != opNumResults) + return emitOpError( "mismatch in number of basic block args and defined values"); - auto iterOperands = op.getIterOperands(); - auto iterArgs = op.getRegionIterArgs(); - auto opResults = - op.getFinalValue() ? op.getResults().drop_front() : op.getResults(); + auto iterOperands = getIterOperands(); + auto iterArgs = getRegionIterArgs(); + auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); unsigned i = 0; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) - return op.emitOpError() << "types mismatch between " << i - << "th iter operand and defined value"; + return emitOpError() << "types mismatch between " << i + << "th iter operand and defined value"; if (std::get<1>(e).getType() != std::get<2>(e).getType()) - return op.emitOpError() << "types mismatch between " << i - << "th iter region arg and defined value"; + return emitOpError() << "types mismatch between " << i + << "th iter region arg and defined value"; i++; } return mlir::success(); } -static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { - p << " (" << op.getInductionVar() << " = " << op.getLowerBound() << " to " - << op.getUpperBound() << " step " << op.getStep() << ") and ("; - assert(op.hasIterOperands()); - auto regionArgs = op.getRegionIterArgs(); - auto operands = op.getIterOperands(); +void IterWhileOp::print(mlir::OpAsmPrinter &p) { + p << " (" << getInductionVar() << " = " << getLowerBound() << " to " + << getUpperBound() << " step " << getStep() << ") and ("; + assert(hasIterOperands()); + auto regionArgs = getRegionIterArgs(); + auto operands = getIterOperands(); p << regionArgs.front() << " = " << *operands.begin() << ")"; if (regionArgs.size() > 1) { p << " iter_args("; @@ -1751,15 +1750,15 @@ static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ") -> ("; llvm::interleaveComma( - llvm::drop_begin(op.getResultTypes(), op.getFinalValue() ? 0 : 1), p); + llvm::drop_begin(getResultTypes(), getFinalValue() ? 0 : 1), p); p << ")"; - } else if (op.getFinalValue()) { - p << " -> (" << op.getResultTypes() << ')'; + } else if (getFinalValue()) { + p << " -> (" << getResultTypes() << ')'; } - p.printOptionalAttrDictWithKeyword(op->getAttrs(), - {op.getFinalValueAttrNameStr()}); + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), + {getFinalValueAttrNameStr()}); p << ' '; - p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false, + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -1910,8 +1909,8 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder, result.addAttributes(attributes); } -static mlir::ParseResult parseDoLoopOp(mlir::OpAsmParser &parser, - mlir::OperationState &result) { +mlir::ParseResult DoLoopOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { auto &builder = parser.getBuilder(); mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; // Parse the induction variable followed by '='. @@ -1994,71 +1993,70 @@ fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) { } // Lifted from loop.loop -static mlir::LogicalResult verify(fir::DoLoopOp op) { +mlir::LogicalResult DoLoopOp::verify() { // Check that the body defines as single block argument for the induction // variable. - auto *body = op.getBody(); + auto *body = getBody(); if (!body->getArgument(0).getType().isIndex()) - return op.emitOpError( + return emitOpError( "expected body first argument to be an index argument for " "the induction variable"); - auto opNumResults = op.getNumResults(); + auto opNumResults = getNumResults(); if (opNumResults == 0) return success(); - if (op.getFinalValue()) { - if (op.getUnordered()) - return op.emitOpError("unordered loop has no final value"); + if (getFinalValue()) { + if (getUnordered()) + return emitOpError("unordered loop has no final value"); opNumResults--; } - if (op.getNumIterOperands() != opNumResults) - return op.emitOpError( + if (getNumIterOperands() != opNumResults) + return emitOpError( "mismatch in number of loop-carried values and defined values"); - if (op.getNumRegionIterArgs() != opNumResults) - return op.emitOpError( + if (getNumRegionIterArgs() != opNumResults) + return emitOpError( "mismatch in number of basic block args and defined values"); - auto iterOperands = op.getIterOperands(); - auto iterArgs = op.getRegionIterArgs(); - auto opResults = - op.getFinalValue() ? op.getResults().drop_front() : op.getResults(); + auto iterOperands = getIterOperands(); + auto iterArgs = getRegionIterArgs(); + auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); unsigned i = 0; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) - return op.emitOpError() << "types mismatch between " << i - << "th iter operand and defined value"; + return emitOpError() << "types mismatch between " << i + << "th iter operand and defined value"; if (std::get<1>(e).getType() != std::get<2>(e).getType()) - return op.emitOpError() << "types mismatch between " << i - << "th iter region arg and defined value"; + return emitOpError() << "types mismatch between " << i + << "th iter region arg and defined value"; i++; } return success(); } -static void print(mlir::OpAsmPrinter &p, fir::DoLoopOp op) { +void DoLoopOp::print(mlir::OpAsmPrinter &p) { bool printBlockTerminators = false; - p << ' ' << op.getInductionVar() << " = " << op.getLowerBound() << " to " - << op.getUpperBound() << " step " << op.getStep(); - if (op.getUnordered()) + p << ' ' << getInductionVar() << " = " << getLowerBound() << " to " + << getUpperBound() << " step " << getStep(); + if (getUnordered()) p << " unordered"; - if (op.hasIterOperands()) { + if (hasIterOperands()) { p << " iter_args("; - auto regionArgs = op.getRegionIterArgs(); - auto operands = op.getIterOperands(); + auto regionArgs = getRegionIterArgs(); + auto operands = getIterOperands(); llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); - p << ") -> (" << op.getResultTypes() << ')'; + p << ") -> (" << getResultTypes() << ')'; printBlockTerminators = true; - } else if (op.getFinalValue()) { - p << " -> " << op.getResultTypes(); + } else if (getFinalValue()) { + p << " -> " << getResultTypes(); printBlockTerminators = true; } - p.printOptionalAttrDictWithKeyword(op->getAttrs(), + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"unordered", "finalValue"}); p << ' '; - p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false, + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, printBlockTerminators); } @@ -2289,9 +2287,8 @@ mlir::LogicalResult SaveResultOp::verify() { } //===----------------------------------------------------------------------===// -// SelectOp +// IntegralSwitchTerminator //===----------------------------------------------------------------------===// - static constexpr llvm::StringRef getCompareOffsetAttr() { return "compare_operand_offsets"; } @@ -2300,6 +2297,116 @@ static constexpr llvm::StringRef getTargetOffsetAttr() { return "target_operand_offsets"; } +template +static LogicalResult verifyIntegralSwitchTerminator(OpT op) { + if (!(op.getSelector().getType().template isa() || + op.getSelector().getType().template isa() || + op.getSelector().getType().template isa())) + return op.emitOpError("must be an integer"); + auto cases = + op->template getAttrOfType(op.getCasesAttr()).getValue(); + auto count = op.getNumDest(); + if (count == 0) + return op.emitOpError("must have at least one successor"); + if (op.getNumConditions() != count) + return op.emitOpError("number of cases and targets don't match"); + if (op.targetOffsetSize() != count) + return op.emitOpError("incorrect number of successor operand groups"); + for (decltype(count) i = 0; i != count; ++i) { + if (!(cases[i].template isa())) + return op.emitOpError("invalid case alternative"); + } + return mlir::success(); +} + +static mlir::ParseResult parseIntegralSwitchTerminator( + mlir::OpAsmParser &parser, mlir::OperationState &result, + llvm::StringRef casesAttr, llvm::StringRef operandSegmentAttr) { + mlir::OpAsmParser::OperandType selector; + mlir::Type type; + if (parseSelector(parser, result, selector, type)) + return mlir::failure(); + + llvm::SmallVector ivalues; + llvm::SmallVector dests; + llvm::SmallVector> destArgs; + while (true) { + mlir::Attribute ivalue; // Integer or Unit + mlir::Block *dest; + llvm::SmallVector destArg; + mlir::NamedAttrList temp; + if (parser.parseAttribute(ivalue, "i", temp) || parser.parseComma() || + parser.parseSuccessorAndUseList(dest, destArg)) + return mlir::failure(); + ivalues.push_back(ivalue); + dests.push_back(dest); + destArgs.push_back(destArg); + if (!parser.parseOptionalRSquare()) + break; + if (parser.parseComma()) + return mlir::failure(); + } + auto &bld = parser.getBuilder(); + result.addAttribute(casesAttr, bld.getArrayAttr(ivalues)); + llvm::SmallVector argOffs; + int32_t sumArgs = 0; + const auto count = dests.size(); + for (std::remove_const_t i = 0; i != count; ++i) { + result.addSuccessors(dests[i]); + result.addOperands(destArgs[i]); + auto argSize = destArgs[i].size(); + argOffs.push_back(argSize); + sumArgs += argSize; + } + result.addAttribute(operandSegmentAttr, + bld.getI32VectorAttr({1, 0, sumArgs})); + result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); + return mlir::success(); +} + +template +static void printIntegralSwitchTerminator(OpT op, mlir::OpAsmPrinter &p) { + p << ' '; + p.printOperand(op.getSelector()); + p << " : " << op.getSelector().getType() << " ["; + auto cases = + op->template getAttrOfType(op.getCasesAttr()).getValue(); + auto count = op.getNumConditions(); + for (decltype(count) i = 0; i != count; ++i) { + if (i) + p << ", "; + auto &attr = cases[i]; + if (auto intAttr = attr.template dyn_cast_or_null()) + p << intAttr.getValue(); + else + p.printAttribute(attr); + p << ", "; + op.printSuccessorAtIndex(p, i); + } + p << ']'; + p.printOptionalAttrDict( + op->getAttrs(), {op.getCasesAttr(), getCompareOffsetAttr(), + getTargetOffsetAttr(), op.getOperandSegmentSizeAttr()}); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult fir::SelectOp::verify() { + return verifyIntegralSwitchTerminator(*this); +} + +mlir::ParseResult fir::SelectOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), + getOperandSegmentSizeAttr()); +} + +void fir::SelectOp::print(mlir::OpAsmPrinter &p) { + printIntegralSwitchTerminator(*this, p); +} + template static A getSubOperands(unsigned pos, A allArgs, mlir::DenseIntElementsAttr ranges, @@ -2644,6 +2751,20 @@ mlir::LogicalResult SelectCaseOp::verify() { // SelectRankOp //===----------------------------------------------------------------------===// +LogicalResult fir::SelectRankOp::verify() { + return verifyIntegralSwitchTerminator(*this); +} + +mlir::ParseResult fir::SelectRankOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), + getOperandSegmentSizeAttr()); +} + +void fir::SelectRankOp::print(mlir::OpAsmPrinter &p) { + printIntegralSwitchTerminator(*this, p); +} + llvm::Optional fir::SelectRankOp::getCompareOperands(unsigned) { return {}; @@ -3123,8 +3244,7 @@ void fir::IfOp::build(mlir::OpBuilder &builder, OperationState &result, } } -static mlir::ParseResult parseIfOp(OpAsmParser &parser, - OperationState &result) { +mlir::ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { result.regions.reserve(2); mlir::Region *thenRegion = result.addRegion(); mlir::Region *elseRegion = result.addRegion(); @@ -3155,32 +3275,32 @@ static mlir::ParseResult parseIfOp(OpAsmParser &parser, return mlir::success(); } -static LogicalResult verify(fir::IfOp op) { - if (op.getNumResults() != 0 && op.getElseRegion().empty()) - return op.emitOpError("must have an else block if defining values"); +LogicalResult IfOp::verify() { + if (getNumResults() != 0 && getElseRegion().empty()) + return emitOpError("must have an else block if defining values"); return mlir::success(); } -static void print(mlir::OpAsmPrinter &p, fir::IfOp op) { +void IfOp::print(mlir::OpAsmPrinter &p) { bool printBlockTerminators = false; - p << ' ' << op.getCondition(); - if (!op.getResults().empty()) { - p << " -> (" << op.getResultTypes() << ')'; + p << ' ' << getCondition(); + if (!getResults().empty()) { + p << " -> (" << getResultTypes() << ')'; printBlockTerminators = true; } p << ' '; - p.printRegion(op.getThenRegion(), /*printEntryBlockArgs=*/false, + p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, printBlockTerminators); // Print the 'else' regions if it exists and has a block. - auto &otherReg = op.getElseRegion(); + auto &otherReg = getElseRegion(); if (!otherReg.empty()) { p << " else "; p.printRegion(otherReg, /*printEntryBlockArgs=*/false, printBlockTerminators); } - p.printOptionalAttrDict(op->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs()); } void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl &results,