Skip to content

Conversation

agozillon
Copy link
Contributor

This patch is one in a series of four patches that seeks to refactor
slightly and extend the current record type map support that was
put in place for Fortran's descriptor types to handle explicit
member mapping for record types at a single level of depth.

For example, the below case where two members of a Fortran
derived type are mapped explicitly:

''''
type :: scalar_and_array
real(4) :: real
integer(4) :: array(10)
integer(4) :: int
end type scalar_and_array
type(scalar_and_array) :: scalar_arr

!$omp target map(tofrom: scalar_arr%int, scalar_arr%real)
''''

Current cases of derived type mapping left for future work are:

explicit member mapping of nested members (e.g. two layers of
record types where we explicitly map a member from the internal
record type)
Fortran's automagical mapping of all elements and nested elements
of a derived type
explicit member mapping of a derived type and then constituient members
(redundant in Fortran due to former case but still legal as far as I am aware)
explicit member mapping of a record type (may be handled reasonably, just
not fully tested in this iteration)
explicit member mapping for Fortran allocatable types (a variation of nested
record types)

This patch seeks to support this by extending the Flang-new OpenMP lowering to
support generation of this newly required information, creating the neccessary
parent <-to-> member map_info links, calculating the member indices and
setting if it's a partial map.

The OMPDescriptorMapInfoGen pass has also been generalized into a map
finalization phase, now named OMPMapInfoFinalization. This pass was extended
to support the insertion of member maps into the BlockArg and MapOperands of
relevant map carrying operations. Similar to the method in which descriptor types
are expanded and constituient members inserted.

Created using spr 1.3.4
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp labels Feb 12, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 12, 2024

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (agozillon)

Changes

This patch is one in a series of four patches that seeks to refactor
slightly and extend the current record type map support that was
put in place for Fortran's descriptor types to handle explicit
member mapping for record types at a single level of depth.

For example, the below case where two members of a Fortran
derived type are mapped explicitly:

''''
type :: scalar_and_array
real(4) :: real
integer(4) :: array(10)
integer(4) :: int
end type scalar_and_array
type(scalar_and_array) :: scalar_arr

!$omp target map(tofrom: scalar_arr%int, scalar_arr%real)
''''

Current cases of derived type mapping left for future work are:
> explicit member mapping of nested members (e.g. two layers of
record types where we explicitly map a member from the internal
record type)
> Fortran's automagical mapping of all elements and nested elements
of a derived type
> explicit member mapping of a derived type and then constituient members
(redundant in Fortran due to former case but still legal as far as I am aware)
> explicit member mapping of a record type (may be handled reasonably, just
not fully tested in this iteration)
> explicit member mapping for Fortran allocatable types (a variation of nested
record types)

This patch seeks to support this by extending the Flang-new OpenMP lowering to
support generation of this newly required information, creating the neccessary
parent <-to-> member map_info links, calculating the member indices and
setting if it's a partial map.

The OMPDescriptorMapInfoGen pass has also been generalized into a map
finalization phase, now named OMPMapInfoFinalization. This pass was extended
to support the insertion of member maps into the BlockArg and MapOperands of
relevant map carrying operations. Similar to the method in which descriptor types
are expanded and constituient members inserted.


Patch is 98.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81511.diff

20 Files Affected:

  • (modified) flang/docs/OpenMP-descriptor-management.md (+2-2)
  • (modified) flang/include/flang/Optimizer/Transforms/Passes.h (+1-1)
  • (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+3-3)
  • (modified) flang/include/flang/Tools/CLOptions.inc (+1-1)
  • (modified) flang/lib/Lower/OpenMP.cpp (+210-25)
  • (modified) flang/lib/Optimizer/Transforms/CMakeLists.txt (+1-1)
  • (removed) flang/lib/Optimizer/Transforms/OMPDescriptorMapInfoGen.cpp (-168)
  • (added) flang/lib/Optimizer/Transforms/OMPMapInfoFinalization.cpp (+262)
  • (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+31-4)
  • (modified) flang/test/Integration/OpenMP/map-types-and-sizes.f90 (+153-1)
  • (modified) flang/test/Lower/OpenMP/FIR/array-bounds.f90 (+2-2)
  • (modified) flang/test/Lower/OpenMP/FIR/map-component-ref.f90 (+2-2)
  • (modified) flang/test/Lower/OpenMP/FIR/target.f90 (+2-2)
  • (modified) flang/test/Lower/OpenMP/allocatable-array-bounds.f90 (+6-6)
  • (modified) flang/test/Lower/OpenMP/allocatable-map.f90 (+4-4)
  • (modified) flang/test/Lower/OpenMP/array-bounds.f90 (+4-4)
  • (added) flang/test/Lower/OpenMP/derived-type-map.f90 (+105)
  • (modified) flang/test/Lower/OpenMP/map-component-ref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/target.f90 (+2-2)
  • (renamed) flang/test/Transforms/omp-map-info-finalization.fir (+29-5)
diff --git a/flang/docs/OpenMP-descriptor-management.md b/flang/docs/OpenMP-descriptor-management.md
index 90a20282e05126..af02b3a99cb07d 100644
--- a/flang/docs/OpenMP-descriptor-management.md
+++ b/flang/docs/OpenMP-descriptor-management.md
@@ -44,7 +44,7 @@ Currently, Flang will lower these descriptor types in the OpenMP lowering (lower
 to all other map types, generating an omp.MapInfoOp containing relevant information required for lowering
 the OpenMP dialect to LLVM-IR during the final stages of the MLIR lowering. However, after 
 the lowering to FIR/HLFIR has been performed an OpenMP dialect specific pass for Fortran, 
-`OMPDescriptorMapInfoGenPass` (Optimizer/OMPDescriptorMapInfoGen.cpp) will expand the 
+`OMPMapInfoFinalizationPass` (Optimizer/OMPMapInfoFinalization.cpp) will expand the 
 `omp.MapInfoOp`'s containing descriptors (which currently will be a `BoxType` or `BoxAddrOp`) into multiple 
 mappings, with one extra per pointer member in the descriptor that is supported on top of the original
 descriptor map operation. These pointers members are linked to the parent descriptor by adding them to 
@@ -52,7 +52,7 @@ the member field of the original descriptor map operation, they are then inserte
 owning operation's (`omp.TargetOp`, `omp.DataOp` etc.) map operand list and in cases where the owning operation
 is `IsolatedFromAbove`, it also inserts them as `BlockArgs` to canonicalize the mappings and simplify lowering.
 
-An example transformation by the `OMPDescriptorMapInfoGenPass`:
+An example transformation by the `OMPMapInfoFinalizationPass`:
 
 ```
 
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index e1d22c8c986da7..fc9a098c3931d3 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -76,7 +76,7 @@ std::unique_ptr<mlir::Pass>
 createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
 std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();
 
-std::unique_ptr<mlir::Pass> createOMPDescriptorMapInfoGenPass();
+std::unique_ptr<mlir::Pass> createOMPMapInfoFinalizationPass();
 std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
 createOMPMarkDeclareTargetPass();
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 5fb576fd876254..0638ae49f5f4ea 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -318,15 +318,15 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
   let dependentDialects = [ "fir::FIROpsDialect" ];
 }
 
-def OMPDescriptorMapInfoGenPass
-    : Pass<"omp-descriptor-map-info-gen", "mlir::func::FuncOp"> {
+def OMPMapInfoFinalizationPass
+    : Pass<"omp-map-info-finalization", "mlir::func::FuncOp"> {
   let summary = "expands OpenMP MapInfo operations containing descriptors";
   let description = [{
     Expands MapInfo operations containing descriptor types into multiple 
     MapInfo's for each pointer element in the descriptor that requires 
     explicit individual mapping by the OpenMP runtime.
   }];
-  let constructor = "::fir::createOMPDescriptorMapInfoGenPass()";
+  let constructor = "::fir::createOMPMapInfoFinalizationPass()";
   let dependentDialects = ["mlir::omp::OpenMPDialect"];
 }
 
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 68e504d0ccb512..ec3d634ac0264b 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -274,7 +274,7 @@ inline void createHLFIRToFIRPassPipeline(
 /// rather than the host device.
 inline void createOpenMPFIRPassPipeline(
     mlir::PassManager &pm, bool isTargetDevice) {
-  pm.addPass(fir::createOMPDescriptorMapInfoGenPass());
+  pm.addPass(fir::createOMPMapInfoFinalizationPass());
   pm.addPass(fir::createOMPMarkDeclareTargetPass());
   if (isTargetDevice)
     pm.addPass(fir::createOMPFunctionFilteringPass());
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index fd18b212bad515..86fdc51602cf12 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -49,14 +49,16 @@ using DeclareTargetCapturePair =
 //===----------------------------------------------------------------------===//
 
 static Fortran::semantics::Symbol *
-getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
+getOmpObjParentSymbol(const Fortran::parser::OmpObject &ompObject) {
   Fortran::semantics::Symbol *sym = nullptr;
   std::visit(
       Fortran::common::visitors{
           [&](const Fortran::parser::Designator &designator) {
-            if (auto *arrayEle =
-                    Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
-                        designator)) {
+            if (auto *structComp = Fortran::parser::Unwrap<
+                    Fortran::parser::StructureComponent>(designator)) {
+              sym = GetFirstName(structComp->base).symbol;
+            } else if (auto *arrayEle = Fortran::parser::Unwrap<
+                           Fortran::parser::ArrayElement>(designator)) {
               sym = GetFirstName(arrayEle->base).symbol;
             } else if (auto *structComp = Fortran::parser::Unwrap<
                            Fortran::parser::StructureComponent>(designator)) {
@@ -72,6 +74,29 @@ getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
   return sym;
 }
 
+static Fortran::semantics::Symbol *
+getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
+  Fortran::semantics::Symbol *sym = nullptr;
+  std::visit(
+      Fortran::common::visitors{
+          [&](const Fortran::parser::Designator &designator) {
+            if (auto *structComp = Fortran::parser::Unwrap<
+                    Fortran::parser::StructureComponent>(designator)) {
+              sym = structComp->component.symbol;
+            } else if (auto *arrayEle = Fortran::parser::Unwrap<
+                           Fortran::parser::ArrayElement>(designator)) {
+              sym = GetLastName(arrayEle->base).symbol;
+            } else if (const Fortran::parser::Name *name =
+                           Fortran::semantics::getDesignatorNameIfDataRef(
+                               designator)) {
+              sym = name->symbol;
+            }
+          },
+          [&](const Fortran::parser::Name &name) { sym = name.symbol; }},
+      ompObject.u);
+  return sym;
+}
+
 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
                           Fortran::lower::AbstractConverter &converter,
                           llvm::SmallVectorImpl<mlir::Value> &operands) {
@@ -1829,9 +1854,10 @@ static mlir::omp::MapInfoOp
 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
                 mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
                 mlir::SmallVector<mlir::Value> bounds,
-                mlir::SmallVector<mlir::Value> members, uint64_t mapType,
+                mlir::SmallVector<mlir::Value> members,
+                mlir::ArrayAttr membersIndex, uint64_t mapType,
                 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
-                bool isVal = false) {
+                bool partialMap = false) {
   if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
     baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
     retTy = baseAddr.getType();
@@ -1841,14 +1867,112 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
       llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
 
   mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
-      loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
+      loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
       builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
       builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
-      builder.getStringAttr(name));
+      builder.getStringAttr(name), builder.getBoolAttr(partialMap));
 
   return op;
 }
 
+int findComponenetMemberPlacement(
+    const Fortran::semantics::Symbol *dTypeSym,
+    const Fortran::semantics::Symbol *componentSym) {
+  int placement = -1;
+  if (const auto *derived{
+          dTypeSym->detailsIf<Fortran::semantics::DerivedTypeDetails>()}) {
+    for (auto t : derived->componentNames()) {
+      placement++;
+      if (t == componentSym->name())
+        return placement;
+    }
+  }
+  return placement;
+}
+
+static void
+checkAndApplyDeclTargetMapFlags(Fortran::lower::AbstractConverter &converter,
+                                llvm::omp::OpenMPOffloadMappingFlags &mapFlags,
+                                Fortran::semantics::Symbol *symbol) {
+  mlir::Operation *op =
+      converter.getModuleOp().lookupSymbol(converter.mangleName(*symbol));
+  if (op)
+    if (auto declareTargetOp =
+            llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op)) {
+      // only Link clauses have OMP_MAP_PTR_AND_OBJ applied, To clause
+      // functions fairly different.
+      if (declareTargetOp.getDeclareTargetCaptureClause() ==
+          mlir::omp::DeclareTargetCaptureClause::link)
+        mapFlags |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
+    }
+}
+
+static void insertChildMapInfoIntoParent(
+    Fortran::lower::AbstractConverter &converter,
+    llvm::SmallVector<const Fortran::semantics::Symbol *> &memberParentSyms,
+    llvm::SmallVector<mlir::Value> &memberMaps,
+    llvm::SmallVector<mlir::Attribute> &memberPlacementIndices,
+    llvm::SmallVectorImpl<mlir::Value> &mapOperands,
+    llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
+    llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols) {
+  // TODO: For multi-nested record types the top level parent is currently
+  // the containing parent for all member operations.
+  for (auto [idx, sym] : llvm::enumerate(memberParentSyms)) {
+    bool parentExists = false;
+    size_t parentIdx = 0;
+    for (size_t i = 0; i < mapSymbols->size(); ++i) {
+      if ((*mapSymbols)[i] == sym) {
+        parentExists = true;
+        parentIdx = i;
+      }
+    }
+
+    if (parentExists) {
+      // found a parent, append.
+      if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
+              mapOperands[parentIdx].getDefiningOp())) {
+        mapOp.getMembersMutable().append(memberMaps[idx]);
+        llvm::SmallVector<mlir::Attribute> memberIndexTmp{
+            mapOp.getMembersIndexAttr().begin(),
+            mapOp.getMembersIndexAttr().end()};
+        memberIndexTmp.push_back(memberPlacementIndices[idx]);
+        mapOp.setMembersIndexAttr(mlir::ArrayAttr::get(
+            converter.getFirOpBuilder().getContext(), memberIndexTmp));
+      }
+    } else {
+      // NOTE: We take the map type of the first child, this may not
+      // be the correct thing to do, however, we shall see. For the moment
+      // it allows this to work with enter and exit without causing MLIR
+      // verification issues. The more appropriate thing may be to take
+      // the "main" map type clause from the directive being used.
+      uint64_t mapType = 0;
+      if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
+              memberMaps[idx].getDefiningOp()))
+        mapType = mapOp.getMapType().value_or(0);
+
+      // create parent to emplace and bind members
+      auto origSymbol = converter.getSymbolAddress(*sym);
+      mlir::Value mapOp = createMapInfoOp(
+          converter.getFirOpBuilder(),
+          converter.getFirOpBuilder().getUnknownLoc(), origSymbol,
+          mlir::Value(), sym->name().ToString(), {}, {memberMaps[idx]},
+          mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                               memberPlacementIndices[idx]),
+          mapType, mlir::omp::VariableCaptureKind::ByRef, origSymbol.getType(),
+          true);
+
+      mapOperands.push_back(mapOp);
+      if (mapSymTypes)
+        mapSymTypes->push_back(mapOp.getType());
+      if (mapSymLocs)
+        mapSymLocs->push_back(mapOp.getLoc());
+      if (mapSymbols)
+        mapSymbols->push_back(sym);
+    }
+  }
+}
+
 bool ClauseProcessor::processMap(
     mlir::Location currentLocation, const llvm::omp::Directive &directive,
     Fortran::semantics::SemanticsContext &semanticsContext,
@@ -1859,7 +1983,13 @@ bool ClauseProcessor::processMap(
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
     const {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  return findRepeatableClause<ClauseTy::Map>(
+
+  llvm::SmallVector<mlir::Value> memberMaps;
+  llvm::SmallVector<mlir::Attribute> memberPlacementIndices;
+  llvm::SmallVector<const Fortran::semantics::Symbol *> memberParentSyms,
+      mapSyms;
+
+  bool clauseFound = findRepeatableClause<ClauseTy::Map>(
       [&](const ClauseTy::Map *mapClause,
           const Fortran::parser::CharBlock &source) {
         mlir::Location clauseLocation = converter.genLocation(source);
@@ -1906,8 +2036,22 @@ bool ClauseProcessor::processMap(
 
         for (const Fortran::parser::OmpObject &ompObject :
              std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
+          llvm::omp::OpenMPOffloadMappingFlags objectsMapTypeBits = mapTypeBits;
+          checkAndApplyDeclTargetMapFlags(converter, objectsMapTypeBits,
+                                          getOmpObjectSymbol(ompObject));
+
           llvm::SmallVector<mlir::Value> bounds;
           std::stringstream asFortran;
+          const Fortran::semantics::Symbol *parentSym = nullptr;
+
+          if (getOmpObjectSymbol(ompObject)->owner().IsDerivedType()) {
+            memberPlacementIndices.push_back(
+                firOpBuilder.getI64IntegerAttr(findComponenetMemberPlacement(
+                    getOmpObjectSymbol(ompObject)->owner().symbol(),
+                    getOmpObjectSymbol(ompObject))));
+            parentSym = getOmpObjParentSymbol(ompObject);
+            memberParentSyms.push_back(parentSym);
+          }
 
           Fortran::lower::AddrAndBoundsInfo info =
               Fortran::lower::gatherDataOperandAddrAndBounds<
@@ -1927,22 +2071,33 @@ bool ClauseProcessor::processMap(
           // types to optimise
           mlir::Value mapOp = createMapInfoOp(
               firOpBuilder, clauseLocation, symAddr, mlir::Value{},
-              asFortran.str(), bounds, {},
+              asFortran.str(), bounds, {}, mlir::ArrayAttr{},
               static_cast<
                   std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
-                  mapTypeBits),
+                  objectsMapTypeBits),
               mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
 
-          mapOperands.push_back(mapOp);
-          if (mapSymTypes)
-            mapSymTypes->push_back(symAddr.getType());
-          if (mapSymLocs)
-            mapSymLocs->push_back(symAddr.getLoc());
-
-          if (mapSymbols)
-            mapSymbols->push_back(getOmpObjectSymbol(ompObject));
+          if (parentSym) {
+            memberMaps.push_back(mapOp);
+          } else {
+            mapOperands.push_back(mapOp);
+            mapSyms.push_back(getOmpObjectSymbol(ompObject));
+            if (mapSymTypes)
+              mapSymTypes->push_back(symAddr.getType());
+            if (mapSymLocs)
+              mapSymLocs->push_back(symAddr.getLoc());
+          }
         }
       });
+
+  insertChildMapInfoIntoParent(converter, memberParentSyms, memberMaps,
+                               memberPlacementIndices, mapOperands, mapSymTypes,
+                               mapSymLocs, &mapSyms);
+
+  if (mapSymbols)
+    *mapSymbols = mapSyms;
+
+  return clauseFound;
 }
 
 bool ClauseProcessor::processReduction(
@@ -2021,7 +2176,12 @@ bool ClauseProcessor::processMotionClauses(
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::StatementContext &stmtCtx,
     llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
-  return findRepeatableClause<T>(
+  llvm::SmallVector<mlir::Value> memberMaps;
+  llvm::SmallVector<mlir::Attribute> memberPlacementIndices;
+  llvm::SmallVector<const Fortran::semantics::Symbol *> memberParentSyms,
+      mapSymbols;
+
+  bool clauseFound = findRepeatableClause<T>(
       [&](const T *motionClause, const Fortran::parser::CharBlock &source) {
         mlir::Location clauseLocation = converter.genLocation(source);
         fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -2036,8 +2196,23 @@ bool ClauseProcessor::processMotionClauses(
                 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
 
         for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
+          llvm::omp::OpenMPOffloadMappingFlags objectsMapTypeBits = mapTypeBits;
+          checkAndApplyDeclTargetMapFlags(converter, objectsMapTypeBits,
+                                          getOmpObjectSymbol(ompObject));
+
           llvm::SmallVector<mlir::Value> bounds;
           std::stringstream asFortran;
+          const Fortran::semantics::Symbol *parentSym = nullptr;
+
+          if (getOmpObjectSymbol(ompObject)->owner().IsDerivedType()) {
+            memberPlacementIndices.push_back(
+                firOpBuilder.getI64IntegerAttr(findComponenetMemberPlacement(
+                    getOmpObjectSymbol(ompObject)->owner().symbol(),
+                    getOmpObjectSymbol(ompObject))));
+            parentSym = getOmpObjParentSymbol(ompObject);
+            memberParentSyms.push_back(parentSym);
+          }
+
           Fortran::lower::AddrAndBoundsInfo info =
               Fortran::lower::gatherDataOperandAddrAndBounds<
                   Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
@@ -2056,15 +2231,25 @@ bool ClauseProcessor::processMotionClauses(
           // types to optimise
           mlir::Value mapOp = createMapInfoOp(
               firOpBuilder, clauseLocation, symAddr, mlir::Value{},
-              asFortran.str(), bounds, {},
+              asFortran.str(), bounds, {}, mlir::ArrayAttr{},
               static_cast<
                   std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
-                  mapTypeBits),
+                  objectsMapTypeBits),
               mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
 
-          mapOperands.push_back(mapOp);
+          if (parentSym) {
+            memberMaps.push_back(mapOp);
+          } else {
+            mapOperands.push_back(mapOp);
+            mapSymbols.push_back(getOmpObjectSymbol(ompObject));
+          }
         }
       });
+
+  insertChildMapInfoIntoParent(converter, memberParentSyms, memberMaps,
+                               memberPlacementIndices, mapOperands, nullptr,
+                               nullptr, &mapSymbols);
+  return clauseFound;
 }
 
 template <typename... Ts>
@@ -2882,7 +3067,7 @@ static void genBodyOfTargetOp(
         firOpBuilder.setInsertionPoint(targetOp);
         mlir::Value mapOp = createMapInfoOp(
             firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(),
-            bounds, llvm::SmallVector<mlir::Value>{},
+            bounds, llvm::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
             static_cast<
                 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
                 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
@@ -3018,7 +3203,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
 
         mlir::Value mapOp = createMapInfoOp(
             converter.getFirOpBuilder(), baseOp.getLoc(), baseOp, mlir::Value{},
-            name.str(), bounds, {},
+            name.str(), bounds, {}, mlir::ArrayAttr{},
             static_cast<
                 std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
                 mapFlag),
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index ba2e267996150e..ce5ce3ed1bc48d 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -17,7 +17,7 @@ add_flang_library(FIRTransforms
   AddDebugFoundation.cpp
   PolymorphicOpConversion.cpp...
[truncated]

@agozillon
Copy link
Contributor Author

agozillon commented Feb 12, 2024

I believe this is the top of the PR stack that should pass... however, it's my first time using SPR so we'll see how it goes. If there's anything odd with the PR (outside of the usual code review) please don't hesitate to mention it so I can address it!

I believe I fragmented the stack appropriately into:

However, if you spot a test or file that isn't in the right part of the stack, please do point it out and I can move it, there's a lot of tests across the patch series so there's a chance I've misplaced one or two.

Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially reviewed, will continue later.

Thanks Andrew, I am learning quite a bit from this PR.

Comment on lines 63 to 64
} else if (auto *structComp = Fortran::parser::Unwrap<
Fortran::parser::StructureComponent>(designator)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is dead now, right? I will never execute AFAICT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch! Thank you, it's a rebase artifact from putting it on top of a recent fix from @kparzysz

Comment on lines 77 to 78
static Fortran::semantics::Symbol *
getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having this and the above functions, can we have one function with a bool getParentObjWhenApplicable argument?

I am suggesting this because almost all the logic is repeated with the exception of the StructureComponent case, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, should be able to do!

const Fortran::semantics::Symbol *dTypeSym,
const Fortran::semantics::Symbol *componentSym) {
int placement = -1;
if (const auto *derived{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to use: if (const auto *derived = ....) instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no real reason, it's just the style that's used in a lot of places in Flang, so I mimic it here. But happy to change it

const Fortran::semantics::Symbol *dTypeSym,
const Fortran::semantics::Symbol *componentSym) {
int placement = -1;
if (const auto *derived{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LLVM style guide suggests to use early exits when possibe. Can we invert the condition and exit with -1 in the if and then execute the main logic after we close the if?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be possible, will give it a try!

int placement = -1;
if (const auto *derived{
dTypeSym->detailsIf<Fortran::semantics::DerivedTypeDetails>()}) {
for (auto t : derived->componentNames()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic looks like a good candidate to be a method inside DerivedTypeDetails, WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to do that, does that seem reasonable @kiranchandramohan

Comment on lines 1897 to 1901
mlir::Operation *op =
converter.getModuleOp().lookupSymbol(converter.mangleName(*symbol));
if (op)
if (auto declareTargetOp =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use auto declareTargetOp = SymbolTable::lookupNearestSymbolFrom<omp::DeclareTargetInterface>(converter.getModuleOp(), converter.mangleName(*symbol)); instead?

It will collapse all the 3 lines and, I think, properly encapsulate what we are doing here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense!

static void insertChildMapInfoIntoParent(
Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> &memberParentSyms,
llvm::SmallVector<mlir::Value> &memberMaps,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to use the more general mlir::Value rather than omp::MapInfoOp? I think all elements of memberMaps are always instances of MapInfoOp, right?

It can be argued that we don't need full access to the op's data but my suggestion is to provide more "documentation" in the code by using as much specific types as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason from what I recall, we just tend to pass things around as mlir::Value's (and I've gotten complacent with it I imagine) perhaps as it makes it easier to pass things around as we tend to use the more generalised mlir::Value most places as opposed to the operation itself. However, I'll see what I can do!

@agozillon
Copy link
Contributor Author

Partially reviewed, will continue later.

Thanks Andrew, I am learning quite a bit from this PR.

No worries, I'll await your review completion to update the PR! Please do take your time though, I'm aware it's a large PR.

for (size_t i = 0; i < mapSymbols->size(); ++i) {
if ((*mapSymbols)[i] == sym) {
parentExists = true;
parentIdx = i;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parentIdx = i;
parentIdx = i;
break;

memberIndexTmp.push_back(memberPlacementIndices[idx]);
mapOp.setMembersIndexAttr(mlir::ArrayAttr::get(
converter.getFirOpBuilder().getContext(), memberIndexTmp));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the parent is not a mapOp, is it a valid case? Should we assert to ensure that no pre-condtions are broken?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should never be anything other than a MapInfoOp, happy to add an assert.

// verification issues. The more appropriate thing may be to take
// the "main" map type clause from the directive being used.
uint64_t mapType = 0;
if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here, is it valid to not be a MapInfoOp or a violation of what should be expected by the function all the time?

Apologies if I missed something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's assumed to only ever be a MapInfoOp, same as every other case where we cast to the MapInfoOp and then directly use its fields

auto origSymbol = converter.getSymbolAddress(*sym);
mlir::Value mapOp = createMapInfoOp(
converter.getFirOpBuilder(),
converter.getFirOpBuilder().getUnknownLoc(), origSymbol,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can provide more precise location info.

Suggested change
converter.getFirOpBuilder().getUnknownLoc(), origSymbol,
origSymbol.getLoc(), origSymbol,

Might help with troubleshooting issues, I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good.

designator)) {
if (auto *structComp = Fortran::parser::Unwrap<
Fortran::parser::StructureComponent>(designator)) {
sym = GetFirstName(structComp->base).symbol;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const Name &GetFirstName(const StructureComponent &x) does this for you.

Suggested change
sym = GetFirstName(structComp->base).symbol;
sym = GetFirstName(structComp).symbol;

Just to not repeat the logic if it is already provided somewhere else.

if (getOmpObjectSymbol(ompObject)->owner().IsDerivedType()) {
memberPlacementIndices.push_back(
firOpBuilder.getI64IntegerAttr(findComponenetMemberPlacement(
getOmpObjectSymbol(ompObject)->owner().symbol(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a dummy question, but in which cases would getOmpObjectSymbol(ompObject)->owner().symbol() and getOmpObjParentSymbol(ompObject) would return different symbols? Can you give me an example for such a case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a dumb question at all (no question is imo!), and someone else may be able to shine further light on it if my assumptions are wrong (still learning my away around the derived types myself).

I believe the way getOmpObjParentSymbol is currently set up so that it will retrieve the first symbol in the list (as it uses getLastName), e.g. if we had some kind of nested derived type mapping like below:

map(to: dtype1%dtype2%scalar)

From my understanding we'd get dtype1, in the case of getOmpObjectSymbol(ompObject)->owner().symbol(), we would get dtype2, the difference between one up and the first effectively! And in the above case we would like to get the index of the member in its direct parent I believe.

However, for the currently covered set of cases by this PR using either should result in the same result I believe, so perhaps I was getting a little ahead of myself with this line!

Copy link
Contributor Author

@agozillon agozillon Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, I recalled a little more of the details behind this choice at the moment after looking into this little segment again, and I'd love for someone with perhaps more knowledge on this to chime in if there's perhaps a better method, perhaps there's something I am not doing correctly.

So another factor is that in the use case above, given the following derived type and its instantiation:

 type t0
    integer :: a0, a1
  end type

type(t0) :: a

!$omp target map(a%a1)
   a%a1 = 0
!$omp end target

getOmpObjParentSymbol will return the symbol:

a size=8 offset=0: ObjectEntity type: TYPE(t0)

which is the symbol for the instantiation of the derived type being used within the map clause, whereas if we use getOmpObjectSymbol(ompObject)->owner().symbol(), we end up with the following:

t0: DerivedType components: a0,a1

Which in this case appears to refer to the derived type declaration i.e. the owner of the derived type component symbol (a1 in this case). The latter is needed for finding the relevant component indices, as we can check the member symbols, the former is the symbol that will actually be bound and used for the map.

Please do take the above with a grain of salt though, I could be making certain assumptions based on the things I've done so far.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Aha, I see now. Thanks for clarifying this!

flang/lib/Lower/OpenMP.cpp Outdated Show resolved Hide resolved
Comment on lines 195 to 196
if (auto mapClauseOwner =
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Early return to reduce indentation.

Suggested change
if (auto mapClauseOwner =
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
auto mapClauseOwner =
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target);
if (!mapClauseOwner) return;

llvm::SmallVector<mlir::Value> newMapOps;
mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapOperands();

for (size_t i = 0; i < mapOperandsArr.size(); ++i) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use std::find(mapOperandsArr.begin(), ..., op)? We can get rid of the for loop and make the code a bit easier to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, I'll give it a try, I really should get back into the pattern of using std/llvm helper functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's applicable in this case unfortunately (perhaps I am missing something though, which is quite possible), we need to iterate over the whole operand list as we still insert when we do not find the operation.

flang/lib/Optimizer/Transforms/OMPMapInfoFinalization.cpp Outdated Show resolved Hide resolved
Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the entire PR and commented on all what I thought worth mentioning. But I will leave the approval to people working on this longer than I am. Hopefully my review makes other reviewers go over the PR faster.

One comment about the lit tests in the PR in general: I see in quite a few places, there are CHECK lines that use the SSA names directly (e.g. %20). I think this is fragile and it would be better to either capture the name of the SSA value if we care about it and want to check it later (i.e. using %[[name:.*]]) or capture and discard the name of the value in a generic way (i.e. using %{{.*}}).

Thanks Andrew for answering my questions.

flang/lib/Lower/OpenMP.cpp Outdated Show resolved Hide resolved
if (getOmpObjectSymbol(ompObject)->owner().IsDerivedType()) {
memberPlacementIndices.push_back(
firOpBuilder.getI64IntegerAttr(findComponenetMemberPlacement(
getOmpObjectSymbol(ompObject)->owner().symbol(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Aha, I see now. Thanks for clarifying this!

flang/lib/Optimizer/Transforms/OMPMapInfoFinalization.cpp Outdated Show resolved Hide resolved
@agozillon
Copy link
Contributor Author

Most recent updates were a rebase for the entire stack and an update of this current PR (Fortran level changes) to try and address the very good feedback. However, first time doing this using SPR, so if anyone notices any weirdness please don't hesitate to point it out!

The main changes from the update are:

  • Attempting to remove the few cases of hard-coded SSA values from tests
  • Early exitting where suggested and where possible
  • Emission of assert when certain expected conditions are not met
  • Movement of mapSymbols to a required argument for processMap, as we always end up generating the list in any case now as with the current implementation for finding prior mapped parents we need to keep track of the symbol list
  • Removal of seperate getOmpObj... function for parent, I've opted to just use getFirstName for this PR, I think it may need something a little more robust for nested derived types (I made an attempt at making something that'd retrieve the parent one up from the current element, but it's not required for this PR and I'd like to test it a bit more first), but no sense getting ahead of myself if I can keep it simple for now
  • Better/clearer method of getting the derived type definition symbol that we utilise to access the member symbols for calculating the member indices

Created using spr 1.3.4
Created using spr 1.3.4
@agozillon
Copy link
Contributor Author

Rebased on recent changes to the OpenMP lowering!

Plus a little ping for some reviewer attention if at all possible please to get some forward momentum on the PR stack, thank you very much ahead of time!

@agozillon
Copy link
Contributor Author

Going to close this and re-open it with the rest of the stack (one of the commit names was a little too long, so going to shorten them in general if I can).

@agozillon agozillon closed this Feb 23, 2024
@agozillon agozillon deleted the users/agozillon/spr/flangopenmpmlir-extend-derived-record-type-map-support-in-flang-openmp-by-adding-some-initial-support-for-explicit-member-mapping branch February 23, 2024 23:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants