Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][openmp] - depend clause support in target, target enter/update/exit data constructs #81610

Conversation

bhandarkar-pranav
Copy link
Contributor

This patch adds support in flang for the depend clause in target and target enter/update/exit constructs. Previously, the following line in a fortran program would have resulted in the error shown below it.

!$omp target map(to:a) depend(in:a)


"not yet implemented: Unhandled clause DEPEND in TARGET construct"

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp flang:semantics labels Feb 13, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 13, 2024

@llvm/pr-subscribers-flang-semantics
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Pranav Bhandarkar (bhandarkar-pranav)

Changes

This patch adds support in flang for the depend clause in target and target enter/update/exit constructs. Previously, the following line in a fortran program would have resulted in the error shown below it.

!$omp target map(to:a) depend(in:a)


"not yet implemented: Unhandled clause DEPEND in TARGET construct"

Full diff: https://github.com/llvm/llvm-project/pull/81610.diff

4 Files Affected:

  • (modified) flang/lib/Lower/OpenMP.cpp (+18-11)
  • (modified) flang/lib/Semantics/check-omp-structure.cpp (+8)
  • (modified) flang/test/Lower/OpenMP/target.f90 (+85)
  • (modified) flang/test/Semantics/OpenMP/clause-validity01.f90 (+1)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 06850bebd7d05a..7e36fdac0c4dbb 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2786,7 +2786,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value ifClauseOperand, deviceOperand;
   mlir::UnitAttr nowaitAttr;
-  llvm::SmallVector<mlir::Value> mapOperands;
+  llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
+  llvm::SmallVector<mlir::Attribute> dependTypeOperands;
 
   Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
   llvm::omp::Directive directive;
@@ -2820,13 +2821,15 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
   } else {
     cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
   }
+  cp.processDepend(dependTypeOperands, dependOperands);
 
-  cp.processTODO<Fortran::parser::OmpClause::Depend>(currentLocation,
-                                                     directive);
-
-  return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
-                                   deviceOperand, nullptr, mlir::ValueRange(),
-                                   nowaitAttr, mapOperands);
+  return firOpBuilder.create<OpTy>(
+      currentLocation, ifClauseOperand, deviceOperand,
+      dependTypeOperands.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 dependTypeOperands),
+      dependOperands, nowaitAttr, mapOperands);
 }
 
 // This functions creates a block for the body of the targetOp's region. It adds
@@ -2993,7 +2996,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
   mlir::UnitAttr nowaitAttr;
-  llvm::SmallVector<mlir::Value> mapOperands;
+  llvm::SmallVector<mlir::Attribute> dependTypeOperands;
+  llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
   llvm::SmallVector<mlir::Type> mapSymTypes;
   llvm::SmallVector<mlir::Location> mapSymLocs;
   llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
@@ -3006,8 +3010,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   cp.processNowait(nowaitAttr);
   cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
                 &mapSymLocs, &mapSymbols);
+  cp.processDepend(dependTypeOperands, dependOperands);
   cp.processTODO<Fortran::parser::OmpClause::Private,
-                 Fortran::parser::OmpClause::Depend,
                  Fortran::parser::OmpClause::Firstprivate,
                  Fortran::parser::OmpClause::IsDevicePtr,
                  Fortran::parser::OmpClause::HasDeviceAddr,
@@ -3017,7 +3021,6 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
                  Fortran::parser::OmpClause::UsesAllocators,
                  Fortran::parser::OmpClause::Defaultmap>(
       currentLocation, llvm::omp::Directive::OMPD_target);
-
   // 5.8.1 Implicit Data-Mapping Attribute Rules
   // The following code follows the implicit data-mapping rules to map all the
   // symbols used inside the region that have not been explicitly mapped using
@@ -3091,7 +3094,11 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
 
   auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
-      nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
+      dependTypeOperands.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 dependTypeOperands),
+      dependOperands, nowaitAttr, mapOperands);
 
   genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
                     mapSymLocs, mapSymbols, currentLocation);
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index 03423de0c6104d..54101ab8a42bbf 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -2815,6 +2815,14 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Device &x) {
 
 void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
   CheckAllowed(llvm::omp::Clause::OMPC_depend);
+  if ((std::holds_alternative<parser::OmpDependClause::Source>(x.v.u) ||
+          std::holds_alternative<parser::OmpDependClause::Sink>(x.v.u)) &&
+      GetContext().directive != llvm::omp::OMPD_ordered) {
+    context_.Say(GetContext().clauseSource,
+        "DEPEND(SOURCE) or DEPEND(SINK : vec) can be used only with the ordered"
+        " directive. Used here in the %s construct."_err_en_US,
+        parser::ToUpperCaseLetters(getDirectiveName(GetContext().directive)));
+  }
   if (const auto *inOut{std::get_if<parser::OmpDependClause::InOut>(&x.v.u)}) {
     const auto &designators{std::get<std::list<parser::Designator>>(inOut->t)};
     for (const auto &ele : designators) {
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index fa07b7f71d514e..030533e1a04553 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -14,6 +14,26 @@ subroutine omp_target_enter_simple
     return
 end subroutine omp_target_enter_simple
 
+!===============================================================================
+! Target_Enter `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_depend() {
+subroutine omp_target_enter_depend
+   !CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_enter_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
+   integer :: a(1024)
+
+   !CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp task depend(out: a)
+   call foo(a)
+   !$omp end task
+   !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
+   !CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}})   map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: omp.target_enter_data   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target enter data map(to: a) depend(in: a)
+    return
+end subroutine omp_target_enter_depend
+
 !===============================================================================
 ! Target_Enter Map types
 !===============================================================================
@@ -134,6 +154,45 @@ subroutine omp_target_exit_device
    !$omp target exit data map(from: a) device(d)
 end subroutine omp_target_exit_device
 
+!===============================================================================
+! Target_Exit `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_exit_depend() {
+subroutine omp_target_exit_depend
+   !CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_exit_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
+   integer :: a(1024)
+   !CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp task depend(out: a)
+   call foo(a)
+   !$omp end task
+   !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
+   !CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}})   map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: omp.target_exit_data   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target exit data map(from: a) depend(out: a)
+end subroutine omp_target_exit_depend
+
+
+!===============================================================================
+! Target_Update `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_depend() {
+subroutine omp_target_update_depend
+   !CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_update_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
+   integer :: a(1024)
+
+   !CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp task depend(out: a)
+   call foo(a)
+   !$omp end task
+
+   !CHECK: %[[BOUNDS:.*]] = omp.bounds
+   !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update to(a) depend(in:a)
+end subroutine omp_target_update_depend
+
 !===============================================================================
 ! Target_Update `to` clause
 !===============================================================================
@@ -295,6 +354,32 @@ subroutine omp_target
    !CHECK: }
 end subroutine omp_target
 
+!===============================================================================
+! Target with region `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_depend() {
+subroutine omp_target_depend
+   !CHECK: %[[EXTENT_A:.*]] = arith.constant 1024 : index
+   !CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
+   integer :: a(1024)
+   !CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp task depend(out: a)
+   call foo(a)
+   !$omp end task
+   !CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index
+   !CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index
+   !CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index
+   !CHECK: %[[BOUNDS_A:.*]] = omp.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index)
+   !CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp target map(tofrom: a) depend(in: a)
+      a(1) = 10
+      !CHECK: omp.terminator
+   !$omp end target
+   !CHECK: }
+ end subroutine omp_target_depend
+
 !===============================================================================
 ! Target implicit capture
 !===============================================================================
diff --git a/flang/test/Semantics/OpenMP/clause-validity01.f90 b/flang/test/Semantics/OpenMP/clause-validity01.f90
index 3fa86ed105a292..d9573a81821f32 100644
--- a/flang/test/Semantics/OpenMP/clause-validity01.f90
+++ b/flang/test/Semantics/OpenMP/clause-validity01.f90
@@ -481,6 +481,7 @@
   !$omp taskyield
   !$omp barrier
   !$omp taskwait
+  !ERROR: DEPEND(SOURCE) or DEPEND(SINK : vec) can be used only with the ordered directive. Used here in the TASKWAIT construct.
   !$omp taskwait depend(source)
   ! !$omp taskwait depend(sink:i-1)
   ! !$omp target enter data map(to:arrayA) map(alloc:arrayB)

@bhandarkar-pranav
Copy link
Contributor Author

@skatrak @clementval @kiranchandramohan - This is the 2nd (Flang) part of a now-closed PR (#80626). Could you please review this?

Copy link
Contributor

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Pranav, this LGTM. Generally I'd have suggested splitting the lowering and the semantics work into separate PRs, but in this case I think they are both small enough to not be necessary.

Please give it at least until after the weekend to make sure others have time to express any concerns before merging.


!CHECK-LABEL: func.func @_QPomp_target_enter_depend() {
subroutine omp_target_enter_depend
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_enter_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think it's not necessary to check the hlfir.declare operation here, since we're only really interested in matching the same SSA value in omp.task depend(taskdependout -> %[[A:.*]] : ... and in omp.target_enter_data ... depend(taskdependin -> %[[A]] : .... Same comment for the other tests added in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely agree. Even I am not a fan of adding CHECKs for aspects that are not intended to be affected by the patch/PR in question. I did it this way only to be consistent with the rest of the this test file. I was bemused by that myself.

@@ -2820,13 +2821,15 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
} else {
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
}
cp.processDepend(dependTypeOperands, dependOperands);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultra-nit: If this doesn't need to happen after processMap or processMotionClauses, I think it's easier to read if it's placed right below processNowait above the 'if'. But feel free to ignore this comment if you don't agree.

…/data constructs

This patch adds support in flang for the depend clause in target and target exit/update/data
constructs. Previously, the following line in a fortran program would have resulted
in the error shown below it.

    !$omp target map(to:a) depend(in:a)

    "not yet implemented: Unhandled clause DEPEND in TARGET construct"
@bhandarkar-pranav bhandarkar-pranav force-pushed the flang/target_depend_clause_pft_to_mlir branch from 407a175 to 69d068d Compare February 21, 2024 16:56
@bhandarkar-pranav bhandarkar-pranav merged commit 58f45d9 into llvm:main Feb 21, 2024
3 of 4 checks passed
@bhandarkar-pranav bhandarkar-pranav deleted the flang/target_depend_clause_pft_to_mlir branch March 13, 2024 20:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants