diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp index b4767cdb25072..246eedf3a515d 100644 --- a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp +++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp @@ -12,6 +12,7 @@ #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Factory.h" +#include "flang/Optimizer/Builder/Runtime/Derived.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Support/FIRContext.h" @@ -1008,6 +1009,50 @@ findNonconstantExtents(mlir::Type memrefTy, return nce; } +/// Allocate temporary storage for an ArrayLoadOp \load and initialize any +/// allocatable direct components of the array elements with an unallocated +/// status. Returns the temporary address as well as a callback to generate the +/// temporary clean-up once it has been used. The clean-up will take care of +/// deallocating all the element allocatable components that may have been +/// allocated while using the temporary. +static std::pair> +allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter, + ArrayLoadOp load, llvm::ArrayRef extents, + mlir::Value shape) { + mlir::Type baseType = load.getMemref().getType(); + llvm::SmallVector nonconstantExtents = + findNonconstantExtents(baseType, extents); + llvm::SmallVector typeParams = + genArrayLoadTypeParameters(loc, rewriter, load); + mlir::Value allocmem = rewriter.create( + loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents); + mlir::Type eleType = + fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType)); + if (fir::isRecordWithAllocatableMember(eleType)) { + // The allocatable component descriptors need to be set to a clean + // deallocated status before anything is done with them. + mlir::Value box = rewriter.create( + loc, fir::BoxType::get(baseType), allocmem, shape, + /*slice=*/mlir::Value{}, typeParams); + auto module = load->getParentOfType(); + FirOpBuilder builder(rewriter, getKindMapping(module)); + runtime::genDerivedTypeInitialize(builder, loc, box); + // Any allocatable component that may have been allocated must be + // deallocated during the clean-up. + auto cleanup = [=](mlir::PatternRewriter &r) { + FirOpBuilder builder(r, getKindMapping(module)); + runtime::genDerivedTypeDestroy(builder, loc, box); + r.create(loc, allocmem); + }; + return {allocmem, cleanup}; + } + auto cleanup = [=](mlir::PatternRewriter &r) { + r.create(loc, allocmem); + }; + return {allocmem, cleanup}; +} + namespace { /// Conversion of fir.array_update and fir.array_modify Ops. /// If there is a conflict for the update, then we need to perform a @@ -1039,11 +1084,8 @@ class ArrayUpdateConversionBase : public mlir::OpRewritePattern { bool copyUsingSlice = false; auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, copyUsingSlice); - llvm::SmallVector nonconstantExtents = - findNonconstantExtents(load.getMemref().getType(), extents); - auto allocmem = rewriter.create( - loc, dyn_cast_ptrOrBoxEleTy(load.getMemref().getType()), - genArrayLoadTypeParameters(loc, rewriter, load), nonconstantExtents); + auto [allocmem, genTempCleanUp] = + allocateArrayTemp(loc, rewriter, load, extents, shapeOp); genArrayCopy(load.getLoc(), rewriter, allocmem, load.getMemref(), shapeOp, load.getSlice(), load); @@ -1061,7 +1103,7 @@ class ArrayUpdateConversionBase : public mlir::OpRewritePattern { // Copy out. genArrayCopy(store.getLoc(), rewriter, store.getMemref(), allocmem, shapeOp, store.getSlice(), load); - rewriter.create(loc, allocmem); + genTempCleanUp(rewriter); return coor; } @@ -1091,11 +1133,9 @@ class ArrayUpdateConversionBase : public mlir::OpRewritePattern { bool copyUsingSlice = false; auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, copyUsingSlice); - llvm::SmallVector nonconstantExtents = - findNonconstantExtents(load.getMemref().getType(), extents); - auto allocmem = rewriter.create( - loc, dyn_cast_ptrOrBoxEleTy(load.getMemref().getType()), - genArrayLoadTypeParameters(loc, rewriter, load), nonconstantExtents); + auto [allocmem, genTempCleanUp] = + allocateArrayTemp(loc, rewriter, load, extents, shapeOp); + genArrayCopy(load.getLoc(), rewriter, allocmem, load.getMemref(), shapeOp, load.getSlice(), load); @@ -1113,7 +1153,7 @@ class ArrayUpdateConversionBase : public mlir::OpRewritePattern { genArrayCopy(store.getLoc(), rewriter, store.getMemref(), allocmem, shapeOp, store.getSlice(), load); - rewriter.create(loc, allocmem); + genTempCleanUp(rewriter); return {coor, load.getResult()}; } // Otherwise, when there is no conflict (a possible loop-carried diff --git a/flang/test/Fir/array-value-copy-3.fir b/flang/test/Fir/array-value-copy-3.fir new file mode 100644 index 0000000000000..457ba341d62d9 --- /dev/null +++ b/flang/test/Fir/array-value-copy-3.fir @@ -0,0 +1,55 @@ +// Test overlapping assignment of derived type arrays with allocatable components. +// This requires initializing the allocatable components to an unallocated status +// before they can be used in component assignments, and to deallocate the components +// that may have been allocated in the end. + +// RUN: fir-opt --array-value-copy %s | FileCheck %s + + +!t_with_alloc_comp = type !fir.type>>}> +func private @custom_assign(!fir.ref, !fir.ref) +func @test_overlap_with_alloc_components(%arg0: !fir.ref>) { + %0 = fir.alloca !fir.box + %c10 = arith.constant 10 : index + %c9 = arith.constant 9 : index + %c1 = arith.constant 1 : index + %c-1 = arith.constant -1 : index + %c0 = arith.constant 0 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %6 = fir.slice %c10, %c1, %c-1 : (index, index, index) -> !fir.slice<1> + %2 = fir.array_load %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.array<10x!t_with_alloc_comp> + %7 = fir.array_load %arg0(%1) [%6] : (!fir.ref>, !fir.shape<1>, !fir.slice<1>) -> !fir.array<10x!t_with_alloc_comp> + %9 = fir.do_loop %arg1 = %c0 to %c9 step %c1 unordered iter_args(%arg2 = %2) -> (!fir.array<10x!t_with_alloc_comp>) { + %10 = fir.array_access %7, %arg1 : (!fir.array<10x!t_with_alloc_comp>, index) -> !fir.ref + %11 = fir.array_access %arg2, %arg1 : (!fir.array<10x!t_with_alloc_comp>, index) -> !fir.ref + fir.call @custom_assign(%11, %10) : (!fir.ref, !fir.ref) -> none + %19 = fir.array_amend %arg2, %11 : (!fir.array<10x!t_with_alloc_comp>, !fir.ref) -> !fir.array<10x!t_with_alloc_comp> + fir.result %19 : !fir.array<10x!t_with_alloc_comp> + } + fir.array_merge_store %2, %9 to %arg0 : !fir.array<10x!t_with_alloc_comp>, !fir.array<10x!t_with_alloc_comp>, !fir.ref> + return +} + +// CHECK-LABEL: func @test_overlap_with_alloc_components( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>>}>>>) { +// CHECK: %[[VAL_4:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = arith.constant -1 : index +// CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_10:.*]] = fir.slice %[[VAL_4]], %[[VAL_6]], %[[VAL_7]] : (index, index, index) -> !fir.slice<1> +// CHECK: %[[VAL_11:.*]] = fir.allocmem !fir.array<10x!fir.type>>}>> +// CHECK: %[[VAL_12:.*]] = fir.embox %[[VAL_11]](%[[VAL_9]]) : (!fir.heap>>}>>>, !fir.shape<1>) -> !fir.box>>}>>>> +// CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_12]] : (!fir.box>>}>>>>) -> !fir.box +// CHECK: fir.call @_FortranAInitialize(%[[VAL_16]], %{{.*}}, %{{.*}}) : (!fir.box, !fir.ref, i32) -> none +// CHECK: fir.do_loop {{.*}} { +// CHECK: fir.call @_FortranAAssign +// CHECK: } +// CHECK: fir.do_loop {{.*}} { +// CHECK: fir.call @custom_assign +// CHECK: } +// CHECK: fir.do_loop %{{.*}} { +// CHECK: fir.call @_FortranAAssign +// CHECK: } +// CHECK: %[[VAL_72:.*]] = fir.convert %[[VAL_12]] : (!fir.box>>}>>>>) -> !fir.box +// CHECK: %[[VAL_73:.*]] = fir.call @_FortranADestroy(%[[VAL_72]]) : (!fir.box) -> none +// CHECK: fir.freemem %[[VAL_11]]