diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index 9b894027a9f20..f589cb8715ff1 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -341,11 +341,30 @@ struct GetLengthOpConversion } }; -static bool allOtherUsesAreDestroys(mlir::Value value, - mlir::Operation *currentUse) { +/// The current hlfir.associate lowering does not handle multiple uses of a +/// non-trivial expression value because it generates the cleanup for the +/// expression bufferization at hlfir.end_associate. If there was more than one +/// hlfir.end_associate, it would be cleaned up multiple times, perhaps before +/// one of the other uses. +static bool allOtherUsesAreSafeForAssociate(mlir::Value value, + mlir::Operation *currentUse, + mlir::Operation *endAssociate) { for (mlir::Operation *useOp : value.getUsers()) - if (!mlir::isa(useOp) && useOp != currentUse) + if (!mlir::isa(useOp) && useOp != currentUse) { + // hlfir.shape_of will not disrupt cleanup so it is safe for + // hlfir.associate. hlfir.shape_of might read the box dimensions and so it + // needs to come before the hflir.end_associate (which may deallocate). + if (mlir::isa(useOp)) { + if (!endAssociate) + continue; + // not known to occur in practice: + if (useOp->getBlock() != endAssociate->getBlock()) + TODO(endAssociate->getLoc(), "Associate split over multiple blocks"); + if (useOp->isBeforeInBlock(endAssociate)) + continue; + } return false; + } return true; } @@ -370,6 +389,15 @@ struct AssociateOpConversion mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getSource()); const bool isTrivialValue = fir::isa_trivial(bufferizedExpr.getType()); + auto getEndAssociate = + [](hlfir::AssociateOp associate) -> mlir::Operation * { + for (mlir::Operation *useOp : associate->getUsers()) + if (mlir::isa(useOp)) + return useOp; + // happens in some hand coded mlir in tests + return nullptr; + }; + auto replaceWith = [&](mlir::Value hlfirVar, mlir::Value firVar, mlir::Value flag) { // 0-dim variables may need special handling: @@ -382,8 +410,8 @@ struct AssociateOpConversion // !fir.ref>, // i1) // - // !fir.box>> value must be propagated - // as the box address !fir.ref>. + // !fir.box>> value must be + // propagated as the box address !fir.ref>. mlir::Type associateHlfirVarType = associate.getResultTypes()[0]; if (hlfirVar.getType().isa() && !associateHlfirVarType.isa()) @@ -410,8 +438,9 @@ struct AssociateOpConversion // If this is the last use of the expression value and this is an hlfir.expr // that was bufferized, re-use the storage. // Otherwise, create a temp and assign the storage to it. - if (!isTrivialValue && allOtherUsesAreDestroys(associate.getSource(), - associate.getOperation())) { + if (!isTrivialValue && allOtherUsesAreSafeForAssociate( + associate.getSource(), associate.getOperation(), + getEndAssociate(associate))) { // Re-use hlfir.expr buffer if this is the only use of the hlfir.expr // outside of the hlfir.destroy. Take on the cleaning-up responsibility // for the related hlfir.end_associate, and erase the hlfir.destroy (if diff --git a/flang/test/HLFIR/associate-codegen.fir b/flang/test/HLFIR/associate-codegen.fir index 5127f78e783cc..7bd70ad58d8ac 100644 --- a/flang/test/HLFIR/associate-codegen.fir +++ b/flang/test/HLFIR/associate-codegen.fir @@ -194,6 +194,51 @@ func.func @test_0dim_box(%x : !fir.ref>>) { // CHECK: return // CHECK: } +// test that we support a hlfir.associate operation where the expr is also used in a hlfir.shape_of op +func.func @test_shape_of(%arg0: !fir.ref>) { + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %shape = fir.shape %c3, %c4 : (index, index) -> !fir.shape<2> + // %0 = hlfir.transpose %arg0 : (!fir.ref>) -> !hlfir.expr<3x4xi32> + %0 = hlfir.elemental %shape unordered : (!fir.shape<2>) -> !hlfir.expr<3x4xi32> { + ^bb0(%arg1: index, %arg2: index): + %4 = hlfir.designate %arg0 (%arg2, %arg1) : (!fir.ref>, index, index) -> !fir.ref + %5 = fir.load %4 : !fir.ref + hlfir.yield_element %5 : i32 + } + %1 = hlfir.shape_of %0 : (!hlfir.expr<3x4xi32>) -> !fir.shape<2> + %2:3 = hlfir.associate %0(%1) {uniq_name = "adapt.valuebyref"} : (!hlfir.expr<3x4xi32>, !fir.shape<2>) -> (!fir.ref>, !fir.ref>, i1) + // ... + hlfir.end_associate %2#1, %2#2 : !fir.ref>, i1 + return +} +// CHECK-LABEL: func.func @test_shape_of( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]], %[[VAL_1]] : (index, index) -> !fir.shape<2> +// CHECK: %[[VAL_4:.*]] = fir.allocmem !fir.array<3x4xi32> {bindc_name = ".tmp.array", uniq_name = ""} +// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap>, !fir.shape<2>) -> (!fir.heap>, !fir.heap>) +// CHECK: %[[VAL_6:.*]] = arith.constant true +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_7]] to %[[VAL_1]] step %[[VAL_7]] unordered { +// CHECK: fir.do_loop %[[VAL_9:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered { +// CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_8]], %[[VAL_9]]) : (!fir.ref>, index, index) -> !fir.ref +// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref +// CHECK: %[[VAL_12:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_9]], %[[VAL_8]]) : (!fir.heap>, index, index) -> !fir.ref +// CHECK: hlfir.assign %[[VAL_11]] to %[[VAL_12]] temporary_lhs : i32, !fir.ref +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_13:.*]] = fir.undefined tuple>, i1> +// CHECK: %[[VAL_14:.*]] = fir.insert_value %[[VAL_13]], %[[VAL_6]], [1 : index] : (tuple>, i1>, i1) -> tuple>, i1> +// CHECK: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_5]]#0, [0 : index] : (tuple>, i1>, !fir.heap>) -> tuple>, i1> +// CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_5]]#0 : (!fir.heap>) -> !fir.ref> +// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_5]]#1 : (!fir.heap>) -> !fir.ref> +// CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_17]] : (!fir.ref>) -> !fir.heap> +// CHECK: fir.freemem %[[VAL_18]] : !fir.heap> +// CHECK: return +// CHECK: } + func.func private @take_i4(!fir.ref) func.func private @take_r4(!fir.ref) func.func private @take_l4(!fir.ref>)