Skip to content

Conversation

TIFitis
Copy link
Member

@TIFitis TIFitis commented Sep 25, 2025

This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass.

  • Generalize MapInfoFinalization to add child maps for arbitrarily nested allocatables when a derived object is mapped via declare mapper.
  • Traverse HLFIR designates rooted at the target block arg and build full coordinate_of chains; append members with correct membersIndex.

This fixes #156461.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp offload labels Sep 25, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2025

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Akash Banerjee (TIFitis)

Changes

This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass.

  • Generalize MapInfoFinalization to add child maps for arbitrarily nested allocatables when a derived object is mapped via declare mapper.
  • Traverse HLFIR designates rooted at the target block arg and build full coordinate_of chains; append members with correct membersIndex.

This fixes #156461.


Full diff: https://github.com/llvm/llvm-project/pull/160766.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+164-57)
  • (modified) flang/test/Lower/OpenMP/declare-mapper.f90 (+38)
  • (added) offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 (+43)
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 57be863cfa1b8..0c7b1ceaf21f9 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -701,40 +701,37 @@ class MapInfoFinalizationPass
 
         auto recordType = mlir::cast<fir::RecordType>(underlyingType);
         llvm::SmallVector<mlir::Value> newMapOpsForFields;
-        llvm::SmallVector<int64_t> fieldIndicies;
-
-        for (auto fieldMemTyPair : recordType.getTypeList()) {
-          auto &field = fieldMemTyPair.first;
-          auto memTy = fieldMemTyPair.second;
-
-          bool shouldMapField =
-              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
-                if (!fir::isAllocatableType(memTy))
-                  return false;
-
-                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
-                if (!designateOp)
-                  return false;
-
-                return designateOp.getComponent() &&
-                       designateOp.getComponent()->strref() == field;
-              }) != mapVarForwardSlice.end();
-
-          // TODO Handle recursive record types. Adapting
-          // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
-          // entities might be helpful here.
-
-          if (!shouldMapField)
-            continue;
-
-          int32_t fieldIdx = recordType.getFieldIndex(field);
+        llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
+
+        auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
+                                   mlir::Type memTy,
+                                   llvm::ArrayRef<int64_t> indexPath,
+                                   llvm::StringRef memberName) {
+          // Avoid adding duplicates for the same index path within this op.
+          bool alreadyPlanned = llvm::any_of(
+              newMemberIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+                return p.size() == indexPath.size() &&
+                       std::equal(p.begin(), p.end(), indexPath.begin());
+              });
+          if (alreadyPlanned)
+            return;
+
+          // Check if already mapped (index path equality).
           bool alreadyMapped = [&]() {
             if (op.getMembersIndexAttr())
               for (auto indexList : op.getMembersIndexAttr()) {
                 auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
-                if (indexListAttr.size() == 1 &&
-                    mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
-                        fieldIdx)
+                if (indexListAttr.size() != indexPath.size())
+                  continue;
+                bool allEq = true;
+                for (auto [i, attr] : llvm::enumerate(indexListAttr)) {
+                  if (mlir::cast<mlir::IntegerAttr>(attr).getInt() !=
+                      indexPath[i]) {
+                    allEq = false;
+                    break;
+                  }
+                }
+                if (allEq)
                   return true;
               }
 
@@ -742,53 +739,165 @@ class MapInfoFinalizationPass
           }();
 
           if (alreadyMapped)
-            continue;
+            return;
 
           builder.setInsertionPoint(op);
-          fir::IntOrValue idxConst =
-              mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
-          auto fieldCoord = fir::CoordinateOp::create(
-              builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
-              llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
           fir::factory::AddrAndBoundsInfo info =
-              fir::factory::getDataOperandBaseAddr(
-                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+              fir::factory::getDataOperandBaseAddr(builder, coordRef,
+                                                   /*isOptional=*/false, loc);
           llvm::SmallVector<mlir::Value> bounds =
               fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
                                                  mlir::omp::MapBoundsType>(
                   builder, info,
-                  hlfir::translateToExtendedValue(op.getLoc(), builder,
-                                                  hlfir::Entity{fieldCoord})
+                  hlfir::translateToExtendedValue(loc, builder,
+                                                  hlfir::Entity{coordRef})
                       .first,
-                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+                  /*dataExvIsAssumedSize=*/false, loc);
 
           mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
-              builder, op.getLoc(), fieldCoord.getResult().getType(),
-              fieldCoord.getResult(),
-              mlir::TypeAttr::get(
-                  fir::unwrapRefType(fieldCoord.getResult().getType())),
+              builder, loc, coordRef.getType(), coordRef,
+              mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
               op.getMapTypeAttr(),
               builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
                   mlir::omp::VariableCaptureKind::ByRef),
               /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
               /*members_index=*/mlir::ArrayAttr{}, bounds,
               /*mapperId=*/mlir::FlatSymbolRefAttr(),
-              builder.getStringAttr(op.getNameAttr().strref() + "." + field +
-                                    ".implicit_map"),
+              builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                    memberName + ".implicit_map"),
               /*partial_map=*/builder.getBoolAttr(false));
           newMapOpsForFields.emplace_back(fieldMapOp);
-          fieldIndicies.emplace_back(fieldIdx);
+          newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+        };
+
+        // 1) Handle direct top-level allocatable fields.
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          if (!fir::isAllocatableType(memTy))
+            continue;
+
+          bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
+            auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
+            return designateOp && designateOp.getComponent() &&
+                   designateOp.getComponent()->strref() == field;
+          });
+          if (!referenced)
+            continue;
+
+          int32_t fieldIdx = recordType.getFieldIndex(field);
+          builder.setInsertionPoint(op);
+          fir::IntOrValue idxConst =
+              mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
+          auto fieldCoord = fir::CoordinateOp::create(
+              builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+          appendMemberMap(op.getLoc(), fieldCoord, memTy, {fieldIdx}, field);
+        }
+
+        // Handle nested allocatable fields along any component chain
+        // referenced in the region via HLFIR designates.
+        llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths;
+        for (mlir::Operation *sliceOp : mapVarForwardSlice) {
+          auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+          if (!designateOp || !designateOp.getComponent())
+            continue;
+          llvm::SmallVector<llvm::StringRef> compPathReversed;
+          compPathReversed.push_back(designateOp.getComponent()->strref());
+          mlir::Value curBase = designateOp.getMemref();
+          bool rootedAtMapArg = false;
+          while (true) {
+            if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
+              if (!parentDes.getComponent())
+                break;
+              compPathReversed.push_back(parentDes.getComponent()->strref());
+              curBase = parentDes.getMemref();
+              continue;
+            }
+            if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
+              if (auto barg =
+                      mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
+                rootedAtMapArg = (barg == opBlockArg);
+            } else if (auto blockArg =
+                           mlir::dyn_cast_or_null<mlir::BlockArgument>(
+                               curBase)) {
+              rootedAtMapArg = (blockArg == opBlockArg);
+            }
+            break;
+          }
+          // Only process nested paths (2+ components). Single-component paths
+          // for direct fields are handled above.
+          if (!rootedAtMapArg || compPathReversed.size() < 2)
+            continue;
+          builder.setInsertionPoint(op);
+          llvm::SmallVector<int64_t> indexPath;
+          mlir::Type curTy = underlyingType;
+          mlir::Value coordRef = op.getVarPtr();
+          bool validPath = true;
+          for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
+            auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
+            if (!recTy) {
+              validPath = false;
+              break;
+            }
+            int32_t idx = recTy.getFieldIndex(compName);
+            if (idx < 0) {
+              validPath = false;
+              break;
+            }
+            indexPath.push_back(idx);
+            mlir::Type memTy = recTy.getType(idx);
+            fir::IntOrValue idxConst =
+                mlir::IntegerAttr::get(builder.getI32Type(), idx);
+            coordRef = fir::CoordinateOp::create(
+                builder, op.getLoc(), builder.getRefType(memTy), coordRef,
+                llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+            curTy = memTy;
+          }
+          if (!validPath)
+            continue;
+          if (auto finalRefTy =
+                  mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
+            mlir::Type eleTy = finalRefTy.getElementType();
+            if (fir::isAllocatableType(eleTy)) {
+              bool isNew = llvm::none_of(
+                  seenIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+                    return p.size() == indexPath.size() &&
+                           std::equal(p.begin(), p.end(), indexPath.begin());
+                  });
+              if (isNew) {
+                seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+                appendMemberMap(op.getLoc(), coordRef, eleTy, indexPath,
+                                compPathReversed.front());
+              }
+            }
+          }
         }
 
         if (newMapOpsForFields.empty())
           return mlir::WalkResult::advance();
 
-        op.getMembersMutable().append(newMapOpsForFields);
+        // Deduplicate by index path to avoid emitting duplicate members for
+        // the same component.
+        llvm::SmallVector<mlir::Value> dedupMapOps;
+        llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths;
+        for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) {
+          const auto &path = newMemberIndexPaths[i];
+          bool isNew = llvm::none_of(
+              dedupIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+                return p.size() == path.size() &&
+                       std::equal(p.begin(), p.end(), path.begin());
+              });
+          if (isNew) {
+            dedupMapOps.push_back(mapOp);
+            dedupIndexPaths.emplace_back(path.begin(), path.end());
+          }
+        }
+        op.getMembersMutable().append(dedupMapOps);
         llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
-        mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
-
-        if (oldMembersIdxAttr)
-          for (mlir::Attribute indexList : oldMembersIdxAttr) {
+        if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
+          for (mlir::Attribute indexList : oldAttr) {
             llvm::SmallVector<int64_t> listVec;
 
             for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
@@ -796,10 +905,8 @@ class MapInfoFinalizationPass
 
             newMemberIndices.emplace_back(std::move(listVec));
           }
-
-        for (int64_t newFieldIdx : fieldIndicies)
-          newMemberIndices.emplace_back(
-              llvm::SmallVector<int64_t>(1, newFieldIdx));
+        for (auto &path : dedupIndexPaths)
+          newMemberIndices.emplace_back(path);
 
         op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
         op.setPartialMap(true);
diff --git a/flang/test/Lower/OpenMP/declare-mapper.f90 b/flang/test/Lower/OpenMP/declare-mapper.f90
index 8a98c68a8d582..1c51666c80f8a 100644
--- a/flang/test/Lower/OpenMP/declare-mapper.f90
+++ b/flang/test/Lower/OpenMP/declare-mapper.f90
@@ -6,6 +6,7 @@
 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90
 
 !--- omp-declare-mapper-1.f90
 subroutine declare_mapper_1
@@ -262,3 +263,40 @@ subroutine use_inner()
       !$omp end target
    end subroutine
 end program declare_mapper_5
+
+!--- omp-declare-mapper-6.f90
+subroutine declare_mapper_nested_parent
+  type :: inner_t
+    real, allocatable :: deep_arr(:)
+  end type inner_t
+
+  type, abstract :: base_t
+    real, allocatable :: base_arr(:)
+    type(inner_t) :: inner
+  end type base_t
+
+  type, extends(base_t) :: real_t
+    real, allocatable :: real_arr(:)
+  end type real_t
+
+  !$omp declare mapper (custommapper : real_t :: t) map(tofrom: t%base_arr, t%real_arr)
+
+  type(real_t) :: r
+
+  allocate(r%base_arr(10))
+  allocate(r%inner%deep_arr(10))
+  allocate(r%real_arr(10))
+  r%base_arr = 1.0
+  r%inner%deep_arr = 4.0
+  r%real_arr = 0.0
+
+  ! CHECK: omp.target
+  ! Check implicit maps for nested parent and deep nested allocatable payloads
+  ! CHECK-DAG: omp.map.info {{.*}} {name = "r.base_arr.implicit_map"}
+  ! CHECK-DAG: omp.map.info {{.*}} {name = "r.deep_arr.implicit_map"}
+  ! The declared mapper's own allocatable is still mapped implicitly
+  ! CHECK-DAG: omp.map.info {{.*}} {name = "r.real_arr.implicit_map"}
+  !$omp target map(mapper(custommapper), tofrom: r)
+    r%real_arr = r%base_arr(1) + r%inner%deep_arr(1)
+  !$omp end target
+end subroutine declare_mapper_nested_parent
diff --git a/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90
new file mode 100644
index 0000000000000..65e04af66e022
--- /dev/null
+++ b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90
@@ -0,0 +1,43 @@
+! This test validates that declare mapper for a derived type that extends
+! a parent type with an allocatable component correctly maps the nested
+! allocatable payload via the mapper when the whole object is mapped on
+! target.
+
+! REQUIRES: flang, amdgpu
+
+! RUN: %libomptarget-compile-fortran-run-and-check-generic
+
+program target_declare_mapper_parent_allocatable
+  implicit none
+
+  type, abstract :: base_t
+    real, allocatable :: base_arr(:)
+  end type base_t
+
+  type, extends(base_t) :: real_t
+    real, allocatable :: real_arr(:)
+  end type real_t
+  !$omp declare mapper(custommapper: real_t :: t) map(t%base_arr, t%real_arr)
+
+  type(real_t) :: r
+  integer :: i
+  allocate(r%base_arr(10), source=1.0)
+  allocate(r%real_arr(10), source=1.0)
+
+  !$omp target map(mapper(custommapper), tofrom: r)
+  do i = 1, size(r%base_arr)
+    r%base_arr(i) = 2.0
+    r%real_arr(i) = 3.0
+    r%real_arr(i) = r%base_arr(1)
+  end do
+  !$omp end target
+
+
+  !CHECK: base_arr:  2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
+  print*, "base_arr: ", r%base_arr
+  !CHECK: real_arr:  2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
+  print*, "real_arr: ", r%real_arr
+
+  deallocate(r%real_arr)
+  deallocate(r%base_arr)
+end program target_declare_mapper_parent_allocatable

@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2025

@llvm/pr-subscribers-offload

Author: Akash Banerjee (TIFitis)

Changes

This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass.

  • Generalize MapInfoFinalization to add child maps for arbitrarily nested allocatables when a derived object is mapped via declare mapper.
  • Traverse HLFIR designates rooted at the target block arg and build full coordinate_of chains; append members with correct membersIndex.

This fixes #156461.


Full diff: https://github.com/llvm/llvm-project/pull/160766.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+164-57)
  • (modified) flang/test/Lower/OpenMP/declare-mapper.f90 (+38)
  • (added) offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 (+43)
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 57be863cfa1b8..0c7b1ceaf21f9 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -701,40 +701,37 @@ class MapInfoFinalizationPass
 
         auto recordType = mlir::cast<fir::RecordType>(underlyingType);
         llvm::SmallVector<mlir::Value> newMapOpsForFields;
-        llvm::SmallVector<int64_t> fieldIndicies;
-
-        for (auto fieldMemTyPair : recordType.getTypeList()) {
-          auto &field = fieldMemTyPair.first;
-          auto memTy = fieldMemTyPair.second;
-
-          bool shouldMapField =
-              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
-                if (!fir::isAllocatableType(memTy))
-                  return false;
-
-                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
-                if (!designateOp)
-                  return false;
-
-                return designateOp.getComponent() &&
-                       designateOp.getComponent()->strref() == field;
-              }) != mapVarForwardSlice.end();
-
-          // TODO Handle recursive record types. Adapting
-          // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
-          // entities might be helpful here.
-
-          if (!shouldMapField)
-            continue;
-
-          int32_t fieldIdx = recordType.getFieldIndex(field);
+        llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
+
+        auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
+                                   mlir::Type memTy,
+                                   llvm::ArrayRef<int64_t> indexPath,
+                                   llvm::StringRef memberName) {
+          // Avoid adding duplicates for the same index path within this op.
+          bool alreadyPlanned = llvm::any_of(
+              newMemberIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+                return p.size() == indexPath.size() &&
+                       std::equal(p.begin(), p.end(), indexPath.begin());
+              });
+          if (alreadyPlanned)
+            return;
+
+          // Check if already mapped (index path equality).
           bool alreadyMapped = [&]() {
             if (op.getMembersIndexAttr())
               for (auto indexList : op.getMembersIndexAttr()) {
                 auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
-                if (indexListAttr.size() == 1 &&
-                    mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
-                        fieldIdx)
+                if (indexListAttr.size() != indexPath.size())
+                  continue;
+                bool allEq = true;
+                for (auto [i, attr] : llvm::enumerate(indexListAttr)) {
+                  if (mlir::cast<mlir::IntegerAttr>(attr).getInt() !=
+                      indexPath[i]) {
+                    allEq = false;
+                    break;
+                  }
+                }
+                if (allEq)
                   return true;
               }
 
@@ -742,53 +739,165 @@ class MapInfoFinalizationPass
           }();
 
           if (alreadyMapped)
-            continue;
+            return;
 
           builder.setInsertionPoint(op);
-          fir::IntOrValue idxConst =
-              mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
-          auto fieldCoord = fir::CoordinateOp::create(
-              builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
-              llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
           fir::factory::AddrAndBoundsInfo info =
-              fir::factory::getDataOperandBaseAddr(
-                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+              fir::factory::getDataOperandBaseAddr(builder, coordRef,
+                                                   /*isOptional=*/false, loc);
           llvm::SmallVector<mlir::Value> bounds =
               fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
                                                  mlir::omp::MapBoundsType>(
                   builder, info,
-                  hlfir::translateToExtendedValue(op.getLoc(), builder,
-                                                  hlfir::Entity{fieldCoord})
+                  hlfir::translateToExtendedValue(loc, builder,
+                                                  hlfir::Entity{coordRef})
                       .first,
-                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+                  /*dataExvIsAssumedSize=*/false, loc);
 
           mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
-              builder, op.getLoc(), fieldCoord.getResult().getType(),
-              fieldCoord.getResult(),
-              mlir::TypeAttr::get(
-                  fir::unwrapRefType(fieldCoord.getResult().getType())),
+              builder, loc, coordRef.getType(), coordRef,
+              mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
               op.getMapTypeAttr(),
               builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
                   mlir::omp::VariableCaptureKind::ByRef),
               /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
               /*members_index=*/mlir::ArrayAttr{}, bounds,
               /*mapperId=*/mlir::FlatSymbolRefAttr(),
-              builder.getStringAttr(op.getNameAttr().strref() + "." + field +
-                                    ".implicit_map"),
+              builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                    memberName + ".implicit_map"),
               /*partial_map=*/builder.getBoolAttr(false));
           newMapOpsForFields.emplace_back(fieldMapOp);
-          fieldIndicies.emplace_back(fieldIdx);
+          newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+        };
+
+        // 1) Handle direct top-level allocatable fields.
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          if (!fir::isAllocatableType(memTy))
+            continue;
+
+          bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
+            auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
+            return designateOp && designateOp.getComponent() &&
+                   designateOp.getComponent()->strref() == field;
+          });
+          if (!referenced)
+            continue;
+
+          int32_t fieldIdx = recordType.getFieldIndex(field);
+          builder.setInsertionPoint(op);
+          fir::IntOrValue idxConst =
+              mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
+          auto fieldCoord = fir::CoordinateOp::create(
+              builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+          appendMemberMap(op.getLoc(), fieldCoord, memTy, {fieldIdx}, field);
+        }
+
+        // Handle nested allocatable fields along any component chain
+        // referenced in the region via HLFIR designates.
+        llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths;
+        for (mlir::Operation *sliceOp : mapVarForwardSlice) {
+          auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+          if (!designateOp || !designateOp.getComponent())
+            continue;
+          llvm::SmallVector<llvm::StringRef> compPathReversed;
+          compPathReversed.push_back(designateOp.getComponent()->strref());
+          mlir::Value curBase = designateOp.getMemref();
+          bool rootedAtMapArg = false;
+          while (true) {
+            if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
+              if (!parentDes.getComponent())
+                break;
+              compPathReversed.push_back(parentDes.getComponent()->strref());
+              curBase = parentDes.getMemref();
+              continue;
+            }
+            if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
+              if (auto barg =
+                      mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
+                rootedAtMapArg = (barg == opBlockArg);
+            } else if (auto blockArg =
+                           mlir::dyn_cast_or_null<mlir::BlockArgument>(
+                               curBase)) {
+              rootedAtMapArg = (blockArg == opBlockArg);
+            }
+            break;
+          }
+          // Only process nested paths (2+ components). Single-component paths
+          // for direct fields are handled above.
+          if (!rootedAtMapArg || compPathReversed.size() < 2)
+            continue;
+          builder.setInsertionPoint(op);
+          llvm::SmallVector<int64_t> indexPath;
+          mlir::Type curTy = underlyingType;
+          mlir::Value coordRef = op.getVarPtr();
+          bool validPath = true;
+          for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
+            auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
+            if (!recTy) {
+              validPath = false;
+              break;
+            }
+            int32_t idx = recTy.getFieldIndex(compName);
+            if (idx < 0) {
+              validPath = false;
+              break;
+            }
+            indexPath.push_back(idx);
+            mlir::Type memTy = recTy.getType(idx);
+            fir::IntOrValue idxConst =
+                mlir::IntegerAttr::get(builder.getI32Type(), idx);
+            coordRef = fir::CoordinateOp::create(
+                builder, op.getLoc(), builder.getRefType(memTy), coordRef,
+                llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+            curTy = memTy;
+          }
+          if (!validPath)
+            continue;
+          if (auto finalRefTy =
+                  mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
+            mlir::Type eleTy = finalRefTy.getElementType();
+            if (fir::isAllocatableType(eleTy)) {
+              bool isNew = llvm::none_of(
+                  seenIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+                    return p.size() == indexPath.size() &&
+                           std::equal(p.begin(), p.end(), indexPath.begin());
+                  });
+              if (isNew) {
+                seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+                appendMemberMap(op.getLoc(), coordRef, eleTy, indexPath,
+                                compPathReversed.front());
+              }
+            }
+          }
         }
 
         if (newMapOpsForFields.empty())
           return mlir::WalkResult::advance();
 
-        op.getMembersMutable().append(newMapOpsForFields);
+        // Deduplicate by index path to avoid emitting duplicate members for
+        // the same component.
+        llvm::SmallVector<mlir::Value> dedupMapOps;
+        llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths;
+        for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) {
+          const auto &path = newMemberIndexPaths[i];
+          bool isNew = llvm::none_of(
+              dedupIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+                return p.size() == path.size() &&
+                       std::equal(p.begin(), p.end(), path.begin());
+              });
+          if (isNew) {
+            dedupMapOps.push_back(mapOp);
+            dedupIndexPaths.emplace_back(path.begin(), path.end());
+          }
+        }
+        op.getMembersMutable().append(dedupMapOps);
         llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
-        mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
-
-        if (oldMembersIdxAttr)
-          for (mlir::Attribute indexList : oldMembersIdxAttr) {
+        if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
+          for (mlir::Attribute indexList : oldAttr) {
             llvm::SmallVector<int64_t> listVec;
 
             for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
@@ -796,10 +905,8 @@ class MapInfoFinalizationPass
 
             newMemberIndices.emplace_back(std::move(listVec));
           }
-
-        for (int64_t newFieldIdx : fieldIndicies)
-          newMemberIndices.emplace_back(
-              llvm::SmallVector<int64_t>(1, newFieldIdx));
+        for (auto &path : dedupIndexPaths)
+          newMemberIndices.emplace_back(path);
 
         op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
         op.setPartialMap(true);
diff --git a/flang/test/Lower/OpenMP/declare-mapper.f90 b/flang/test/Lower/OpenMP/declare-mapper.f90
index 8a98c68a8d582..1c51666c80f8a 100644
--- a/flang/test/Lower/OpenMP/declare-mapper.f90
+++ b/flang/test/Lower/OpenMP/declare-mapper.f90
@@ -6,6 +6,7 @@
 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90
 
 !--- omp-declare-mapper-1.f90
 subroutine declare_mapper_1
@@ -262,3 +263,40 @@ subroutine use_inner()
       !$omp end target
    end subroutine
 end program declare_mapper_5
+
+!--- omp-declare-mapper-6.f90
+subroutine declare_mapper_nested_parent
+  type :: inner_t
+    real, allocatable :: deep_arr(:)
+  end type inner_t
+
+  type, abstract :: base_t
+    real, allocatable :: base_arr(:)
+    type(inner_t) :: inner
+  end type base_t
+
+  type, extends(base_t) :: real_t
+    real, allocatable :: real_arr(:)
+  end type real_t
+
+  !$omp declare mapper (custommapper : real_t :: t) map(tofrom: t%base_arr, t%real_arr)
+
+  type(real_t) :: r
+
+  allocate(r%base_arr(10))
+  allocate(r%inner%deep_arr(10))
+  allocate(r%real_arr(10))
+  r%base_arr = 1.0
+  r%inner%deep_arr = 4.0
+  r%real_arr = 0.0
+
+  ! CHECK: omp.target
+  ! Check implicit maps for nested parent and deep nested allocatable payloads
+  ! CHECK-DAG: omp.map.info {{.*}} {name = "r.base_arr.implicit_map"}
+  ! CHECK-DAG: omp.map.info {{.*}} {name = "r.deep_arr.implicit_map"}
+  ! The declared mapper's own allocatable is still mapped implicitly
+  ! CHECK-DAG: omp.map.info {{.*}} {name = "r.real_arr.implicit_map"}
+  !$omp target map(mapper(custommapper), tofrom: r)
+    r%real_arr = r%base_arr(1) + r%inner%deep_arr(1)
+  !$omp end target
+end subroutine declare_mapper_nested_parent
diff --git a/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90
new file mode 100644
index 0000000000000..65e04af66e022
--- /dev/null
+++ b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90
@@ -0,0 +1,43 @@
+! This test validates that declare mapper for a derived type that extends
+! a parent type with an allocatable component correctly maps the nested
+! allocatable payload via the mapper when the whole object is mapped on
+! target.
+
+! REQUIRES: flang, amdgpu
+
+! RUN: %libomptarget-compile-fortran-run-and-check-generic
+
+program target_declare_mapper_parent_allocatable
+  implicit none
+
+  type, abstract :: base_t
+    real, allocatable :: base_arr(:)
+  end type base_t
+
+  type, extends(base_t) :: real_t
+    real, allocatable :: real_arr(:)
+  end type real_t
+  !$omp declare mapper(custommapper: real_t :: t) map(t%base_arr, t%real_arr)
+
+  type(real_t) :: r
+  integer :: i
+  allocate(r%base_arr(10), source=1.0)
+  allocate(r%real_arr(10), source=1.0)
+
+  !$omp target map(mapper(custommapper), tofrom: r)
+  do i = 1, size(r%base_arr)
+    r%base_arr(i) = 2.0
+    r%real_arr(i) = 3.0
+    r%real_arr(i) = r%base_arr(1)
+  end do
+  !$omp end target
+
+
+  !CHECK: base_arr:  2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
+  print*, "base_arr: ", r%base_arr
+  !CHECK: real_arr:  2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
+  print*, "real_arr: ", r%real_arr
+
+  deallocate(r%real_arr)
+  deallocate(r%base_arr)
+end program target_declare_mapper_parent_allocatable

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the OpenMP MapInfoFinalization pass to support nested allocatable components in derived types when using declare mappers. It generalizes the pass to handle arbitrarily nested derived types by traversing HLFIR designates and building complete coordinate chains for nested components.

Key changes:

  • Refactored MapInfoFinalization to handle nested allocatable components beyond just direct fields
  • Added logic to traverse component paths and build coordinate_of chains for nested access patterns
  • Enhanced deduplication logic to prevent duplicate mappings for the same component paths

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 New test case for validating declare mapper with derived type inheritance and nested allocatable components
flang/test/Lower/OpenMP/declare-mapper.f90 Added test case for nested parent types with deep allocatable components
flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp Major refactoring to support nested allocatable component mapping with enhanced traversal and deduplication logic

@TIFitis
Copy link
Member Author

TIFitis commented Sep 26, 2025

@agozillon Sorry, I found the base_arr wasn't getting fully mapped when rebasing it for downstream and had to revert #160116. I've fixed it here.

Copy link
Contributor

@agozillon agozillon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem, LGTM!

@TIFitis TIFitis force-pushed the nested_derived_types branch from 52e7187 to 6f86956 Compare September 26, 2025 15:38
@TIFitis TIFitis requested a review from Copilot September 26, 2025 15:42
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated no new comments.

@TIFitis TIFitis force-pushed the nested_derived_types branch from ad30b76 to 0d97911 Compare October 1, 2025 20:07
TIFitis and others added 3 commits October 1, 2025 21:56
…ed types (llvm#160116)

This PR adds support for nested derived types and their mappers to the
MapInfoFinalization pass.

- Generalize MapInfoFinalization to add child maps for arbitrarily
nested allocatables when a derived object is mapped via declare mapper.
- Traverse HLFIR designates rooted at the target block arg and build
full coordinate_of chains; append members with correct membersIndex.

This fixes llvm#156461.
Copy link
Contributor

@agozillon agozillon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update @TIFitis LGTM! :-)

@TIFitis TIFitis force-pushed the nested_derived_types branch from 0d97911 to 4149676 Compare October 2, 2025 15:56
@TIFitis TIFitis enabled auto-merge (squash) October 2, 2025 15:57
@TIFitis TIFitis merged commit ed12dc5 into llvm:main Oct 2, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…ed types (llvm#160766)

This PR adds support for nested derived types and their mappers to the
MapInfoFinalization pass.

- Generalize MapInfoFinalization to add child maps for arbitrarily
nested allocatables when a derived object is mapped via declare mapper.
- Traverse HLFIR designates rooted at the target block arg and build
full coordinate_of chains; append members with correct membersIndex.

This fixes llvm#156461.
@TIFitis TIFitis deleted the nested_derived_types branch October 7, 2025 15:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category offload

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[flang][OpenMP] Declare Mapper does not map base type members properly

3 participants