Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][OpenMP] Allocate array reduction variables on the heap #87773

Merged
merged 2 commits into from
Apr 11, 2024

Conversation

tblah
Copy link
Contributor

@tblah tblah commented Apr 5, 2024

Following up on a review comment:
#84958 (comment)

Reductions might be inlined inside of a loop so stack allocations are not safe.

Normally flang allocates arrays on the stack. Allocatable arrays have a different type: fir.box<fir.heap<fir.array<...>>> instead of fir.box<fir.array<...>>. This patch will allocate all arrays on the heap.

Reductions on allocatable arrays still aren't supported (but I will get to this soon).

Following up on a review comment:
llvm#84958 (comment)

Reductions might be inlined inside of a loop so stack allocations are
not safe.

Normally flang allocates arrays on the stack. Allocatable arrays have a
different type: fir.box<fir.heap<fir.array<...>>> instead of
fir.box<fir.array<...>>. This patch will allocate all arrays on the heap.

Reductions on allocatable arrays still aren't supported (but I will get
to this soon).
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Tom Eccles (tblah)

Changes

Following up on a review comment:
#84958 (comment)

Reductions might be inlined inside of a loop so stack allocations are not safe.

Normally flang allocates arrays on the stack. Allocatable arrays have a different type: fir.box<fir.heap<fir.array<...>>> instead of fir.box<fir.array<...>>. This patch will allocate all arrays on the heap.

Reductions on allocatable arrays still aren't supported (but I will get to this soon).


Patch is 24.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87773.diff

7 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+69-6)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-array.f90 (+17-3)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-array2.f90 (+17-3)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction3.f90 (+17-14)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90 (+15-2)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array.f90 (+16-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array2.f90 (+17-4)
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c1c94119fd9083..7a4cefb420f997 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -18,6 +18,7 @@
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/FatalError.h"
 #include "flang/Parser/tools.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "llvm/Support/CommandLine.h"
@@ -379,8 +380,60 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
   TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
 }
 
+static void
+createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc,
+                             mlir::omp::DeclareReductionOp &reductionDecl) {
+  mlir::Type redTy = reductionDecl.getType();
+
+  mlir::Region &cleanupRegion = reductionDecl.getCleanupRegion();
+  assert(cleanupRegion.empty());
+  mlir::Block *block =
+      builder.createBlock(&cleanupRegion, cleanupRegion.end(), {redTy}, {loc});
+  builder.setInsertionPointToEnd(block);
+
+  auto typeError = [loc]() {
+    fir::emitFatalError(loc,
+                        "Attempt to create an omp reduction cleanup region "
+                        "for a type that wasn't allocated",
+                        /*genCrashDiag=*/true);
+  };
+
+  mlir::Type valTy = fir::unwrapRefType(redTy);
+  if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) {
+    mlir::Type innerTy = fir::extractSequenceType(boxTy);
+    if (!mlir::isa<fir::SequenceType>(innerTy))
+      typeError();
+
+    mlir::Value arg = block->getArgument(0);
+    arg = builder.loadIfRef(loc, arg);
+    assert(mlir::isa<fir::BaseBoxType>(arg.getType()));
+
+    // Deallocate box
+    // The FIR type system doesn't nesecarrily know that this is a mutable box
+    // if we allocated the thread local array on the heap to avoid looped stack
+    // allocations.
+    mlir::Value addr =
+        hlfir::genVariableRawAddress(loc, builder, hlfir::Entity{arg});
+    mlir::Value isAllocated = builder.genIsNotNullAddr(loc, addr);
+    fir::IfOp ifOp =
+        builder.create<fir::IfOp>(loc, isAllocated, /*withElseRegion=*/false);
+    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+    mlir::Value cast = builder.createConvert(
+        loc, fir::HeapType::get(fir::dyn_cast_ptrEleTy(addr.getType())), addr);
+    builder.create<fir::FreeMemOp>(loc, cast);
+
+    builder.setInsertionPointAfter(ifOp);
+    builder.create<mlir::omp::YieldOp>(loc);
+    return;
+  }
+
+  typeError();
+}
+
 static mlir::Value
 createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
+                          mlir::omp::DeclareReductionOp &reductionDecl,
                           const ReductionProcessor::ReductionIdentifier redId,
                           mlir::Type type, bool isByRef) {
   mlir::Type ty = fir::unwrapRefType(type);
@@ -407,11 +460,21 @@ createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
     // Create the private copy from the initial fir.box:
     hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
 
-    // TODO: if the whole reduction is nested inside of a loop, this alloca
-    // could lead to a stack overflow (the memory is only freed at the end of
-    // the stack frame). The reduction declare operation needs a deallocation
-    // region to undo the init region.
-    hlfir::Entity temp = createStackTempFromMold(loc, builder, source);
+    // Allocating on the heap in case the whole reduction is nested inside of a
+    // loop
+    auto [temp, needsDealloc] = createTempFromMold(loc, builder, source);
+    // if needsDealloc isn't statically false, add cleanup region. TODO: always
+    // do this for allocatable boxes because they might have been re-allocated
+    // in the body of the loop/parallel region
+    std::optional<int64_t> cstNeedsDealloc =
+        fir::getIntIfConstant(needsDealloc);
+    assert(cstNeedsDealloc.has_value() &&
+           "createTempFromMold decides this statically");
+    if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) {
+      auto insPt = builder.saveInsertionPoint();
+      createReductionCleanupRegion(builder, loc, reductionDecl);
+      builder.restoreInsertionPoint(insPt);
+    }
 
     // Put the temporary inside of a box:
     hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp);
@@ -450,7 +513,7 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
 
   mlir::Value init =
-      createReductionInitRegion(builder, loc, redId, type, isByRef);
+      createReductionInitRegion(builder, loc, decl, redId, type, isByRef);
   builder.create<mlir::omp::YieldOp>(loc, init);
 
   builder.createBlock(&decl.getReductionRegion(),
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
index 56dcabbb75c3a8..26c9d4f0850964 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
@@ -15,13 +15,15 @@ program reduce
 
 ! CHECK-LABEL:   omp.declare_reduction @add_reduction_byref_box_3xi32 : !fir.ref<!fir.box<!fir.array<3xi32>>> init {
 ! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
-! CHECK:           %[[VAL_1:.*]] = fir.alloca !fir.array<3xi32> {bindc_name = ".tmp"}
 ! CHECK:           %[[VAL_2:.*]] = arith.constant 0 : i32
 ! CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
 ! CHECK:           %[[VAL_4:.*]] = arith.constant 3 : index
 ! CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
-! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<3xi32>>, !fir.ref<!fir.array<3xi32>>)
-! CHECK:           %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
+! CHECK:           %[[VAL_1:.*]] = fir.allocmem !fir.array<3xi32> {bindc_name = ".tmp", uniq_name = ""}
+! CHECK:           %[[TRUE:.*]]  = arith.constant true
+! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<3xi32>>,
+!fir.shape<1>) -> (!fir.heap<!fir.array<3xi32>>, !fir.heap<!fir.array<3xi32>>)
+! CHECK:           %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.heap<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
 ! CHECK:           hlfir.assign %[[VAL_2]] to %[[VAL_7]] : i32, !fir.box<!fir.array<3xi32>>
 ! CHECK:           %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
 ! CHECK:           fir.store %[[VAL_7]] to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
@@ -43,6 +45,18 @@ program reduce
 ! CHECK:             fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
 ! CHECK:           }
 ! CHECK:           omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
+! CHECK:         }  cleanup {
+! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
+! CHECK:           %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
+! CHECK:           %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>) -> !fir.ref<!fir.array<3xi32>>
+! CHECK:           %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> i64
+! CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i64
+! CHECK:           %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
+! CHECK:           fir.if %[[VAL_5]] {
+! CHECK:             %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> !fir.heap<!fir.array<3xi32>>
+! CHECK:             fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<3xi32>>
+! CHECK:           }
+! CHECK:           omp.yield
 ! CHECK:         }
 
 ! CHECK-LABEL:   func.func @_QQmain()
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
index 94bff410a2f0d7..bed04401248bed 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
@@ -15,13 +15,15 @@ program reduce
 
 ! CHECK-LABEL:   omp.declare_reduction @add_reduction_byref_box_3xi32 : !fir.ref<!fir.box<!fir.array<3xi32>>> init {
 ! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
-! CHECK:           %[[VAL_1:.*]] = fir.alloca !fir.array<3xi32> {bindc_name = ".tmp"}
 ! CHECK:           %[[VAL_2:.*]] = arith.constant 0 : i32
 ! CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
 ! CHECK:           %[[VAL_4:.*]] = arith.constant 3 : index
 ! CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
-! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<3xi32>>, !fir.ref<!fir.array<3xi32>>)
-! CHECK:           %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
+! CHECK:           %[[VAL_1:.*]] = fir.allocmem !fir.array<3xi32>
+! CHECK:           %[[TRUE:.*]]  = arith.constant true
+! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<3xi32>>,
+!fir.shape<1>) -> (!fir.heap<!fir.array<3xi32>>, !fir.heap<!fir.array<3xi32>>)
+! CHECK:           %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.heap<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
 ! CHECK:           hlfir.assign %[[VAL_2]] to %[[VAL_7]] : i32, !fir.box<!fir.array<3xi32>>
 ! CHECK:           %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
 ! CHECK:           fir.store %[[VAL_7]] to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
@@ -43,6 +45,18 @@ program reduce
 ! CHECK:             fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
 ! CHECK:           }
 ! CHECK:           omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
+! CHECK:         }  cleanup {
+! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
+! CHECK:           %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
+! CHECK:           %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>) -> !fir.ref<!fir.array<3xi32>>
+! CHECK:           %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> i64
+! CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i64
+! CHECK:           %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
+! CHECK:           fir.if %[[VAL_5]] {
+! CHECK:             %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> !fir.heap<!fir.array<3xi32>>
+! CHECK:             fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<3xi32>>
+! CHECK:           }
+! CHECK:           omp.yield
 ! CHECK:         }
 
 ! CHECK-LABEL:   func.func @_QQmain() attributes {fir.bindc_name = "reduce"} {
diff --git a/flang/test/Lower/OpenMP/parallel-reduction3.f90 b/flang/test/Lower/OpenMP/parallel-reduction3.f90
index b25759713e318e..ce6bd17265ddba 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction3.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction3.f90
@@ -1,15 +1,6 @@
-! NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
-
-! The script is designed to make adding checks to
-! a test case fast, it is *not* designed to be authoritative
-! about what constitutes a good test! The CHECK should be
-! minimized and named to reflect the test intent.
-
 ! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
 
-
-
 ! CHECK-LABEL:   omp.declare_reduction @add_reduction_byref_box_Uxi32 : !fir.ref<!fir.box<!fir.array<?xi32>>> init {
 ! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>):
 ! CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i32
@@ -17,14 +8,14 @@
 ! CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 ! CHECK:           %[[VAL_4:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 ! CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]]#1 : (index) -> !fir.shape<1>
-! CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.array<?xi32>, %[[VAL_4]]#1 {bindc_name = ".tmp"}
-! CHECK:           %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?xi32>>)
+! CHECK:           %[[VAL_6:.*]] = fir.allocmem !fir.array<?xi32>, %[[VAL_4]]#1 {bindc_name = ".tmp", uniq_name = ""}
+! CHECK:           %[[TRUE:.*]]  = arith.constant true
+! CHECK:           %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.heap<!fir.array<?xi32>>)
 ! CHECK:           hlfir.assign %[[VAL_1]] to %[[VAL_7]]#0 : i32, !fir.box<!fir.array<?xi32>>
 ! CHECK:           %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
 ! CHECK:           fir.store %[[VAL_7]]#0 to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
 ! CHECK:           omp.yield(%[[VAL_8]] : !fir.ref<!fir.box<!fir.array<?xi32>>>)
-
-! CHECK-LABEL:   } combiner {
+! CHECK:         } combiner {
 ! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>, %[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>):
 ! CHECK:           %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
 ! CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
@@ -41,6 +32,18 @@
 ! CHECK:             fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
 ! CHECK:           }
 ! CHECK:           omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>)
+! CHECK:         }  cleanup {
+! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>):
+! CHECK:           %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
+! CHECK:           %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
+! CHECK:           %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xi32>>) -> i64
+! CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i64
+! CHECK:           %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
+! CHECK:           fir.if %[[VAL_5]] {
+! CHECK:             %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xi32>>) -> !fir.heap<!fir.array<?xi32>>
+! CHECK:             fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<?xi32>>
+! CHECK:           }
+! CHECK:           omp.yield
 ! CHECK:         }
 
 ! CHECK-LABEL:   func.func @_QPs(
@@ -122,4 +125,4 @@ subroutine s(x)
     !$omp end parallel do
 
     if (c(1) /= 5050) stop 1
-end subroutine s
\ No newline at end of file
+end subroutine s
diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90
index a1f339faea5cd5..8f83a30c9fe782 100644
--- a/flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90
+++ b/flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90
@@ -29,8 +29,9 @@ subroutine reduce(r)
 ! CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 ! CHECK:           %[[VAL_4:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
 ! CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]]#1 : (index) -> !fir.shape<1>
-! CHECK:           %[[VAL_6:.*]] = fir.alloca !fir.array<?xf64>, %[[VAL_4]]#1 {bindc_name = ".tmp"}
-! CHECK:           %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<?xf64>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf64>>, !fir.ref<!fir.array<?xf64>>)
+! CHECK:           %[[VAL_6:.*]] = fir.allocmem !fir.array<?xf64>, %[[VAL_4]]#1 {bindc_name = ".tmp", uniq_name = ""}
+! CHECK:           %[[TRUE:.*]]  = arith.constant true
+! CHECK:           %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xf64>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf64>>, !fir.heap<!fir.array<?xf64>>)
 ! CHECK:           hlfir.assign %[[VAL_1]] to %[[VAL_7]]#0 : f64, !fir.box<!fir.array<?xf64>>
 ! CHECK:           %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<?xf64>>
 ! CHECK:           fir.store %[[VAL_7]]#0 to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
@@ -53,6 +54,18 @@ subroutine reduce(r)
 ! CHECK:             fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<f64>
 ! CHECK:           }
 ! CHECK:           omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xf64>>>)
+! CHECK:         }  cleanup {
+! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xf64>>>):
+! CHECK:           %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
+! CHECK:           %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<?xf64>>) -> !fir.ref<!fir.array<?xf64>>
+! CHECK:           %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf64>>) -> i64
+! CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i64
+! CHECK:           %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
+! CHECK:           fir.if %[[VAL_5]] {
+! CHECK:             %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf64>>) -> !fir.heap<!fir.array<?xf64>>
+! CHECK:             fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<?xf64>>
+! CHECK:           }
+! CHECK:           omp.yield
 ! CHECK:         }
 
 ! CHECK-LABEL:   func.func private @_QFPreduce(
diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-array.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-array.f90
index a898204c881d9b..a08bca9eb283b5 100644
--- a/flang/test/Lower/OpenMP/wsloop-reduction-array.f90
+++ b/flang/test/Lower/OpenMP/wsloop-reduction-array.f90
@@ -16,13 +16,14 @@ program reduce
 
 ! CHECK-LABEL   omp.declare_reduction @add_reduction_byref_box_2xi32 : !fir.ref<!fir.box<!fir.array<2xi32>>> init {
 ! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<2xi32>>>):
-! CHECK:           %[[VAL_1:.*]] = fir.alloca !fir.array<2xi32> {bindc_name = ".tmp"}
 ! CHECK:           %[[VAL_2:.*]] = arith.constant 0 : i32
 ! CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<2xi32>>>
 ! CHECK:           %[[VAL_4:.*]] = arith.constant 2 : index
 ! CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
-! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<2xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<2xi32>>, !fir.ref<!fir.array<2xi32>>)
-! CHECK:           %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.ref<!fir.array<2xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<2xi32>>
+! CHECK:           %[[VAL_1:.*]] = fir.allocmem !fir.array<2xi32> {bindc_name = ".tmp", uniq_name = ""}
+! CHECK:           %[[TRUE:.*]]  = arith.constant true
+! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<2xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<2xi32>>, !fir.heap<!fir.array<2xi32>>)
+! CHECK:           %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[VAL_5]]) : (!fir.heap<!fir.array<2xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<2xi32>>
 ! CHECK:           hlfir.assign %[[VAL_2]] to %[[VAL_7]] : i32, !fir.box<!fir.array<2xi32>>
 ! CHECK:           %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<2xi32>>
 ! CHECK:           fir.store %[[VAL_7]] to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<2xi32>>>
@@ -45,6 +46,18 @@ program reduce
 ! CHECK:             fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
 ! CHECK:           }
 ! CHECK:           omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<2xi32>>>)
+! CHECK:         }  cleanup {
+! CHECK:         ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<2xi32>>>):
+! CHECK:           %[[VAL_1:.*...
[truncated]

@kiranchandramohan
Copy link
Contributor

Allocating on the stack vs heap has some tradeoffs with OpenMP. If there are a large number of threads and if the memory required is high then there is a possibility that we run out of stack space. On the other hand, allocating on the heap probably takes more time. The runtime probably uses locks so I guess it is safe, but please check.

Reductions might be inlined inside of a loop so stack allocations are not safe.

We could use stacksave and stackrestore as well. This was the original suggestion in the patch.
https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic
https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic

@tblah
Copy link
Contributor Author

tblah commented Apr 8, 2024

Allocating on the stack vs heap has some tradeoffs with OpenMP. If there are a large number of threads and if the memory required is high then there is a possibility that we run out of stack space. On the other hand, allocating on the heap probably takes more time. The runtime probably uses locks so I guess it is safe, but please check.

Reductions might be inlined inside of a loop so stack allocations are not safe.

We could use stacksave and stackrestore as well. This was the original suggestion in the patch. https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic

Stacksave/stackrestore would be awkward to implement because there isn't a convenient way to pass the saved stack pointer into the cleanup region. Adding an additional argument to the OpenMP operation definition feels very flang-specific. Maybe the stack save/restore should be added in OpenMPIRBuilder?

fir.allocmem is lowered to a call to malloc, which is thread safe: https://en.cppreference.com/w/c/memory/malloc

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG.

Could you add a TODO to do a performance comparison with the alloca version. Most allocas will be hoisted. Only those allocas which are variable length would need the stacksave and restore. stacksave and restore can be inserted when the reduction declaration is processed.

@tblah tblah merged commit 6f068b9 into llvm:main Apr 11, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants