Skip to content

Commit

Permalink
[flang] Updated FIR dialect to _Both
Browse files Browse the repository at this point in the history
Change dialect (and remove now redundant accessors) to generate both
form of accessors of being generated. Tried to keep this change
reasonably minimal (this also includes keeping note about not generating
getType accessor to avoid shadowing).

Differential Revision: https://reviews.llvm.org/D115420
  • Loading branch information
jpienaar committed Dec 9, 2021
1 parent 428ed61 commit 3012f35
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 98 deletions.
2 changes: 1 addition & 1 deletion flang/include/flang/Lower/Support/Utils.h
Expand Up @@ -32,7 +32,7 @@ inline llvm::StringRef toStringRef(const Fortran::parser::CharBlock &cb) {
namespace fir {
/// Return the integer value of a arith::ConstantOp.
inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
return cop.value().cast<mlir::IntegerAttr>().getValue().getSExtValue();
return cop.getValue().cast<mlir::IntegerAttr>().getValue().getSExtValue();
}
} // namespace fir

Expand Down
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Dialect/FIRDialect.td
Expand Up @@ -24,6 +24,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def fir_Dialect : Dialect {
let name = "fir";
let cppNamespace = "::fir";
let emitAccessorPrefix = kEmitAccessorPrefix_Both;
}

#endif // FORTRAN_DIALECT_FIR_DIALECT
53 changes: 7 additions & 46 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Expand Up @@ -177,7 +177,6 @@ def fir_AllocaOp : fir_Op<"alloca", [AttrSizedOperandSegments,
unsigned numShapeOperands() { return shape().size(); }
operand_range getShapeOperands() { return shape(); }
static mlir::Type getRefTy(mlir::Type ty);
mlir::Type getInType() { return in_type(); }
}];
}

Expand Down Expand Up @@ -235,7 +234,6 @@ def fir_AllocMemOp : fir_Op<"allocmem",
unsigned numShapeOperands() { return shape().size(); }
operand_range getShapeOperands() { return shape(); }
static mlir::Type getRefTy(mlir::Type ty);
mlir::Type getInType() { return in_type(); }
}];
}

Expand Down Expand Up @@ -484,7 +482,6 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
}

// The selector is the value being tested to determine the destination
mlir::Value getSelector() { return selector(); }
mlir::Value getSelector(llvm::ArrayRef<mlir::Value> operands) {
return operands[0];
}
Expand Down Expand Up @@ -893,8 +890,6 @@ def fir_EmboxOp : fir_Op<"embox", [NoSideEffect, AttrSizedOperandSegments]> {
let verifier = "return ::verify(*this);";

let extraClassDeclaration = [{
mlir::Value getShape() { return shape(); }
mlir::Value getSlice() { return slice(); }
bool hasLenParams() { return !typeparams().empty(); }
unsigned numLenParams() { return typeparams().size(); }
}];
Expand Down Expand Up @@ -1673,11 +1668,6 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
[{ return build($_builder, $_state, resultType, ref, coor,
mlir::TypeAttr::get(ref.getType())); }]>,
];

let extraClassDeclaration = [{
/// Get the type of the base object.
mlir::Type getBaseType() { return baseType(); }
}];
}

def fir_ExtractValueOp : fir_OneResultOp<"extract_value", [NoSideEffect]> {
Expand Down Expand Up @@ -1771,12 +1761,6 @@ def fir_ShapeOp : fir_Op<"shape", [NoSideEffect]> {
}];

let verifier = "return ::verify(*this);";

let extraClassDeclaration = [{
std::vector<mlir::Value> getExtents() {
return {extents().begin(), extents().end()};
}
}];
}

def fir_ShapeShiftOp : fir_Op<"shape_shift", [NoSideEffect]> {
Expand Down Expand Up @@ -1854,12 +1838,6 @@ def fir_ShiftOp : fir_Op<"shift", [NoSideEffect]> {
}];

let verifier = "return ::verify(*this);";

let extraClassDeclaration = [{
std::vector<mlir::Value> getOrigins() {
return {origins().begin(), origins().end()};
}
}];
}

def fir_SliceOp : fir_Op<"slice", [NoSideEffect, AttrSizedOperandSegments]> {
Expand Down Expand Up @@ -2022,9 +2000,6 @@ def fir_LenParamIndexOp : fir_OneResultOp<"len_param_index", [NoSideEffect]> {
let extraClassDeclaration = [{
static constexpr llvm::StringRef fieldAttrName() { return "field_id"; }
static constexpr llvm::StringRef typeAttrName() { return "on_type"; }
mlir::Type getOnType() {
return (*this)->getAttrOfType<TypeAttr>(typeAttrName()).getValue();
}
}];
}

Expand Down Expand Up @@ -2250,7 +2225,7 @@ def fir_IterWhileOp : region_Op<"iterate_while",
];

let extraClassDeclaration = [{
static constexpr llvm::StringRef getFinalValueAttrName() {
static constexpr llvm::StringRef getFinalValueAttrNameStr() {
return "finalValue";
}
mlir::Block *getBody() { return &region().front(); }
Expand Down Expand Up @@ -2336,7 +2311,7 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
}]>];

let extraClassDeclaration = [{
static constexpr StringRef getCalleeAttrName() { return "callee"; }
static constexpr StringRef getCalleeAttrNameStr() { return "callee"; }

mlir::FunctionType getFunctionType();

Expand Down Expand Up @@ -2396,7 +2371,7 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
static constexpr llvm::StringRef passArgAttrName() {
return "pass_arg_pos";
}
static constexpr llvm::StringRef getMethodAttrName() { return "method"; }
static constexpr llvm::StringRef getMethodAttrNameStr() { return "method"; }
unsigned passArgPos();
}];
}
Expand Down Expand Up @@ -2646,13 +2621,6 @@ def fir_GenTypeDescOp : fir_OneResultOp<"gentypedesc", [NoSideEffect]> {
let builders = [OpBuilder<(ins "mlir::TypeAttr":$inty)>];

let verifier = "return ::verify(*this);";

let extraClassDeclaration = [{
mlir::Type getInType() {
// get the type that the type descriptor describes
return (*this)->getAttrOfType<mlir::TypeAttr>("in_type").getValue();
}
}];
}

def fir_NoReassocOp : fir_OneResultOp<"no_reassoc",
Expand Down Expand Up @@ -2744,8 +2712,8 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
];

let extraClassDeclaration = [{
static constexpr llvm::StringRef symbolAttrName() { return "symref"; }
static constexpr llvm::StringRef getConstantAttrName() {
static constexpr llvm::StringRef symbolAttrNameStr() { return "symref"; }
static constexpr llvm::StringRef getConstantAttrNameStr() {
return "constant";
}
static constexpr llvm::StringRef linkageAttrName() { return "linkName"; }
Expand All @@ -2765,9 +2733,6 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
/// the variable's initial value.
void appendInitialValue(mlir::Operation *op);

/// A GlobalOp has one region.
mlir::Region &getRegion() { return (*this)->getRegion(0); }

/// A GlobalOp has one block.
mlir::Block &getBlock() { return getRegion().front(); }

Expand Down Expand Up @@ -2860,10 +2825,6 @@ def fir_DispatchTableOp : fir_Op<"dispatch_table",
/// Append a dispatch table entry to the table.
void appendTableEntry(mlir::Operation *op);

mlir::Region &getRegion() {
return (*this)->getRegion(0);
}

mlir::Block &getBlock() {
return getRegion().front();
}
Expand Down Expand Up @@ -2892,8 +2853,8 @@ def fir_DTEntryOp : fir_Op<"dt_entry", [HasParent<"DispatchTableOp">]> {
let printer = "::print(p, *this);";

let extraClassDeclaration = [{
static constexpr llvm::StringRef getMethodAttrName() { return "method"; }
static constexpr llvm::StringRef getProcAttrName() { return "proc"; }
static constexpr llvm::StringRef getMethodAttrNameStr() { return "method"; }
static constexpr llvm::StringRef getProcAttrNameStr() { return "proc"; }
}];
}

Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Support/Utils.h
Expand Up @@ -19,7 +19,7 @@
namespace fir {
/// Return the integer value of a arith::ConstantOp.
inline std::int64_t toInt(mlir::arith::ConstantOp cop) {
return cop.value().cast<mlir::IntegerAttr>().getValue().getSExtValue();
return cop.getValue().cast<mlir::IntegerAttr>().getValue().getSExtValue();
}
} // namespace fir

Expand Down
6 changes: 4 additions & 2 deletions flang/include/flang/Optimizer/Transforms/Factory.h
Expand Up @@ -38,8 +38,10 @@ inline std::vector<mlir::Value> getOrigins(mlir::Value shapeVal) {
if (auto *shapeOp = shapeVal.getDefiningOp()) {
if (auto shOp = mlir::dyn_cast<fir::ShapeShiftOp>(shapeOp))
return shOp.getOrigins();
if (auto shOp = mlir::dyn_cast<fir::ShiftOp>(shapeOp))
return shOp.getOrigins();
if (auto shOp = mlir::dyn_cast<fir::ShiftOp>(shapeOp)) {
auto operands = shOp.getOrigins();
return {operands.begin(), operands.end()};
}
}
return {};
}
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/CodeGen/CodeGen.cpp
Expand Up @@ -1060,7 +1060,7 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
mlir::Type vecType = mlir::VectorType::get(
insertOp.getType().getShape(), constant.getType());
auto denseAttr = mlir::DenseElementsAttr::get(
vecType.cast<ShapedType>(), constant.value());
vecType.cast<ShapedType>(), constant.getValue());
rewriter.setInsertionPointAfter(insertOp);
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
insertOp, seqTyAttr, denseAttr);
Expand Down
51 changes: 26 additions & 25 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Expand Up @@ -382,8 +382,10 @@ static mlir::Type adjustedElementType(mlir::Type t) {
std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() {
if (auto sh = shape())
if (auto *op = sh.getDefiningOp()) {
if (auto shOp = dyn_cast<fir::ShapeOp>(op))
return shOp.getExtents();
if (auto shOp = dyn_cast<fir::ShapeOp>(op)) {
auto extents = shOp.getExtents();
return {extents.begin(), extents.end()};
}
return cast<fir::ShapeShiftOp>(op).getExtents();
}
return {};
Expand Down Expand Up @@ -632,7 +634,7 @@ static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser,
void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
mlir::FuncOp callee, mlir::ValueRange operands) {
result.addOperands(operands);
result.addAttribute(getCalleeAttrName(), SymbolRefAttr::get(callee));
result.addAttribute(getCalleeAttrNameStr(), SymbolRefAttr::get(callee));
result.addTypes(callee.getType().getResults());
}

Expand All @@ -642,7 +644,7 @@ void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
mlir::ValueRange operands) {
result.addOperands(operands);
if (callee)
result.addAttribute(getCalleeAttrName(), callee);
result.addAttribute(getCalleeAttrNameStr(), callee);
result.addTypes(results);
}

Expand Down Expand Up @@ -921,11 +923,12 @@ static mlir::ParseResult parseDispatchOp(mlir::OpAsmParser &parser,
llvm::StringRef calleeName;
if (failed(parser.parseOptionalKeyword(&calleeName))) {
mlir::StringAttr calleeAttr;
if (parser.parseAttribute(calleeAttr, fir::DispatchOp::getMethodAttrName(),
if (parser.parseAttribute(calleeAttr,
fir::DispatchOp::getMethodAttrNameStr(),
result.attributes))
return mlir::failure();
} else {
result.addAttribute(fir::DispatchOp::getMethodAttrName(),
result.addAttribute(fir::DispatchOp::getMethodAttrNameStr(),
parser.getBuilder().getStringAttr(calleeName));
}
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
Expand All @@ -939,8 +942,7 @@ static mlir::ParseResult parseDispatchOp(mlir::OpAsmParser &parser,
}

static void print(mlir::OpAsmPrinter &p, fir::DispatchOp &op) {
p << ' ' << op.getOperation()->getAttr(fir::DispatchOp::getMethodAttrName())
<< '(';
p << ' ' << op.getMethodAttr() << '(';
p.printOperand(op.object());
if (!op.args().empty()) {
p << ", ";
Expand Down Expand Up @@ -1167,7 +1169,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {

// Parse the name as a symbol reference attribute.
mlir::SymbolRefAttr nameAttr;
if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(),
if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrNameStr(),
result.attributes))
return mlir::failure();
result.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
Expand Down Expand Up @@ -1211,11 +1213,10 @@ static void print(mlir::OpAsmPrinter &p, fir::GlobalOp &op) {
if (op.linkName().hasValue())
p << ' ' << op.linkName().getValue();
p << ' ';
p.printAttributeWithoutType(
op.getOperation()->getAttr(fir::GlobalOp::symbolAttrName()));
p.printAttributeWithoutType(op.getSymrefAttr());
if (auto val = op.getValueOrNull())
p << '(' << val << ')';
if (op.getOperation()->getAttr(fir::GlobalOp::getConstantAttrName()))
if (op.getOperation()->getAttr(fir::GlobalOp::getConstantAttrNameStr()))
p << " constant";
p << " : ";
p.printType(op.getType());
Expand All @@ -1237,7 +1238,7 @@ void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result,
result.addAttribute(typeAttrName(result.name), mlir::TypeAttr::get(type));
result.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
result.addAttribute(symbolAttrName(),
result.addAttribute(symbolAttrNameStr(),
SymbolRefAttr::get(builder.getContext(), name));
if (isConstant)
result.addAttribute(constantAttrName(result.name), builder.getUnitAttr());
Expand Down Expand Up @@ -1483,13 +1484,13 @@ struct UndoComplexPattern : public mlir::RewritePattern {
!isZero(insval2.coor()[0]))
return mlir::failure();
auto eai =
dyn_cast_or_null<fir::ExtractValueOp>(binf.lhs().getDefiningOp());
dyn_cast_or_null<fir::ExtractValueOp>(binf.getLhs().getDefiningOp());
auto ebi =
dyn_cast_or_null<fir::ExtractValueOp>(binf.rhs().getDefiningOp());
dyn_cast_or_null<fir::ExtractValueOp>(binf.getRhs().getDefiningOp());
auto ear =
dyn_cast_or_null<fir::ExtractValueOp>(binf2.lhs().getDefiningOp());
dyn_cast_or_null<fir::ExtractValueOp>(binf2.getLhs().getDefiningOp());
auto ebr =
dyn_cast_or_null<fir::ExtractValueOp>(binf2.rhs().getDefiningOp());
dyn_cast_or_null<fir::ExtractValueOp>(binf2.getRhs().getDefiningOp());
if (!eai || !ebi || !ear || !ebr || ear.adt() != eai.adt() ||
ebr.adt() != ebi.adt() || eai.coor().size() != 1 ||
!isOne(eai.coor()[0]) || ebi.coor().size() != 1 ||
Expand Down Expand Up @@ -1521,7 +1522,7 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder,
result.addOperands({lb, ub, step, iterate});
if (finalCountValue) {
result.addTypes(builder.getIndexType());
result.addAttribute(getFinalValueAttrName(), builder.getUnitAttr());
result.addAttribute(getFinalValueAttrNameStr(), builder.getUnitAttr());
}
result.addTypes(iterate.getType());
result.addOperands(iterArgs);
Expand Down Expand Up @@ -1613,7 +1614,7 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
llvm::SmallVector<mlir::Type> argTypes;
// Induction variable (hidden)
if (prependCount)
result.addAttribute(IterWhileOp::getFinalValueAttrName(),
result.addAttribute(IterWhileOp::getFinalValueAttrNameStr(),
builder.getUnitAttr());
else
argTypes.push_back(indexType);
Expand Down Expand Up @@ -1707,7 +1708,7 @@ static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) {
p << " -> (" << op.getResultTypes() << ')';
}
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
{IterWhileOp::getFinalValueAttrName()});
{op.getFinalValueAttrNameStr()});
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
Expand Down Expand Up @@ -2056,24 +2057,24 @@ static mlir::ParseResult parseDTEntryOp(mlir::OpAsmParser &parser,
// allow `methodName` or `"methodName"`
if (failed(parser.parseOptionalKeyword(&methodName))) {
mlir::StringAttr methodAttr;
if (parser.parseAttribute(methodAttr, fir::DTEntryOp::getMethodAttrName(),
if (parser.parseAttribute(methodAttr,
fir::DTEntryOp::getMethodAttrNameStr(),
result.attributes))
return mlir::failure();
} else {
result.addAttribute(fir::DTEntryOp::getMethodAttrName(),
result.addAttribute(fir::DTEntryOp::getMethodAttrNameStr(),
parser.getBuilder().getStringAttr(methodName));
}
mlir::SymbolRefAttr calleeAttr;
if (parser.parseComma() ||
parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrName(),
parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrNameStr(),
result.attributes))
return mlir::failure();
return mlir::success();
}

static void print(mlir::OpAsmPrinter &p, fir::DTEntryOp &op) {
p << ' ' << op.getOperation()->getAttr(fir::DTEntryOp::getMethodAttrName())
<< ", " << op.getOperation()->getAttr(fir::DTEntryOp::getProcAttrName());
p << ' ' << op.getMethodAttr() << ", " << op.getProcAttr();
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/AffinePromotion.cpp
Expand Up @@ -325,7 +325,7 @@ static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,

static Optional<int64_t> constantIntegerLike(const mlir::Value value) {
if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>())
if (auto stepAttr = definition.value().dyn_cast<IntegerAttr>())
if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>())
return stepAttr.getInt();
return {};
}
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp
Expand Up @@ -333,7 +333,7 @@ ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) {
accesses.push_back(owner);
appendToQueue(update.getResult(1));
} else if (auto br = mlir::dyn_cast<mlir::BranchOp>(owner)) {
branchOp(br.getDest(), br.destOperands());
branchOp(br.getDest(), br.getDestOperands());
} else if (auto br = mlir::dyn_cast<mlir::CondBranchOp>(owner)) {
branchOp(br.getTrueDest(), br.getTrueOperands());
branchOp(br.getFalseDest(), br.getFalseOperands());
Expand Down

0 comments on commit 3012f35

Please sign in to comment.