Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenMP][MLIR] Extend explicit derived type member mapping support for OpenMP dialects lowering to LLVM-IR #81510

Conversation

agozillon
Copy link
Contributor

This patch seeks to refactor slightly and extend the current record type map
support that was put in place for Fortran's descriptor types to handle explicit
member mapping for record types at a single level of depth (the case of explicit
mapping of nested record types is currently unsupported).

This patch seeks to support this by extending the OpenMPToLLVMIRTranslation phase
to more generally support record types, building on the prior groundwork in the
Fortran allocatables/pointers patch. It now supports different kinds of record type
mapping, in this case full record type mapping and then explicit member mapping
in which there is a special case for certain types when mapped individually to not
require any parent map link in the kernel argument structure. To facilitate this
required:

  • The movement of the setting of the map flag type "ptr_and_obj" to respective
    frontends, now supporting it as a possible flag that can be read and printed
    in mlir form. Some minor changes to declare target map type setting was
    neccesary for this.
  • The addition of a member index array operand, which tracks the position
    of the member in the parent, required for caclulating the appropriate size
    to offload to the target, alongside the parents offload pointer (always the
    first member currently being mapped).
  • A partial mapping attribute operand, to indicate if the entire record type is
    being mapped or just member components, aiding the ability to lower
    record types in the different manners that are possible.
  • Refactoring bounds calculation for record types and general arrays to one
    location (as well as load/store generation prior to assigning to the kernel
    argument structure), as a side affect enter/exit/update/data mapping
    should now be more correct and fully support bounds mapping, previously
    this would have only worked for target.

Created using spr 1.3.4
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 12, 2024

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir

Author: None (agozillon)

Changes

This patch seeks to refactor slightly and extend the current record type map
support that was put in place for Fortran's descriptor types to handle explicit
member mapping for record types at a single level of depth (the case of explicit
mapping of nested record types is currently unsupported).

This patch seeks to support this by extending the OpenMPToLLVMIRTranslation phase
to more generally support record types, building on the prior groundwork in the
Fortran allocatables/pointers patch. It now supports different kinds of record type
mapping, in this case full record type mapping and then explicit member mapping
in which there is a special case for certain types when mapped individually to not
require any parent map link in the kernel argument structure. To facilitate this
required:

  • The movement of the setting of the map flag type "ptr_and_obj" to respective
    frontends, now supporting it as a possible flag that can be read and printed
    in mlir form. Some minor changes to declare target map type setting was
    neccesary for this.
  • The addition of a member index array operand, which tracks the position
    of the member in the parent, required for caclulating the appropriate size
    to offload to the target, alongside the parents offload pointer (always the
    first member currently being mapped).
  • A partial mapping attribute operand, to indicate if the entire record type is
    being mapped or just member components, aiding the ability to lower
    record types in the different manners that are possible.
  • Refactoring bounds calculation for record types and general arrays to one
    location (as well as load/store generation prior to assigning to the kernel
    argument structure), as a side affect enter/exit/update/data mapping
    should now be more correct and fully support bounds mapping, previously
    this would have only worked for target.

Patch is 49.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81510.diff

4 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+300-204)
  • (modified) mlir/test/Target/LLVMIR/omptarget-fortran-allocatable-types-host.mlir (+14-16)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+27-22)
  • (added) mlir/test/Target/LLVMIR/omptarget-record-type-mapping-host.mlir (+63)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 79956f82ed141a..1ba8099135c35f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1750,9 +1750,9 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
 
       mapData.BaseType.push_back(
           moduleTranslation.convertType(mapOp.getVarType()));
-      mapData.Sizes.push_back(getSizeInBytes(
-          dl, mapOp.getVarType(), mapOp, mapData.BasePointers.back(),
-          mapData.BaseType.back(), builder, moduleTranslation));
+      mapData.Sizes.push_back(
+          getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
+                         mapData.BaseType.back(), builder, moduleTranslation));
       mapData.MapClause.push_back(mapOp.getOperation());
       mapData.Types.push_back(
           llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
@@ -1783,6 +1783,98 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
   }
 }
 
+static int getMapDataMemberIdx(MapInfoData &mapData,
+                               mlir::omp::MapInfoOp memberOp) {
+  int memberDataIdx = -1;
+  for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
+    if (mapData.MapClause[i] == memberOp)
+      memberDataIdx = i;
+  }
+  return memberDataIdx;
+}
+
+static mlir::omp::MapInfoOp
+getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
+  // Only 1 member has been mapped, we can return it.
+  if (mapInfo.getMembersIndex()->size() == 1)
+    if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
+            mapInfo.getMembers()[0].getDefiningOp()))
+      return mapOp;
+
+  int64_t curPos =
+      mapInfo.getMembersIndex()->begin()->cast<mlir::IntegerAttr>().getInt();
+
+  int64_t idx = 1, curIdx = 0, memberPlacement = 0;
+  for (const auto *iter = std::next(mapInfo.getMembersIndex()->begin());
+       iter != mapInfo.getMembersIndex()->end(); iter++) {
+    memberPlacement = iter->cast<mlir::IntegerAttr>().getInt();
+    if (first) {
+      if (memberPlacement < curPos) {
+        curIdx = idx;
+        curPos = memberPlacement;
+      }
+    } else {
+      if (memberPlacement > curPos) {
+        curIdx = idx;
+        curPos = memberPlacement;
+      }
+    }
+    idx++;
+  }
+
+  if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
+          mapInfo.getMembers()[curIdx].getDefiningOp()))
+    return mapOp;
+
+  return {};
+}
+
+std::vector<llvm::Value *>
+calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
+                      llvm::IRBuilderBase &builder, bool isArrayTy,
+                      mlir::OperandRange bounds) {
+  std::vector<llvm::Value *> idx;
+  llvm::Value *offsetAddress = nullptr;
+  if (!bounds.empty()) {
+    idx.push_back(builder.getInt64(0));
+    if (isArrayTy) {
+      for (int i = bounds.size() - 1; i >= 0; --i) {
+        if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
+                bounds[i].getDefiningOp())) {
+          idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
+        }
+      }
+    } else {
+      std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
+      for (size_t i = 1; i < bounds.size(); ++i) {
+        if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
+                bounds[i].getDefiningOp())) {
+          dimensionIndexSizeOffset.push_back(builder.CreateMul(
+              moduleTranslation.lookupValue(boundOp.getExtent()),
+              dimensionIndexSizeOffset[i - 1]));
+        }
+      }
+
+      for (int i = bounds.size() - 1; i >= 0; --i) {
+        if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
+                bounds[i].getDefiningOp())) {
+          if (!offsetAddress)
+            offsetAddress = builder.CreateMul(
+                moduleTranslation.lookupValue(boundOp.getLowerBound()),
+                dimensionIndexSizeOffset[i]);
+          else
+            offsetAddress = builder.CreateAdd(
+                offsetAddress, builder.CreateMul(moduleTranslation.lookupValue(
+                                                     boundOp.getLowerBound()),
+                                                 dimensionIndexSizeOffset[i]));
+        }
+      }
+    }
+  }
+
+  return offsetAddress ? std::vector<llvm::Value *>{offsetAddress} : idx;
+}
+
 // This creates two insertions into the MapInfosTy data structure for the
 // "parent" of a set of members, (usually a container e.g.
 // class/structure/derived type) when subsequent members have also been
@@ -1810,7 +1902,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
   combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
       mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
   combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
-  combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
 
   // Calculate size of the parent object being mapped based on the
   // addresses at runtime, highAddr - lowAddr = size. This of course
@@ -1819,24 +1910,40 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
   // Fortran pointers and allocatables, the mapping of the pointed to
   // data by the descriptor (which itself, is a structure containing
   // runtime information on the dynamically allocated data).
-  llvm::Value *lowAddr = builder.CreatePointerCast(
-      mapData.Pointers[mapDataIndex], builder.getPtrTy());
-  llvm::Value *highAddr = builder.CreatePointerCast(
-      builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
-                                 mapData.Pointers[mapDataIndex], 1),
-      builder.getPtrTy());
+  auto parentClause =
+      mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
+
+  llvm::Value *lowAddr, *highAddr;
+  if (!parentClause.getPartialMap()) {
+    lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
+                                        builder.getPtrTy());
+    highAddr = builder.CreatePointerCast(
+        builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
+                                   mapData.Pointers[mapDataIndex], 1),
+        builder.getPtrTy());
+    combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
+  } else {
+    auto mapOp =
+        mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
+    int firstMemberIdx = getMapDataMemberIdx(
+        mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
+    lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
+                                        builder.getPtrTy());
+    int lastMemberIdx = getMapDataMemberIdx(
+        mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
+    highAddr = builder.CreatePointerCast(
+        builder.CreateGEP(mapData.BaseType[lastMemberIdx],
+                          mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
+        builder.getPtrTy());
+    combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
+  }
+
   llvm::Value *size = builder.CreateIntCast(
       builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
       builder.getInt64Ty(),
       /*isSigned=*/false);
   combinedInfo.Sizes.push_back(size);
 
-  // This creates the initial MEMBER_OF mapping that consists of
-  // the parent/top level container (same as above effectively, except
-  // with a fixed initial compile time size and seperate maptype which
-  // indicates the true mape type (tofrom etc.) and that it is a part
-  // of a larger mapping and indicating the link between it and it's
-  // members that are also explicitly mapped).
   llvm::omp::OpenMPOffloadMappingFlags mapFlag =
       llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
   if (isTargetParams)
@@ -1846,15 +1953,22 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
       ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
   ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
 
-  combinedInfo.Types.emplace_back(mapFlag);
-  combinedInfo.DevicePointers.emplace_back(
-      llvm::OpenMPIRBuilder::DeviceInfoTy::None);
-  combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
-      mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
-  combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
-  combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
-  combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
-
+  // This creates the initial MEMBER_OF mapping that consists of
+  // the parent/top level container (same as above effectively, except
+  // with a fixed initial compile time size and seperate maptype which
+  // indicates the true mape type (tofrom etc.). This parent mapping is
+  // only relevant if the structure in it's totality is being mapped,
+  // otherwise the above suffices.
+  if (!parentClause.getPartialMap()) {
+    combinedInfo.Types.emplace_back(mapFlag);
+    combinedInfo.DevicePointers.emplace_back(
+        llvm::OpenMPIRBuilder::DeviceInfoTy::None);
+    combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
+        mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
+    combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
+    combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
+    combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
+  }
   return memberOfFlag;
 }
 
@@ -1871,86 +1985,99 @@ static void processMapMembersWithParent(
   for (auto mappedMembers : parentClause.getMembers()) {
     auto memberClause =
         mlir::dyn_cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
-    int memberDataIdx = -1;
-    for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
-      if (mapData.MapClause[i] == memberClause)
-        memberDataIdx = i;
-    }
+    int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
 
     assert(memberDataIdx >= 0 && "could not find mapped member of structure");
 
     // Same MemberOfFlag to indicate its link with parent and other members
-    // of, and we flag that it's part of a pointer and object coupling.
+    // of
     auto mapFlag =
         llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
     mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
+    mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
     ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
-    mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
+
     combinedInfo.Types.emplace_back(mapFlag);
     combinedInfo.DevicePointers.emplace_back(
         llvm::OpenMPIRBuilder::DeviceInfoTy::None);
     combinedInfo.Names.emplace_back(
         LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
-
-    combinedInfo.BasePointers.emplace_back(mapData.BasePointers[memberDataIdx]);
-
-    std::vector<llvm::Value *> idx{builder.getInt64(0)};
-    llvm::Value *offsetAddress = nullptr;
-    if (!memberClause.getBounds().empty()) {
-      if (mapData.BaseType[memberDataIdx]->isArrayTy()) {
-        for (int i = memberClause.getBounds().size() - 1; i >= 0; --i) {
-          if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
-                  memberClause.getBounds()[i].getDefiningOp())) {
-            idx.push_back(
-                moduleTranslation.lookupValue(boundOp.getLowerBound()));
-          }
-        }
-      } else {
-        std::vector<llvm::Value *> dimensionIndexSizeOffset{
-            builder.getInt64(1)};
-        for (size_t i = 1; i < memberClause.getBounds().size(); ++i) {
-          if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
-                  memberClause.getBounds()[i].getDefiningOp())) {
-            dimensionIndexSizeOffset.push_back(builder.CreateMul(
-                moduleTranslation.lookupValue(boundOp.getExtent()),
-                dimensionIndexSizeOffset[i - 1]));
-          }
-        }
-
-        for (int i = memberClause.getBounds().size() - 1; i >= 0; --i) {
-          if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
-                  memberClause.getBounds()[i].getDefiningOp())) {
-            if (!offsetAddress)
-              offsetAddress = builder.CreateMul(
-                  moduleTranslation.lookupValue(boundOp.getLowerBound()),
-                  dimensionIndexSizeOffset[i]);
-            else
-              offsetAddress = builder.CreateAdd(
-                  offsetAddress,
-                  builder.CreateMul(
-                      moduleTranslation.lookupValue(boundOp.getLowerBound()),
-                      dimensionIndexSizeOffset[i]));
-          }
-        }
-      }
-    }
-
-    llvm::Value *memberIdx =
-        builder.CreateLoad(builder.getPtrTy(), mapData.Pointers[memberDataIdx]);
-    memberIdx = builder.CreateInBoundsGEP(
-        mapData.BaseType[memberDataIdx], memberIdx,
-        offsetAddress ? std::vector<llvm::Value *>{offsetAddress} : idx,
-        "member_idx");
-    combinedInfo.Pointers.emplace_back(memberIdx);
+    combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
+    combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
     combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
   }
 }
 
+// This may be a bit of a naive check, the intent is to verify if the
+// mapped data being passed is a pointer -> pointee that requires special
+// handling in certain cases. There may be a better way to verify this, but
+// unfortunately with opaque pointers we lose the ability to easily check if
+// something is a pointer whilst maintaining access to the underlying type.
+static bool checkIfPointerMap(llvm::omp::OpenMPOffloadMappingFlags mapFlag) {
+  return static_cast<
+             std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+             mapFlag &
+             llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) != 0;
+}
+
+static void
+processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
+                     llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
+                     bool isTargetParams, int mapDataParentIdx = -1) {
+  // Declare Target Mappings are excluded from being marked as
+  // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
+  // marked with OMP_MAP_PTR_AND_OBJ instead.
+  auto mapFlag = mapData.Types[mapDataIdx];
+  if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
+    mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
+
+  if (auto mapInfoOp =
+          dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIdx]))
+    if (mapInfoOp.getMapCaptureType().value() ==
+            mlir::omp::VariableCaptureKind::ByCopy &&
+        !checkIfPointerMap(mapFlag))
+      mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
+
+  // if we're provided a mapDataParentIdx, then the data being mapped is
+  // part of a larger object (in a parent <-> member mapping) and in this
+  // case our BasePointer should be the parent.
+  if (mapDataParentIdx >= 0)
+    combinedInfo.BasePointers.emplace_back(
+        mapData.BasePointers[mapDataParentIdx]);
+  else
+    combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);
+
+  combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
+  combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
+  combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
+  combinedInfo.Types.emplace_back(mapFlag);
+  combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
+}
+
 static void processMapWithMembersOf(
     LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
     llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
     llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
     uint64_t mapDataIndex, bool isTargetParams) {
+  auto parentClause =
+      mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
+  // If we have a partial map (no parent referneced in the map clauses of the
+  // directive, only members) and only a single member, we do not need to bind
+  // the map of the member to the parent, we can pass the member seperately.
+  if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
+    auto memberClause = mlir::dyn_cast<mlir::omp::MapInfoOp>(
+        parentClause.getMembers()[0].getDefiningOp());
+    int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
+    // Primarily only scalars can be optimised this way it seems, array's
+    // need to be mapped as a regular record <-> member map even if partially
+    // mapping.
+    if (!mapData.BaseType[memberDataIdx]->isArrayTy()) {
+      processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams,
+                           mapDataIndex);
+      return;
+    }
+  }
+
   llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
       mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
                            combinedInfo, mapData, mapDataIndex, isTargetParams);
@@ -1959,6 +2086,79 @@ static void processMapWithMembersOf(
                               memberOfParentFlag);
 }
 
+// This is a variation on Clang's GenerateOpenMPCapturedVars, which
+// generates different operation (e.g. load/store) combinations for
+// arguments to the kernel, based on map capture kinds which are then
+// utilised in the combinedInfo in place of the original Map value.
+static void
+createAlteredByCaptureMap(MapInfoData &mapData,
+                          LLVM::ModuleTranslation &moduleTranslation,
+                          llvm::IRBuilderBase &builder) {
+  for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
+    // if it's declare target, skip it, it's handled seperately.
+    if (!mapData.IsDeclareTarget[i]) {
+      mlir::omp::VariableCaptureKind captureKind =
+          mlir::omp::VariableCaptureKind::ByRef;
+
+      auto mapOp =
+          mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
+      captureKind = mapOp.getMapCaptureType().value_or(
+          mlir::omp::VariableCaptureKind::ByRef);
+
+      bool isPtrTy = checkIfPointerMap(
+          llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
+
+      // Currently handles array sectioning lowerbound case, but more
+      // logic may be required in the future. Clang invokes EmitLValue,
+      // which has specialised logic for special Clang types such as user
+      // defines, so it is possible we will have to extend this for
+      // structures or other complex types. As the general idea is that this
+      // function mimics some of the logic from Clang that we require for
+      // kernel argument passing from host -> device.
+      switch (captureKind) {
+      case mlir::omp::VariableCaptureKind::ByRef: {
+        llvm::Value *newV = mapData.Pointers[i];
+        std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
+            moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
+            mapOp.getBounds());
+        if (isPtrTy)
+          newV = builder.CreateLoad(builder.getPtrTy(), newV);
+
+        if (!offsetIdx.empty())
+          newV = builder.CreateInBoundsGEP(mapData.BaseType[i], newV, offsetIdx,
+                                           "array_offset");
+        mapData.Pointers[i] = newV;
+      } break;
+      case mlir::omp::VariableCaptureKind::ByCopy: {
+        llvm::Value *newV;
+        if (mapData.Pointers[i]->getType()->isPointerTy())
+          newV = builder.CreateLoad(mapData.BaseType[i], mapData.Pointers[i]);
+        else
+          newV = mapData.Pointers[i];
+
+        if (!isPtrTy) {
+          auto curInsert = builder.saveIP();
+          builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
+          auto *memTempAlloc =
+              builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
+          builder.restoreIP(curInsert);
+
+          builder.CreateStore(newV, memTempAlloc);
+          newV = builder.CreateLoad(builder.getPtrTy(), memTempAlloc);
+        }
+
+        mapData.Pointers[i] = newV;
+        mapData.BasePointers[i] = newV;
+...
[truncated]

Comment on lines 1786 to 1794
static int getMapDataMemberIdx(MapInfoData &mapData,
mlir::omp::MapInfoOp memberOp) {
int memberDataIdx = -1;
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
if (mapData.MapClause[i] == memberOp)
memberDataIdx = i;
}
return memberDataIdx;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::find_if(...), just like the other PR 😛.

Also it seems like all the uses of this function assume that -1 will not be returned (there has to be an element matching the search key). So, I would suggest:

Suggested change
static int getMapDataMemberIdx(MapInfoData &mapData,
mlir::omp::MapInfoOp memberOp) {
int memberDataIdx = -1;
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
if (mapData.MapClause[i] == memberOp)
memberDataIdx = i;
}
return memberDataIdx;
}
static int getMapDataMemberIdx(MapInfoData &mapData,
mlir::omp::MapInfoOp memberOp) {
auto res = llvm::find(mapData.MapClause, memberOp);
assert(res != mapData.MapClause.end());
return std::distance(mapData.MapClause.begin(), res);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, that's an excellent solution :-)

Comment on lines 1808 to 1823
for (const auto *iter = std::next(mapInfo.getMembersIndex()->begin());
iter != mapInfo.getMembersIndex()->end(); iter++) {
memberPlacement = iter->cast<mlir::IntegerAttr>().getInt();
if (first) {
if (memberPlacement < curPos) {
curIdx = idx;
curPos = memberPlacement;
}
} else {
if (memberPlacement > curPos) {
curIdx = idx;
curPos = memberPlacement;
}
}
idx++;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead sort getMembersIndex() somehow and fetch either the head or the tail based on first? Not sure how easy it is but my feeling is that it should be possible and would be easier to understand.

Or guarantee that it is always sorted on op creation/update. But maybe this is not worth it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine the former of sorting it when the function is called should be fine, I'll have to double check though.

The latter, I don't think is possible, or at the very least I'd like to maintain the user-ordering for the maps for the moment (and member indices by extension) as I think it is important specification wise (at least until we can rule it out in any case) as there's only one case where the specification calls for re-ordering (which we still need to add support for actually) and that's re-ordering from/tofrom/to map clauses in-front of alloca/delete/release.

mapInfo.getMembers()[curIdx].getDefiningOp()))
return mapOp;

return {};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for repeating this, but I think it is better to assert if some assumption in the code is violated rather than returning an empty value.

mlir::OperandRange bounds) {
std::vector<llvm::Value *> idx;
llvm::Value *offsetAddress = nullptr;
if (!bounds.empty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!bounds.empty()) {
if (bounds.empty())
return idx;

Comment on lines 1861 to 1869
if (!offsetAddress)
offsetAddress = builder.CreateMul(
moduleTranslation.lookupValue(boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]);
else
offsetAddress = builder.CreateAdd(
offsetAddress, builder.CreateMul(moduleTranslation.lookupValue(
boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can get rid of offsetAddress this way:

Suggested change
if (!offsetAddress)
offsetAddress = builder.CreateMul(
moduleTranslation.lookupValue(boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]);
else
offsetAddress = builder.CreateAdd(
offsetAddress, builder.CreateMul(moduleTranslation.lookupValue(
boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]));
if (idx.empty())
idx.emplace_back(builder.CreateMul(
moduleTranslation.lookupValue(boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]));
else
offsetAddress.back() = builder.CreateAdd(
offsetAddress.back(), builder.CreateMul(moduleTranslation.lookupValue(
boundOp.getLowerBound()),
dimensionIndexSizeOffset[i]));

}

std::vector<llvm::Value *>
calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is quite some complicated logic in this function, can you add some documentation on the function to provide a high-level description? And in the function itself, can you add a few more comments explaining things like why you are iterating backwards in some cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, hopefully the comments are a little helpful!

From my (possibly flawed, as I may be attributing the need for this to the wrong thing, so if someone wishes to correct me please do) understanding the backwards iteration is required to create the appropriate array offsetting in Fortran due to Fortran's column-major layout ordering (the bounds are still provided in row-major order I believe, the same as it's written in Fortran I think). If this is indeed the case then it does mean we're currently making a bit of an exception for Fortran here, and to keep this lowering and dialect language agnostic whenever we get a C/C++ frontend we will likely have to agree on a bounds ordering, and one of the languages may have to do a re-ordering in their frontend.

Created using spr 1.3.4
Created using spr 1.3.4
@agozillon
Copy link
Contributor Author

agozillon commented Feb 17, 2024

Updated this PR (the commits on the other stacked PRs can be ignored, they're due to me running a SPR diff on everything, as I made a little mistake during a rebase).

Change list in the recent commit attempts to address review comments by:

  • Remove unnecessary map type flag removal for future PR where it will actually be required
  • Modify getMapDataMemberIdx as suggested
  • Make retrieval of first or last mapped member a little cleaner through the usage of standard/llvm library functions
  • Add further comments and notes to help understanding a little (and a TODO)

Created using spr 1.3.4
@agozillon
Copy link
Contributor Author

a little ping for some reviewer attention if at all possible please to get some forward momentum on the PR stack, thank you very much ahead of time!

@agozillon
Copy link
Contributor Author

Going to close this and re-open it with the rest of the stack (one of the commit names was a little too long, so going to shorten them in general if I can).

@agozillon agozillon closed this Feb 23, 2024
@agozillon agozillon deleted the users/agozillon/spr/openmpmlir-extend-explicit-derived-type-member-mapping-support-for-openmp-dialects-lowering-to-llvm-ir branch February 23, 2024 23:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants