-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][OpenMP] Allocate array reduction variables on the heap #87773
[flang][OpenMP] Allocate array reduction variables on the heap #87773
Conversation
Following up on a review comment: llvm#84958 (comment) Reductions might be inlined inside of a loop so stack allocations are not safe. Normally flang allocates arrays on the stack. Allocatable arrays have a different type: fir.box<fir.heap<fir.array<...>>> instead of fir.box<fir.array<...>>. This patch will allocate all arrays on the heap. Reductions on allocatable arrays still aren't supported (but I will get to this soon).
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-openmp Author: Tom Eccles (tblah) ChangesFollowing up on a review comment: Reductions might be inlined inside of a loop so stack allocations are not safe. Normally flang allocates arrays on the stack. Allocatable arrays have a different type: fir.box<fir.heap<fir.array<...>>> instead of fir.box<fir.array<...>>. This patch will allocate all arrays on the heap. Reductions on allocatable arrays still aren't supported (but I will get to this soon). Patch is 24.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87773.diff 7 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c1c94119fd9083..7a4cefb420f997 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -18,6 +18,7 @@
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/FatalError.h"
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
@@ -379,8 +380,60 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
}
+static void
+createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::omp::DeclareReductionOp &reductionDecl) {
+ mlir::Type redTy = reductionDecl.getType();
+
+ mlir::Region &cleanupRegion = reductionDecl.getCleanupRegion();
+ assert(cleanupRegion.empty());
+ mlir::Block *block =
+ builder.createBlock(&cleanupRegion, cleanupRegion.end(), {redTy}, {loc});
+ builder.setInsertionPointToEnd(block);
+
+ auto typeError = [loc]() {
+ fir::emitFatalError(loc,
+ "Attempt to create an omp reduction cleanup region "
+ "for a type that wasn't allocated",
+ /*genCrashDiag=*/true);
+ };
+
+ mlir::Type valTy = fir::unwrapRefType(redTy);
+ if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) {
+ mlir::Type innerTy = fir::extractSequenceType(boxTy);
+ if (!mlir::isa<fir::SequenceType>(innerTy))
+ typeError();
+
+ mlir::Value arg = block->getArgument(0);
+ arg = builder.loadIfRef(loc, arg);
+ assert(mlir::isa<fir::BaseBoxType>(arg.getType()));
+
+ // Deallocate box
+ // The FIR type system doesn't nesecarrily know that this is a mutable box
+ // if we allocated the thread local array on the heap to avoid looped stack
+ // allocations.
+ mlir::Value addr =
+ hlfir::genVariableRawAddress(loc, builder, hlfir::Entity{arg});
+ mlir::Value isAllocated = builder.genIsNotNullAddr(loc, addr);
+ fir::IfOp ifOp =
+ builder.create<fir::IfOp>(loc, isAllocated, /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ mlir::Value cast = builder.createConvert(
+ loc, fir::HeapType::get(fir::dyn_cast_ptrEleTy(addr.getType())), addr);
+ builder.create<fir::FreeMemOp>(loc, cast);
+
+ builder.setInsertionPointAfter(ifOp);
+ builder.create<mlir::omp::YieldOp>(loc);
+ return;
+ }
+
+ typeError();
+}
+
static mlir::Value
createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::omp::DeclareReductionOp &reductionDecl,
const ReductionProcessor::ReductionIdentifier redId,
mlir::Type type, bool isByRef) {
mlir::Type ty = fir::unwrapRefType(type);
@@ -407,11 +460,21 @@ createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
// Create the private copy from the initial fir.box:
hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
- // TODO: if the whole reduction is nested inside of a loop, this alloca
- // could lead to a stack overflow (the memory is only freed at the end of
- // the stack frame). The reduction declare operation needs a deallocation
- // region to undo the init region.
- hlfir::Entity temp = createStackTempFromMold(loc, builder, source);
+ // Allocating on the heap in case the whole reduction is nested inside of a
+ // loop
+ auto [temp, needsDealloc] = createTempFromMold(loc, builder, source);
+ // if needsDealloc isn't statically false, add cleanup region. TODO: always
+ // do this for allocatable boxes because they might have been re-allocated
+ // in the body of the loop/parallel region
+ std::optional<int64_t> cstNeedsDealloc =
+ fir::getIntIfConstant(needsDealloc);
+ assert(cstNeedsDealloc.has_value() &&
+ "createTempFromMold decides this statically");
+ if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) {
+ auto insPt = builder.saveInsertionPoint();
+ createReductionCleanupRegion(builder, loc, reductionDecl);
+ builder.restoreInsertionPoint(insPt);
+ }
// Put the temporary inside of a box:
hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp);
@@ -450,7 +513,7 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
mlir::Value init =
- createReductionInitRegion(builder, loc, redId, type, isByRef);
+ createReductionInitRegion(builder, loc, decl, redId, type, isByRef);
builder.create<mlir::omp::YieldOp>(loc, init);
builder.createBlock(&decl.getReductionRegion(),
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
index 56dcabbb75c3a8..26c9d4f0850964 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
@@ -15,13 +15,15 @@ program reduce
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_3xi32 : !fir.ref<!fir.box<!fir.array<3xi32>>> init {
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
-! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.array<3xi32> {bindc_name = ".tmp"}
! CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
! CHECK: %[[VAL_4:.*]] = arith.constant 3 : index
! CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
-! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<3xi32>>, !fir.ref<!fir.array<3xi32>>)
-! CHECK: %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
+! CHECK: %[[VAL_1:.*]] = fir.allocmem !fir.array<3xi32> {bindc_name = ".tmp", uniq_name = ""}
+! CHECK: %[[TRUE:.*]] = arith.constant true
+! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<3xi32>>,
+!fir.shape<1>) -> (!fir.heap<!fir.array<3xi32>>, !fir.heap<!fir.array<3xi32>>)
+! CHECK: %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.heap<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
! CHECK: hlfir.assign %[[VAL_2]] to %[[VAL_7]] : i32, !fir.box<!fir.array<3xi32>>
! CHECK: %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
! CHECK: fir.store %[[VAL_7]] to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
@@ -43,6 +45,18 @@ program reduce
! CHECK: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
! CHECK: }
! CHECK: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
+! CHECK: } cleanup {
+! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
+! CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
+! CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>) -> !fir.ref<!fir.array<3xi32>>
+! CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> i64
+! CHECK: %[[VAL_4:.*]] = arith.constant 0 : i64
+! CHECK: %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
+! CHECK: fir.if %[[VAL_5]] {
+! CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> !fir.heap<!fir.array<3xi32>>
+! CHECK: fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<3xi32>>
+! CHECK: }
+! CHECK: omp.yield
! CHECK: }
! CHECK-LABEL: func.func @_QQmain()
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
index 94bff410a2f0d7..bed04401248bed 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
@@ -15,13 +15,15 @@ program reduce
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_3xi32 : !fir.ref<!fir.box<!fir.array<3xi32>>> init {
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
-! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.array<3xi32> {bindc_name = ".tmp"}
! CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
! CHECK: %[[VAL_4:.*]] = arith.constant 3 : index
! CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
-! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<3xi32>>, !fir.ref<!fir.array<3xi32>>)
-! CHECK: %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
+! CHECK: %[[VAL_1:.*]] = fir.allocmem !fir.array<3xi32>
+! CHECK: %[[TRUE:.*]] = arith.constant true
+! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<3xi32>>,
+!fir.shape<1>) -> (!fir.heap<!fir.array<3xi32>>, !fir.heap<!fir.array<3xi32>>)
+! CHECK: %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.heap<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
! CHECK: hlfir.assign %[[VAL_2]] to %[[VAL_7]] : i32, !fir.box<!fir.array<3xi32>>
! CHECK: %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
! CHECK: fir.store %[[VAL_7]] to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
@@ -43,6 +45,18 @@ program reduce
! CHECK: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
! CHECK: }
! CHECK: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
+! CHECK: } cleanup {
+! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
+! CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
+! CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>) -> !fir.ref<!fir.array<3xi32>>
+! CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> i64
+! CHECK: %[[VAL_4:.*]] = arith.constant 0 : i64
+! CHECK: %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
+! CHECK: fir.if %[[VAL_5]] {
+! CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> !fir.heap<!fir.array<3xi32>>
+! CHECK: fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<3xi32>>
+! CHECK: }
+! CHECK: omp.yield
! CHECK: }
! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "reduce"} {
diff --git a/flang/test/Lower/OpenMP/parallel-reduction3.f90 b/flang/test/Lower/OpenMP/parallel-reduction3.f90
index b25759713e318e..ce6bd17265ddba 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction3.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction3.f90
@@ -1,15 +1,6 @@
-! NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-
-! The script is designed to make adding checks to
-! a test case fast, it is *not* designed to be authoritative
-! about what constitutes a good test! The CHECK should be
-! minimized and named to reflect the test intent.
-
! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_Uxi32 : !fir.ref<!fir.box<!fir.array<?xi32>>> init {
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>):
! CHECK: %[[VAL_1:.*]] = arith.constant 0 : i32
@@ -17,14 +8,14 @@
! CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_4:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
! CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]]#1 : (index) -> !fir.shape<1>
-! CHECK: %[[VAL_6:.*]] = fir.alloca !fir.array<?xi32>, %[[VAL_4]]#1 {bindc_name = ".tmp"}
-! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?xi32>>)
+! CHECK: %[[VAL_6:.*]] = fir.allocmem !fir.array<?xi32>, %[[VAL_4]]#1 {bindc_name = ".tmp", uniq_name = ""}
+! CHECK: %[[TRUE:.*]] = arith.constant true
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.heap<!fir.array<?xi32>>)
! CHECK: hlfir.assign %[[VAL_1]] to %[[VAL_7]]#0 : i32, !fir.box<!fir.array<?xi32>>
! CHECK: %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
! CHECK: fir.store %[[VAL_7]]#0 to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
! CHECK: omp.yield(%[[VAL_8]] : !fir.ref<!fir.box<!fir.array<?xi32>>>)
-
-! CHECK-LABEL: } combiner {
+! CHECK: } combiner {
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>, %[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>):
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
@@ -41,6 +32,18 @@
! CHECK: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
! CHECK: }
! CHECK: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>)
+! CHECK: } cleanup {
+! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>):
+! CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
+! CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
+! CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xi32>>) -> i64
+! CHECK: %[[VAL_4:.*]] = arith.constant 0 : i64
+! CHECK: %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
+! CHECK: fir.if %[[VAL_5]] {
+! CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xi32>>) -> !fir.heap<!fir.array<?xi32>>
+! CHECK: fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<?xi32>>
+! CHECK: }
+! CHECK: omp.yield
! CHECK: }
! CHECK-LABEL: func.func @_QPs(
@@ -122,4 +125,4 @@ subroutine s(x)
!$omp end parallel do
if (c(1) /= 5050) stop 1
-end subroutine s
\ No newline at end of file
+end subroutine s
diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90
index a1f339faea5cd5..8f83a30c9fe782 100644
--- a/flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90
+++ b/flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90
@@ -29,8 +29,9 @@ subroutine reduce(r)
! CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_4:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
! CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]]#1 : (index) -> !fir.shape<1>
-! CHECK: %[[VAL_6:.*]] = fir.alloca !fir.array<?xf64>, %[[VAL_4]]#1 {bindc_name = ".tmp"}
-! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<?xf64>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf64>>, !fir.ref<!fir.array<?xf64>>)
+! CHECK: %[[VAL_6:.*]] = fir.allocmem !fir.array<?xf64>, %[[VAL_4]]#1 {bindc_name = ".tmp", uniq_name = ""}
+! CHECK: %[[TRUE:.*]] = arith.constant true
+! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xf64>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf64>>, !fir.heap<!fir.array<?xf64>>)
! CHECK: hlfir.assign %[[VAL_1]] to %[[VAL_7]]#0 : f64, !fir.box<!fir.array<?xf64>>
! CHECK: %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<?xf64>>
! CHECK: fir.store %[[VAL_7]]#0 to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
@@ -53,6 +54,18 @@ subroutine reduce(r)
! CHECK: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<f64>
! CHECK: }
! CHECK: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xf64>>>)
+! CHECK: } cleanup {
+! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xf64>>>):
+! CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
+! CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<?xf64>>) -> !fir.ref<!fir.array<?xf64>>
+! CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf64>>) -> i64
+! CHECK: %[[VAL_4:.*]] = arith.constant 0 : i64
+! CHECK: %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
+! CHECK: fir.if %[[VAL_5]] {
+! CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf64>>) -> !fir.heap<!fir.array<?xf64>>
+! CHECK: fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<?xf64>>
+! CHECK: }
+! CHECK: omp.yield
! CHECK: }
! CHECK-LABEL: func.func private @_QFPreduce(
diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-array.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-array.f90
index a898204c881d9b..a08bca9eb283b5 100644
--- a/flang/test/Lower/OpenMP/wsloop-reduction-array.f90
+++ b/flang/test/Lower/OpenMP/wsloop-reduction-array.f90
@@ -16,13 +16,14 @@ program reduce
! CHECK-LABEL omp.declare_reduction @add_reduction_byref_box_2xi32 : !fir.ref<!fir.box<!fir.array<2xi32>>> init {
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<2xi32>>>):
-! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.array<2xi32> {bindc_name = ".tmp"}
! CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<2xi32>>>
! CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
! CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
-! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<2xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<2xi32>>, !fir.ref<!fir.array<2xi32>>)
-! CHECK: %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.ref<!fir.array<2xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<2xi32>>
+! CHECK: %[[VAL_1:.*]] = fir.allocmem !fir.array<2xi32> {bindc_name = ".tmp", uniq_name = ""}
+! CHECK: %[[TRUE:.*]] = arith.constant true
+! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<2xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<2xi32>>, !fir.heap<!fir.array<2xi32>>)
+! CHECK: %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.heap<!fir.array<2xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<2xi32>>
! CHECK: hlfir.assign %[[VAL_2]] to %[[VAL_7]] : i32, !fir.box<!fir.array<2xi32>>
! CHECK: %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<2xi32>>
! CHECK: fir.store %[[VAL_7]] to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<2xi32>>>
@@ -45,6 +46,18 @@ program reduce
! CHECK: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
! CHECK: }
! CHECK: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<2xi32>>>)
+! CHECK: } cleanup {
+! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<2xi32>>>):
+! CHECK: %[[VAL_1:.*...
[truncated]
|
Allocating on the stack vs heap has some tradeoffs with OpenMP. If there are a large number of threads and if the memory required is high then there is a possibility that we run out of stack space. On the other hand, allocating on the heap probably takes more time. The runtime probably uses locks so I guess it is safe, but please check.
We could use stacksave and stackrestore as well. This was the original suggestion in the patch. |
Stacksave/stackrestore would be awkward to implement because there isn't a convenient way to pass the saved stack pointer into the cleanup region. Adding an additional argument to the OpenMP operation definition feels very flang-specific. Maybe the stack save/restore should be added in OpenMPIRBuilder?
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG.
Could you add a TODO to do a performance comparison with the alloca version. Most allocas will be hoisted. Only those allocas which are variable length would need the stacksave and restore. stacksave and restore can be inserted when the reduction declaration is processed.
Following up on a review comment:
#84958 (comment)
Reductions might be inlined inside of a loop so stack allocations are not safe.
Normally flang allocates arrays on the stack. Allocatable arrays have a different type: fir.box<fir.heap<fir.array<...>>> instead of fir.box<fir.array<...>>. This patch will allocate all arrays on the heap.
Reductions on allocatable arrays still aren't supported (but I will get to this soon).