diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index a38630b2a04f0..51336e15fb8b1 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2328,11 +2328,6 @@ def fir_DoLoopOp : region_Op<"do_loop", ]; let extraClassDeclaration = [{ - static constexpr llvm::StringRef unorderedAttrName() { return "unordered"; } - static constexpr llvm::StringRef finalValueAttrName() { - return "finalValue"; - } - mlir::Value getInductionVar() { return getBody()->getArgument(0); } mlir::OpBuilder getBodyBuilder() { return OpBuilder(getBody(), std::prev(getBody()->end())); @@ -2367,8 +2362,7 @@ def fir_DoLoopOp : region_Op<"do_loop", mlir::Block *getBody() { return ®ion().front(); } void setUnordered() { - (*this)->setAttr(unorderedAttrName(), - mlir::UnitAttr::get(getContext())); + unorderedAttr(mlir::UnitAttr::get(getContext())); } mlir::BlockArgument iterArgToBlockArg(mlir::Value iterArg); @@ -2478,9 +2472,6 @@ def fir_IterWhileOp : region_Op<"iterate_while", ]; let extraClassDeclaration = [{ - static constexpr llvm::StringRef finalValueAttrName() { - return "finalValue"; - } mlir::Block *getBody() { return ®ion().front(); } mlir::Value getIterateVar() { return getBody()->getArgument(1); } mlir::Value getInductionVar() { return getBody()->getArgument(0); } @@ -2553,7 +2544,7 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> { CArg<"mlir::ValueRange", "{}">:$operands), [{ $_state.addOperands(operands); - $_state.addAttribute(calleeAttrName(), + $_state.addAttribute(calleeAttrName($_state.name), $_builder.getSymbolRefAttr(callee)); $_state.addTypes(callee.getType().getResults()); }]>, @@ -2562,7 +2553,7 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> { CArg<"mlir::ValueRange", "{}">:$operands), [{ $_state.addOperands(operands); - $_state.addAttribute(calleeAttrName(), callee); + $_state.addAttribute(calleeAttrName($_state.name), callee); $_state.addTypes(results); }]>, OpBuilder<(ins "llvm::StringRef":$callee, @@ -2574,14 +2565,11 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> { }]>]; let extraClassDeclaration = [{ - static constexpr StringRef calleeAttrName() { return "callee"; } - mlir::FunctionType getFunctionType(); /// Get the argument operands to the called function. operand_range getArgOperands() { - if (auto calling = - (*this)->getAttrOfType(calleeAttrName())) + if (calleeAttr()) return {arg_operand_begin(), arg_operand_end()}; return {arg_operand_begin() + 1, arg_operand_end()}; } @@ -2591,8 +2579,7 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> { /// Return the callee of this operation. CallInterfaceCallable getCallableForCallee() { - if (auto calling = - (*this)->getAttrOfType(calleeAttrName())) + if (auto calling = calleeAttr()) return calling; return getOperand(0); } @@ -2627,11 +2614,10 @@ def fir_DispatchOp : fir_Op<"dispatch", []> { llvm::StringRef calleeName; if (failed(parser.parseOptionalKeyword(&calleeName))) { mlir::StringAttr calleeAttr; - if (parser.parseAttribute(calleeAttr, methodAttrName(), - result.attributes)) + if (parser.parseAttribute(calleeAttr, "method", result.attributes)) return mlir::failure(); } else { - result.addAttribute(methodAttrName(), + result.addAttribute(methodAttrName(result.name), parser.getBuilder().getStringAttr(calleeName)); } if (parser.parseOperandList(operands, @@ -2646,7 +2632,7 @@ def fir_DispatchOp : fir_Op<"dispatch", []> { }]; let printer = [{ - p << getOperationName() << ' ' << (*this)->getAttr(methodAttrName()) << '('; + p << getOperationName() << ' ' << methodAttr() << '('; p.printOperand(object()); if (!args().empty()) { p << ", "; @@ -2668,7 +2654,6 @@ def fir_DispatchOp : fir_Op<"dispatch", []> { static constexpr llvm::StringRef passArgAttrName() { return "pass_arg_pos"; } - static constexpr llvm::StringRef methodAttrName() { return "method"; } unsigned passArgPos(); }]; } @@ -3144,7 +3129,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { p.printAttributeWithoutType((*this)->getAttr(symbolAttrName())); if (auto val = getValueOrNull()) p << '(' << val << ')'; - if ((*this)->getAttr(constantAttrName())) + if (constantAttr()) p << " constant"; p << " : "; p.printType(getType()); @@ -3177,14 +3162,11 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { let extraClassDeclaration = [{ static constexpr llvm::StringRef symbolAttrName() { return "symref"; } - static constexpr llvm::StringRef constantAttrName() { return "constant"; } - static constexpr llvm::StringRef initValAttrName() { return "initVal"; } static constexpr llvm::StringRef linkageAttrName() { return "linkName"; } - static constexpr llvm::StringRef typeAttrName() { return "type"; } /// The printable type of the global mlir::Type getType() { - return (*this)->getAttrOfType(typeAttrName()).getValue(); + return typeAttr().getValue(); } /// The semantic type of the global @@ -3378,28 +3360,23 @@ def fir_DTEntryOp : fir_Op<"dt_entry", []> { // allow `methodName` or `"methodName"` if (failed(parser.parseOptionalKeyword(&methodName))) { mlir::StringAttr methodAttr; - if (parser.parseAttribute(methodAttr, methodAttrName(), + if (parser.parseAttribute(methodAttr, "method", result.attributes)) return mlir::failure(); } else { - result.addAttribute(methodAttrName(), + result.addAttribute(methodAttrName(result.name), parser.getBuilder().getStringAttr(methodName)); } mlir::SymbolRefAttr calleeAttr; if (parser.parseComma() || - parser.parseAttribute(calleeAttr, procAttrName(), result.attributes)) + parser.parseAttribute(calleeAttr, "proc", result.attributes)) return mlir::failure(); return mlir::success(); }]; let printer = [{ - p << getOperationName() << ' ' << (*this)->getAttr(methodAttrName()) << ", " - << (*this)->getAttr(procAttrName()); - }]; - - let extraClassDeclaration = [{ - static constexpr llvm::StringRef methodAttrName() { return "method"; } - static constexpr llvm::StringRef procAttrName() { return "proc"; } + p << getOperationName() << ' ' << methodAttr() << ", " + << procAttr(); }]; } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 38390d8011347..ecdcdfcaaaea5 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -252,7 +252,7 @@ static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { else p << op.getOperand(0); p << '(' << op->getOperands().drop_front(isDirect ? 0 : 1) << ')'; - p.printOptionalAttrDict(op->getAttrs(), {fir::CallOp::calleeAttrName()}); + p.printOptionalAttrDict(op->getAttrs(), {"callee"}); auto resultTypes{op.getResultTypes()}; llvm::SmallVector argTypes( llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); @@ -269,7 +269,7 @@ static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, mlir::SymbolRefAttr funcAttr; bool isDirect = operands.empty(); if (isDirect) - if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) + if (parser.parseAttribute(funcAttr, "callee", attrs)) return mlir::failure(); Type type; @@ -586,8 +586,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { bool simpleInitializer = false; if (mlir::succeeded(parser.parseOptionalLParen())) { Attribute attr; - if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), - result.attributes) || + if (parser.parseAttribute(attr, "initVal", result.attributes) || parser.parseRParen()) return mlir::failure(); simpleInitializer = true; @@ -595,15 +594,14 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { if (succeeded(parser.parseOptionalKeyword("constant"))) { // if "constant" keyword then mark this as a constant, not a variable - result.addAttribute(fir::GlobalOp::constantAttrName(), - builder.getUnitAttr()); + result.addAttribute("constant", builder.getUnitAttr()); } mlir::Type globalType; if (parser.parseColonType(globalType)) return mlir::failure(); - result.addAttribute(fir::GlobalOp::typeAttrName(), + result.addAttribute(fir::GlobalOp::typeAttrName(result.name), mlir::TypeAttr::get(globalType)); if (simpleInitializer) { @@ -628,14 +626,14 @@ void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, Attribute initialVal, StringAttr linkage, ArrayRef attrs) { result.addRegion(); - result.addAttribute(typeAttrName(), mlir::TypeAttr::get(type)); + result.addAttribute(typeAttrName(result.name), mlir::TypeAttr::get(type)); result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); if (isConstant) - result.addAttribute(constantAttrName(), builder.getUnitAttr()); + result.addAttribute(constantAttrName(result.name), builder.getUnitAttr()); if (initialVal) - result.addAttribute(initValAttrName(), initialVal); + result.addAttribute(initValAttrName(result.name), initialVal); if (linkage) result.addAttribute(linkageAttrName(), linkage); result.attributes.append(attrs.begin(), attrs.end()); @@ -754,7 +752,7 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder, result.addOperands({lb, ub, step, iterate}); if (finalCountValue) { result.addTypes(builder.getIndexType()); - result.addAttribute(finalValueAttrName(), builder.getUnitAttr()); + result.addAttribute(finalValueAttrName(result.name), builder.getUnitAttr()); } result.addTypes(iterate.getType()); result.addOperands(iterArgs); @@ -846,7 +844,7 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, llvm::SmallVector argTypes; // Induction variable (hidden) if (prependCount) - result.addAttribute(IterWhileOp::finalValueAttrName(), + result.addAttribute(IterWhileOp::finalValueAttrName(result.name), builder.getUnitAttr()); else argTypes.push_back(indexType); @@ -940,8 +938,7 @@ static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { } else if (op.finalValue()) { p << " -> (" << op.getResultTypes() << ')'; } - p.printOptionalAttrDictWithKeyword(op->getAttrs(), - {IterWhileOp::finalValueAttrName()}); + p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"finalValue"}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -1011,7 +1008,7 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder, result.addOperands(iterArgs); if (finalCountValue) { result.addTypes(builder.getIndexType()); - result.addAttribute(finalValueAttrName(), builder.getUnitAttr()); + result.addAttribute(finalValueAttrName(result.name), builder.getUnitAttr()); } for (auto v : iterArgs) result.addTypes(v.getType()); @@ -1022,7 +1019,7 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder, bodyRegion->front().addArgument(builder.getIndexType()); bodyRegion->front().addArguments(iterArgs.getTypes()); if (unordered) - result.addAttribute(unorderedAttrName(), builder.getUnitAttr()); + result.addAttribute(unorderedAttrName(result.name), builder.getUnitAttr()); result.addAttributes(attributes); } @@ -1045,8 +1042,7 @@ static mlir::ParseResult parseDoLoopOp(mlir::OpAsmParser &parser, return failure(); if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) - result.addAttribute(fir::DoLoopOp::unorderedAttrName(), - builder.getUnitAttr()); + result.addAttribute("unordered", builder.getUnitAttr()); // Parse the optional initial iteration arguments. llvm::SmallVector regionArgs, operands; @@ -1080,7 +1076,8 @@ static mlir::ParseResult parseDoLoopOp(mlir::OpAsmParser &parser, // Induction variable. if (prependCount) - result.addAttribute(DoLoopOp::finalValueAttrName(), builder.getUnitAttr()); + result.addAttribute(DoLoopOp::finalValueAttrName(result.name), + builder.getUnitAttr()); else argTypes.push_back(indexType); // Loop carried variables @@ -1172,8 +1169,7 @@ static void print(mlir::OpAsmPrinter &p, fir::DoLoopOp op) { printBlockTerminators = true; } p.printOptionalAttrDictWithKeyword(op->getAttrs(), - {fir::DoLoopOp::unorderedAttrName(), - fir::DoLoopOp::finalValueAttrName()}); + {"unordered", "finalValue"}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, printBlockTerminators); }