Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][openacc/mp] Do not read bounds on absent box #75252

Merged
merged 2 commits into from
Dec 15, 2023

Conversation

clementval
Copy link
Contributor

@clementval clementval commented Dec 12, 2023

Make sure we only load box and read its bounds when it is present.

  • Add AddrAndBoundInfo struct to be able to carry around the addr and isPresent values. This is likely to grow so we can make all the access in a single fir.if operation.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp openacc labels Dec 12, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 12, 2023

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-flang-openmp

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Make sure we only load box and read its bounds when it is present.

Fix also some template parameter ordering issues.


Full diff: https://github.com/llvm/llvm-project/pull/75252.diff

5 Files Affected:

  • (modified) flang/lib/Lower/DirectivesCommon.h (+87-17)
  • (modified) flang/lib/Lower/OpenACC.cpp (+17-13)
  • (modified) flang/lib/Lower/OpenMP.cpp (+2-2)
  • (modified) flang/test/Lower/OpenACC/acc-bounds.f90 (+31)
  • (modified) flang/test/Lower/OpenACC/acc-data.f90 (-1)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 88a8916663df75..39f87202f90f5f 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -620,25 +620,36 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
 
     // Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a
     // `fir.ref<fir.class<T>>` type.
-    if (symAddr.getType().isa<fir::ReferenceType>())
+    if (symAddr.getType().isa<fir::ReferenceType>()) {
+      if (Fortran::semantics::IsOptional(sym)) {
+        mlir::Value isPresent =
+            builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), symAddr);
+        return builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
+            .genThen([&]() {
+              mlir::Value load = builder.create<fir::LoadOp>(loc, symAddr);
+              builder.create<fir::ResultOp>(loc, mlir::ValueRange{load});
+            })
+            .genElse([&] {
+              mlir::Value absent = builder.create<fir::AbsentOp>(loc, boxTy);
+              builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
+            })
+            .getResults()[0];
+      }
       return builder.create<fir::LoadOp>(loc, symAddr);
+    }
   }
   return symAddr;
 }
 
-/// Generate the bounds operation from the descriptor information.
 template <typename BoundsOp, typename BoundsType>
-llvm::SmallVector<mlir::Value>
-genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
-                    Fortran::lower::AbstractConverter &converter,
+static llvm::SmallVector<mlir::Value>
+gatherBoundsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
                     fir::ExtendedValue dataExv, mlir::Value box) {
+  mlir::Value byteStride;
   llvm::SmallVector<mlir::Value> bounds;
   mlir::Type idxTy = builder.getIndexType();
   mlir::Type boundTy = builder.getType<BoundsType>();
   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
-  assert(box.getType().isa<fir::BaseBoxType>() &&
-         "expect fir.box or fir.class");
-  mlir::Value byteStride;
   for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
     mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim);
     mlir::Value baseLb =
@@ -660,6 +671,58 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
   return bounds;
 }
 
+/// Generate the bounds operation from the descriptor information.
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
+                    Fortran::lower::AbstractConverter &converter,
+                    fir::ExtendedValue dataExv, mlir::Value box,
+                    bool isOptional = false) {
+  llvm::SmallVector<mlir::Value> bounds;
+  mlir::Type idxTy = builder.getIndexType();
+  mlir::Type boundTy = builder.getType<BoundsType>();
+
+  assert(box.getType().isa<fir::BaseBoxType>() &&
+         "expect fir.box or fir.class");
+
+  if (isOptional) {
+    mlir::Value isPresent =
+        builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), box);
+
+    llvm::SmallVector<mlir::Type> resTypes;
+    for (unsigned dim = 0; dim < dataExv.rank(); ++dim)
+      resTypes.push_back(boundTy);
+
+    auto ifOp =
+        builder.genIfOp(loc, resTypes, isPresent, /*withElseRegion=*/true)
+            .genThen([&]() {
+              llvm::SmallVector<mlir::Value> tempBounds =
+                  gatherBoundsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                            dataExv, box);
+              builder.create<fir::ResultOp>(loc, tempBounds);
+            })
+            .genElse([&] {
+              llvm::SmallVector<mlir::Value> tempBounds;
+              mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
+              mlir::Value minusOne =
+                  builder.createIntegerConstant(loc, idxTy, -1);
+              for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
+                mlir::Value bound = builder.create<BoundsOp>(
+                    loc, boundTy, zero, minusOne, zero, mlir::Value(), false,
+                    mlir::Value{});
+                tempBounds.push_back(bound);
+              }
+              builder.create<fir::ResultOp>(loc, tempBounds);
+            });
+    bounds.append(ifOp.getResults().begin(), ifOp.getResults().end());
+  } else {
+    llvm::SmallVector<mlir::Value> tempBounds =
+        gatherBoundsFromBox<BoundsOp, BoundsType>(builder, loc, dataExv, box);
+    bounds.append(tempBounds.begin(), tempBounds.end());
+  }
+  return bounds;
+}
+
 /// Generate bounds operation for base array without any subscripts
 /// provided.
 template <typename BoundsOp, typename BoundsType>
@@ -885,20 +948,20 @@ mlir::Value gatherDataOperandAddrAndBounds(
 
                 if (!arrayElement->subscripts.empty()) {
                   asFortran << '(';
-                  bounds = genBoundsOps<BoundsType, BoundsOp>(
+                  bounds = genBoundsOps<BoundsOp, BoundsType>(
                       builder, operandLocation, converter, stmtCtx,
                       arrayElement->subscripts, asFortran, dataExv, baseAddr,
                       treatIndexAsSection);
                 }
                 asFortran << ')';
-              } else if (Fortran::parser::Unwrap<
+              } else if (auto structComp = Fortran::parser::Unwrap<
                              Fortran::parser::StructureComponent>(designator)) {
                 fir::ExtendedValue compExv =
                     converter.genExprAddr(operandLocation, *expr, stmtCtx);
                 baseAddr = fir::getBase(compExv);
                 if (fir::unwrapRefType(baseAddr.getType())
                         .isa<fir::SequenceType>())
-                  bounds = genBaseBoundsOps<BoundsType, BoundsOp>(
+                  bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
                       builder, operandLocation, converter, compExv, baseAddr);
                 asFortran << (*expr).AsFortran();
 
@@ -917,8 +980,11 @@ mlir::Value gatherDataOperandAddrAndBounds(
                 if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
                         baseAddr.getDefiningOp())) {
                   baseAddr = boxAddrOp.getVal();
-                  bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>(
-                      builder, operandLocation, converter, compExv, baseAddr);
+                  bool isOptional = Fortran::semantics::IsOptional(
+                      *Fortran::parser::GetLastName(*structComp).symbol);
+                  bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+                      builder, operandLocation, converter, compExv, baseAddr,
+                      isOptional);
                 }
               } else {
                 if (Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
@@ -943,12 +1009,16 @@ mlir::Value gatherDataOperandAddrAndBounds(
                   baseAddr = getDataOperandBaseAddr(
                       converter, builder, *name.symbol, operandLocation);
                   if (fir::unwrapRefType(baseAddr.getType())
-                          .isa<fir::BaseBoxType>())
-                    bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>(
-                        builder, operandLocation, converter, dataExv, baseAddr);
+                          .isa<fir::BaseBoxType>()) {
+                    bool isOptional =
+                        Fortran::semantics::IsOptional(*name.symbol);
+                    bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+                        builder, operandLocation, converter, dataExv, baseAddr,
+                        isOptional);
+                  }
                   if (fir::unwrapRefType(baseAddr.getType())
                           .isa<fir::SequenceType>())
-                    bounds = genBaseBoundsOps<BoundsType, BoundsOp>(
+                    bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
                         builder, operandLocation, converter, dataExv, baseAddr);
                   asFortran << name.ToString();
                 } else { // Unsupported
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e2abed1b9f4f67..531685948bc843 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -266,10 +266,11 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
     std::stringstream asFortran;
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-        Fortran::parser::AccObject, mlir::acc::DataBoundsType,
-        mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
-                                 accObject, operandLocation, asFortran, bounds,
-                                 /*treatIndexAsSection=*/true);
+        Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+        mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+                                   stmtCtx, accObject, operandLocation,
+                                   asFortran, bounds,
+                                   /*treatIndexAsSection=*/true);
     Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
                                   bounds, structured, implicit, dataClause,
                                   baseAddr.getType());
@@ -291,9 +292,10 @@ static void genDeclareDataOperandOperations(
     std::stringstream asFortran;
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-        Fortran::parser::AccObject, mlir::acc::DataBoundsType,
-        mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
-                                 accObject, operandLocation, asFortran, bounds);
+        Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+        mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+                                   stmtCtx, accObject, operandLocation,
+                                   asFortran, bounds);
     EntryOp op = createDataEntryOp<EntryOp>(
         builder, operandLocation, baseAddr, asFortran, bounds, structured,
         implicit, dataClause, baseAddr.getType());
@@ -748,9 +750,10 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
     std::stringstream asFortran;
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-        Fortran::parser::AccObject, mlir::acc::DataBoundsType,
-        mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
-                                 accObject, operandLocation, asFortran, bounds);
+        Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+        mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+                                   stmtCtx, accObject, operandLocation,
+                                   asFortran, bounds);
 
     RecipeOp recipe;
     mlir::Type retTy = getTypeFromBounds(bounds, baseAddr.getType());
@@ -1324,9 +1327,10 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
     std::stringstream asFortran;
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-        Fortran::parser::AccObject, mlir::acc::DataBoundsType,
-        mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
-                                 accObject, operandLocation, asFortran, bounds);
+        Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+        mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+                                   stmtCtx, accObject, operandLocation,
+                                   asFortran, bounds);
 
     mlir::Type reductionTy = fir::unwrapRefType(baseAddr.getType());
     if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index eeba87fcd15116..59e06e8458e6c0 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1794,8 +1794,8 @@ bool ClauseProcessor::processMap(
           llvm::SmallVector<mlir::Value> bounds;
           std::stringstream asFortran;
           mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-              Fortran::parser::OmpObject, mlir::omp::DataBoundsType,
-              mlir::omp::DataBoundsOp>(
+              Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
+              mlir::omp::DataBoundsType>(
               converter, firOpBuilder, semanticsContext, stmtCtx, ompObject,
               clauseLocation, asFortran, bounds, treatIndexAsSection);
 
diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90
index 8db18ab5aa9c4b..c8787c5e118f97 100644
--- a/flang/test/Lower/OpenACC/acc-bounds.f90
+++ b/flang/test/Lower/OpenACC/acc-bounds.f90
@@ -116,4 +116,35 @@ subroutine acc_multi_strides(a)
 ! CHECK: %[[PRESENT:.*]] = acc.present varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?x?x?xf32>>) bounds(%29, %33, %37) -> !fir.ref<!fir.array<?x?x?xf32>> {name = "a"}
 ! CHECK: acc.kernels dataOperands(%[[PRESENT]] : !fir.ref<!fir.array<?x?x?xf32>>) {
 
+  subroutine acc_optional_data(a)
+    real, pointer, optional :: a(:)
+    !$acc data attach(a)
+    !$acc end data
+  end subroutine
+  
+  ! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data(
+  ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a", fir.optional}) {
+  ! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+  ! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
+  ! CHECK: %[[ADDR:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box<!fir.ptr<!fir.array<?xf32>>>) {
+  ! CHECK:   %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+  ! CHECK:   fir.result %[[LOAD]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  ! CHECK: } else {
+  ! CHECK:   %[[ABSENT:.*]] = fir.absent !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  ! CHECK:   fir.result %[[ABSENT]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  ! CHECK: }
+  ! CHECK: %[[BOUNDS:.*]] = fir.if %{{.*}} -> (!acc.data_bounds_ty) {
+  ! CHECK:   %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}}#1 : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}}#0 : index) {strideInBytes = true}
+  ! CHECK:   fir.result %[[BOUND]] : !acc.data_bounds_ty
+  ! CHECK: } else {
+  ! CHECK:   %[[C0:.*]] = arith.constant 0 : index
+  ! CHECK:   %[[CM1:.*]] = arith.constant -1 : index
+  ! CHECK:   %[[BOUND:.*]] = acc.bounds lowerbound(%[[C0]] : index) upperbound(%[[CM1]] : index) extent(%[[C0]] : index)
+  ! CHECK:   fir.result %[[BOUND]] : !acc.data_bounds_ty
+  ! CHECK: }
+  ! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[ADDR]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
+  ! CHECK: %[[ATTACH:.*]] = acc.attach varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUNDS]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
+  ! CHECK: acc.data dataOperands(%[[ATTACH]] : !fir.ptr<!fir.array<?xf32>>)
+  
+
 end module
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index d302be85c5df46..a6572e14707606 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -198,4 +198,3 @@ subroutine acc_data
 ! CHECK-NOT: acc.data
 
 end subroutine acc_data
-

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me. Thank you for fixing!

@clementval
Copy link
Contributor Author

Looks great to me. Thank you for fixing!

Thanks for the review. I'll follow on this next week since we have to cover more cases of optional.

@clementval clementval merged commit 22426d9 into llvm:main Dec 15, 2023
4 checks passed
@clementval clementval deleted the acc_optional branch December 15, 2023 21:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants