Skip to content

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Sep 3, 2025

@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

Extends support for mapping do concurrent on the device by adding support for local specifiers. The changes in this PR map the local variable to the omp.target op and uses the mapped value as the private clause operand in the nested omp.parallel op.


Full diff: https://github.com/llvm/llvm-project/pull/156589.diff

3 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+12)
  • (modified) flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp (+114-78)
  • (added) flang/test/Transforms/DoConcurrent/local_device.mlir (+49)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index bc971e8fd6600..fc6eedc6ed4c6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3894,6 +3894,18 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
       return getReduceVars().size();
     }
 
+    unsigned getInductionVarsStart() {
+      return 0;
+    }
+
+    unsigned getLocalOperandsStart() {
+      return getNumInductionVars();
+    }
+
+    unsigned getReduceOperandsStart() {
+      return getLocalOperandsStart() + getNumLocalOperands();
+    }
+
     mlir::Block::BlockArgListType getInductionVars() {
       return getBody()->getArguments().slice(0, getNumInductionVars());
     }
diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index a800a20129abe..66b778fecc208 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -137,6 +137,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
 
         liveIns.push_back(operand->get());
       });
+
+  for (mlir::Value local : loop.getLocalVars())
+    liveIns.push_back(local);
 }
 
 /// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -251,8 +254,7 @@ class DoConcurrentConversion
               .getIsTargetDevice();
 
       mlir::omp::TargetOperands targetClauseOps;
-      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
-                           loopNestClauseOps,
+      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps,
                            isTargetDevice ? nullptr : &targetClauseOps);
 
       LiveInShapeInfoMap liveInShapeInfoMap;
@@ -274,14 +276,13 @@ class DoConcurrentConversion
     }
 
     mlir::omp::ParallelOp parallelOp =
-        genParallelOp(doLoop.getLoc(), rewriter, ivInfos, mapper);
+        genParallelOp(rewriter, loop, ivInfos, mapper);
 
     // Only set as composite when part of `distribute parallel do`.
     parallelOp.setComposite(mapToDevice);
 
     if (!mapToDevice)
-      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
-                           loopNestClauseOps);
+      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps);
 
     for (mlir::Value local : locals)
       looputils::localizeLoopLocalValue(local, parallelOp.getRegion(),
@@ -290,10 +291,38 @@ class DoConcurrentConversion
     if (mapToDevice)
       genDistributeOp(doLoop.getLoc(), rewriter).setComposite(/*val=*/true);
 
-    mlir::omp::LoopNestOp ompLoopNest =
+    auto [loopNestOp, wsLoopOp] =
         genWsLoopOp(rewriter, loop, mapper, loopNestClauseOps,
                     /*isComposite=*/mapToDevice);
 
+    // `local` region arguments are transferred/cloned from the `do concurrent`
+    // loop to the loopnest op when the region is cloned above. Instead, these
+    // region arguments should be on the workshare loop's region.
+    if (mapToDevice) {
+      for (auto [parallelArg, loopNestArg] : llvm::zip_equal(
+               parallelOp.getRegion().getArguments(),
+               loopNestOp.getRegion().getArguments().slice(
+                   loop.getLocalOperandsStart(), loop.getNumLocalOperands())))
+        rewriter.replaceAllUsesWith(loopNestArg, parallelArg);
+
+      for (auto [wsloopArg, loopNestArg] : llvm::zip_equal(
+               wsLoopOp.getRegion().getArguments(),
+               loopNestOp.getRegion().getArguments().slice(
+                   loop.getReduceOperandsStart(), loop.getNumReduceOperands())))
+        rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
+    } else {
+      for (auto [wsloopArg, loopNestArg] :
+           llvm::zip_equal(wsLoopOp.getRegion().getArguments(),
+                           loopNestOp.getRegion().getArguments().drop_front(
+                               loopNestClauseOps.loopLowerBounds.size())))
+        rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
+    }
+
+    for (unsigned i = 0;
+         i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
+      loopNestOp.getRegion().eraseArgument(
+          loopNestClauseOps.loopLowerBounds.size());
+
     rewriter.setInsertionPoint(doLoop);
     fir::FirOpBuilder builder(
         rewriter,
@@ -314,7 +343,7 @@ class DoConcurrentConversion
     // Mark `unordered` loops that are not perfectly nested to be skipped from
     // the legality check of the `ConversionTarget` since we are not interested
     // in mapping them to OpenMP.
-    ompLoopNest->walk([&](fir::DoConcurrentOp doLoop) {
+    loopNestOp->walk([&](fir::DoConcurrentOp doLoop) {
       concurrentLoopsToSkip.insert(doLoop);
     });
 
@@ -370,11 +399,21 @@ class DoConcurrentConversion
       llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
 
   mlir::omp::ParallelOp
-  genParallelOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
+  genParallelOp(mlir::ConversionPatternRewriter &rewriter,
+                fir::DoConcurrentLoopOp loop,
                 looputils::InductionVariableInfos &ivInfos,
                 mlir::IRMapping &mapper) const {
-    auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc);
-    rewriter.createBlock(&parallelOp.getRegion());
+    mlir::omp::ParallelOperands parallelOps;
+
+    if (mapToDevice)
+      genPrivatizers(rewriter, mapper, loop, parallelOps);
+
+    mlir::Location loc = loop.getLoc();
+    auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc, parallelOps);
+    Fortran::common::openmp::EntryBlockArgs parallelArgs;
+    parallelArgs.priv.vars = parallelOps.privateVars;
+    Fortran::common::openmp::genEntryBlock(rewriter, parallelArgs,
+                                           parallelOp.getRegion());
     rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc));
 
     genLoopNestIndVarAllocs(rewriter, ivInfos, mapper);
@@ -411,7 +450,7 @@ class DoConcurrentConversion
 
   void genLoopNestClauseOps(
       mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
-      fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
+      fir::DoConcurrentLoopOp loop,
       mlir::omp::LoopNestOperands &loopNestClauseOps,
       mlir::omp::TargetOperands *targetClauseOps = nullptr) const {
     assert(loopNestClauseOps.loopLowerBounds.empty() &&
@@ -440,59 +479,14 @@ class DoConcurrentConversion
     loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
   }
 
-  mlir::omp::LoopNestOp
+  std::pair<mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
   genWsLoopOp(mlir::ConversionPatternRewriter &rewriter,
               fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
               const mlir::omp::LoopNestOperands &clauseOps,
               bool isComposite) const {
     mlir::omp::WsloopOperands wsloopClauseOps;
-
-    auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
-                                           mlir::Region &ompRegion) {
-      if (!firRegion.empty()) {
-        rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
-        auto firYield =
-            mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
-        rewriter.setInsertionPoint(firYield);
-        mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
-                                   firYield.getOperands());
-        rewriter.eraseOp(firYield);
-      }
-    };
-
-    // For `local` (and `local_init`) opernads, emit corresponding `private`
-    // clauses and attach these clauses to the workshare loop.
-    if (!loop.getLocalVars().empty())
-      for (auto [op, sym, arg] : llvm::zip_equal(
-               loop.getLocalVars(),
-               loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
-               loop.getRegionLocalArgs())) {
-        auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
-            sym.getLeafReference());
-        if (localizer.getLocalitySpecifierType() ==
-            fir::LocalitySpecifierType::LocalInit)
-          TODO(localizer.getLoc(),
-               "local_init conversion is not supported yet");
-
-        mlir::OpBuilder::InsertionGuard guard(rewriter);
-        rewriter.setInsertionPointAfter(localizer);
-
-        auto privatizer = mlir::omp::PrivateClauseOp::create(
-            rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
-            localizer.getTypeAttr().getValue(),
-            mlir::omp::DataSharingClauseType::Private);
-
-        cloneFIRRegionToOMP(localizer.getInitRegion(),
-                            privatizer.getInitRegion());
-        cloneFIRRegionToOMP(localizer.getDeallocRegion(),
-                            privatizer.getDeallocRegion());
-
-        moduleSymbolTable.insert(privatizer);
-
-        wsloopClauseOps.privateVars.push_back(op);
-        wsloopClauseOps.privateSyms.push_back(
-            mlir::SymbolRefAttr::get(privatizer));
-      }
+    if (!mapToDevice)
+      genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
 
     if (!loop.getReduceVars().empty()) {
       for (auto [op, byRef, sym, arg] : llvm::zip_equal(
@@ -515,15 +509,15 @@ class DoConcurrentConversion
               rewriter, firReducer.getLoc(), ompReducerName,
               firReducer.getTypeAttr().getValue());
 
-          cloneFIRRegionToOMP(firReducer.getAllocRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
                               ompReducer.getAllocRegion());
-          cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
                               ompReducer.getInitializerRegion());
-          cloneFIRRegionToOMP(firReducer.getReductionRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
                               ompReducer.getReductionRegion());
-          cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
                               ompReducer.getAtomicReductionRegion());
-          cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
                               ompReducer.getCleanupRegion());
           moduleSymbolTable.insert(ompReducer);
         }
@@ -555,21 +549,10 @@ class DoConcurrentConversion
 
     rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
     mlir::omp::YieldOp::create(rewriter, loop->getLoc());
+    loop->getParentOfType<mlir::ModuleOp>().print(
+        llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
 
-    // `local` region arguments are transferred/cloned from the `do concurrent`
-    // loop to the loopnest op when the region is cloned above. Instead, these
-    // region arguments should be on the workshare loop's region.
-    for (auto [wsloopArg, loopNestArg] :
-         llvm::zip_equal(wsloopOp.getRegion().getArguments(),
-                         loopNestOp.getRegion().getArguments().drop_front(
-                             clauseOps.loopLowerBounds.size())))
-      rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
-
-    for (unsigned i = 0;
-         i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
-      loopNestOp.getRegion().eraseArgument(clauseOps.loopLowerBounds.size());
-
-    return loopNestOp;
+    return {loopNestOp, wsloopOp};
   }
 
   void genBoundsOps(fir::FirOpBuilder &builder, mlir::Value liveIn,
@@ -810,6 +793,59 @@ class DoConcurrentConversion
     return distOp;
   }
 
+  void cloneFIRRegionToOMP(mlir::ConversionPatternRewriter &rewriter,
+                           mlir::Region &firRegion,
+                           mlir::Region &ompRegion) const {
+    if (!firRegion.empty()) {
+      rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
+      auto firYield =
+          mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
+      rewriter.setInsertionPoint(firYield);
+      mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
+                                 firYield.getOperands());
+      rewriter.eraseOp(firYield);
+    }
+  }
+
+  void genPrivatizers(mlir::ConversionPatternRewriter &rewriter,
+                      mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
+                      mlir::omp::PrivateClauseOps &privateClauseOps) const {
+    // For `local` (and `local_init`) operands, emit corresponding `private`
+    // clauses and attach these clauses to the workshare loop.
+    if (!loop.getLocalVars().empty())
+      for (auto [var, sym, arg] : llvm::zip_equal(
+               loop.getLocalVars(),
+               loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
+               loop.getRegionLocalArgs())) {
+        auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
+            sym.getLeafReference());
+        if (localizer.getLocalitySpecifierType() ==
+            fir::LocalitySpecifierType::LocalInit)
+          TODO(localizer.getLoc(),
+               "local_init conversion is not supported yet");
+
+        mlir::OpBuilder::InsertionGuard guard(rewriter);
+        rewriter.setInsertionPointAfter(localizer);
+
+        auto privatizer = mlir::omp::PrivateClauseOp::create(
+            rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
+            localizer.getTypeAttr().getValue(),
+            mlir::omp::DataSharingClauseType::Private);
+
+        cloneFIRRegionToOMP(rewriter, localizer.getInitRegion(),
+                            privatizer.getInitRegion());
+        cloneFIRRegionToOMP(rewriter, localizer.getDeallocRegion(),
+                            privatizer.getDeallocRegion());
+
+        moduleSymbolTable.insert(privatizer);
+
+        privateClauseOps.privateVars.push_back(mapToDevice ? mapper.lookup(var)
+                                                           : var);
+        privateClauseOps.privateSyms.push_back(
+            mlir::SymbolRefAttr::get(privatizer));
+      }
+  }
+
   bool mapToDevice;
   llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
   mlir::SymbolTable &moduleSymbolTable;
diff --git a/flang/test/Transforms/DoConcurrent/local_device.mlir b/flang/test/Transforms/DoConcurrent/local_device.mlir
new file mode 100644
index 0000000000000..e54bb1aeb414e
--- /dev/null
+++ b/flang/test/Transforms/DoConcurrent/local_device.mlir
@@ -0,0 +1,49 @@
+// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
+
+fir.local {type = local} @_QFfooEmy_local_private_f32 : f32
+
+func.func @_QPfoo() {
+  %0 = fir.dummy_scope : !fir.dscope
+  %3 = fir.alloca f32 {bindc_name = "my_local", uniq_name = "_QFfooEmy_local"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+
+  fir.do_concurrent {
+    %7 = fir.alloca i32 {bindc_name = "i"}
+    %8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+    fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) local(@_QFfooEmy_local_private_f32 %4#0 -> %arg1 : !fir.ref<f32>) {
+      %9 = fir.convert %arg0 : (index) -> i32
+      fir.store %9 to %8#0 : !fir.ref<i32>
+      %10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+      %cst = arith.constant 4.200000e+01 : f32
+      hlfir.assign %cst to %10#0 : f32, !fir.ref<f32>
+    }
+  }
+  return
+}
+
+// CHECK: omp.private {type = private} @[[OMP_PRIVATIZER:.*.omp]] : f32
+
+// CHECK: %[[LOCAL_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "{{.*}}my_local"}
+// CHECK: %[[LOCAL_MAP:.*]] = omp.map.info var_ptr(%[[LOCAL_DECL]]#1 : {{.*}})
+
+// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[LOCAL_MAP]] -> %[[LOCAL_MAP_ARG:.*]] : {{.*}}) {
+// CHECK:   %[[LOCAL_DEV_DECL:.*]]:2 = hlfir.declare %[[LOCAL_MAP_ARG]] {uniq_name = "_QFfooEmy_local"}
+
+// CHECK:   omp.teams {
+// CHECK:     omp.parallel private(@[[OMP_PRIVATIZER]] %[[LOCAL_DEV_DECL]]#0 -> %[[LOCAL_PRIV_ARG:.*]] : {{.*}}) {
+// CHECK:       omp.distribute {
+// CHECK:         omp.wsloop {
+// CHECK:           omp.loop_nest {{.*}} {
+// CHECK:             %[[LOCAL_LOOP_DECL:.*]]:2 = hlfir.declare %[[LOCAL_PRIV_ARG]] {uniq_name = "_QFfooEmy_local"}
+// CHECK:             hlfir.assign %{{.*}} to %[[LOCAL_LOOP_DECL]]#0
+// CHECK:             omp.yield
+// CHECK:           }
+// CHECK:         }
+// CHECK:       }
+// CHECK:     }
+// CHECK:   }
+// CHECK: }

@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_5_offload_tests branch from d967c72 to d592609 Compare September 8, 2025 12:00
@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_6_local_specs_device branch from 78e1013 to 13f4544 Compare September 8, 2025 12:00
ergawy added a commit that referenced this pull request Sep 8, 2025
…ide values (#155754)

Following up on #154483, this PR introduces further refactoring to
extract some shared utils between OpenMP lowering and `do concurrent`
conversion pass. In particular, this PR extracts 2 utils that handle
mapping or cloning values used inside target regions but defined
outside.

Later `do concurrent` PR(s) will also use these utils.

PR stack:
- #155754 ◀️
- #155987
- #155992
- #155993
- #156589
- #156610
- #156837
@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_5_offload_tests branch from d592609 to 315f521 Compare September 8, 2025 12:36
@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_6_local_specs_device branch from 13f4544 to c29c8a2 Compare September 8, 2025 12:36
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 8, 2025
… clone outside values (#155754)

Following up on #154483, this PR introduces further refactoring to
extract some shared utils between OpenMP lowering and `do concurrent`
conversion pass. In particular, this PR extracts 2 utils that handle
mapping or cloning values used inside target regions but defined
outside.

Later `do concurrent` PR(s) will also use these utils.

PR stack:
- llvm/llvm-project#155754 ◀️
- llvm/llvm-project#155987
- llvm/llvm-project#155992
- llvm/llvm-project#155993
- llvm/llvm-project#156589
- llvm/llvm-project#156610
- llvm/llvm-project#156837
@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_5_offload_tests branch from 315f521 to e681a9f Compare September 9, 2025 10:13
@ergawy ergawy closed this Sep 9, 2025
@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_6_local_specs_device branch from c29c8a2 to e681a9f Compare September 9, 2025 10:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants