diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index f826f2566b897..483f318b649d9 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -20,6 +20,7 @@ #include "flang/Optimizer/Support/Utils.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" @@ -36,6 +37,18 @@ namespace { #include "flang/Optimizer/Dialect/CanonicalizationPatterns.inc" } // namespace +static void propagateAttributes(mlir::Operation *fromOp, + mlir::Operation *toOp) { + if (!fromOp || !toOp) + return; + + for (mlir::NamedAttribute attr : fromOp->getAttrs()) { + if (attr.getName().getValue().starts_with( + mlir::acc::OpenACCDialect::getDialectNamespace())) + toOp->setAttr(attr.getName(), attr.getValue()); + } +} + /// Return true if a sequence type is of some incomplete size or a record type /// is malformed or contains an incomplete sequence type. An incomplete sequence /// type is one with more unknown extents in the type than have been provided @@ -626,8 +639,10 @@ mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) { if (auto *v = getVal().getDefiningOp()) { if (auto box = mlir::dyn_cast(v)) { // Fold only if not sliced - if (!box.getSlice() && box.getMemref().getType() == getType()) + if (!box.getSlice() && box.getMemref().getType() == getType()) { + propagateAttributes(getOperation(), box.getMemref().getDefiningOp()); return box.getMemref(); + } } if (auto box = mlir::dyn_cast(v)) if (box.getMemref().getType() == getType()) diff --git a/flang/test/Fir/OpenACC/propagate-attr-folding.fir b/flang/test/Fir/OpenACC/propagate-attr-folding.fir new file mode 100644 index 0000000000000..99ac7516690e2 --- /dev/null +++ b/flang/test/Fir/OpenACC/propagate-attr-folding.fir @@ -0,0 +1,40 @@ +// RUN: fir-opt %s --opt-bufferization | FileCheck %s + +// Check that OpenACC attributes are propagated to the defining operations when +// fir.box_addr is folded in bufferization optimization. + +func.func @_QPsub1(%arg0: !fir.ref> {fir.bindc_name = "a"}, %arg1: !fir.ref {fir.bindc_name = "n1"}, %arg2: !fir.ref {fir.bindc_name = "n2"}) { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = fir.declare %arg1 {uniq_name = "_QFsub1En1"} : (!fir.ref) -> !fir.ref + %1 = fir.declare %arg2 {uniq_name = "_QFsub1En2"} : (!fir.ref) -> !fir.ref + %2 = fir.load %0 : !fir.ref + %3 = fir.convert %2 : (i32) -> index + %4 = arith.cmpi sgt, %3, %c0 : index + %5 = arith.select %4, %3, %c0 : index + %6 = fir.load %1 : !fir.ref + %7 = fir.convert %6 : (i32) -> index + %8 = arith.cmpi sgt, %7, %c0 : index + %9 = arith.select %8, %7, %c0 : index + %10 = fir.shape %5, %9 : (index, index) -> !fir.shape<2> + %11 = fir.declare %arg0(%10) {uniq_name = "_QFsub1Ea"} : (!fir.ref>, !fir.shape<2>) -> !fir.ref> + %12 = fir.embox %11(%10) : (!fir.ref>, !fir.shape<2>) -> !fir.box> + %13:3 = fir.box_dims %12, %c0 : (!fir.box>, index) -> (index, index, index) + %14 = arith.subi %13#1, %c1 : index + %15 = acc.bounds lowerbound(%c0 : index) upperbound(%14 : index) extent(%13#1 : index) stride(%13#2 : index) startIdx(%c1 : index) {strideInBytes = true} + %16 = arith.muli %13#2, %13#1 : index + %17:3 = fir.box_dims %12, %c1 : (!fir.box>, index) -> (index, index, index) + %18 = arith.subi %17#1, %c1 : index + %19 = acc.bounds lowerbound(%c0 : index) upperbound(%18 : index) extent(%17#1 : index) stride(%16 : index) startIdx(%c1 : index) {strideInBytes = true} + %20 = fir.box_addr %12 {acc.declare = #acc.declare} : (!fir.box>) -> !fir.ref> + %21 = acc.present varPtr(%20 : !fir.ref>) bounds(%15, %19) -> !fir.ref> {name = "a"} + %22 = acc.declare_enter dataOperands(%21 : !fir.ref>) + acc.declare_exit token(%22) + return +} + +// CHECK-LABEL: func.func @_QPsub1( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref> {fir.bindc_name = "a"} +// CHECK: %[[DECL:.*]] = fir.declare %[[ARG0]](%{{.*}}) {acc.declare = #acc.declare, uniq_name = "_QFsub1Ea"} : (!fir.ref>, !fir.shape<2>) -> !fir.ref> +// CHECK: %[[PRES:.*]] = acc.present varPtr(%[[DECL]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {name = "a"} +// CHECK: %{{.*}} = acc.declare_enter dataOperands(%[[PRES]] : !fir.ref>)