-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[Flang][OpenMP] Implicitly map nested allocatable components in derived types #160766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-openmp Author: Akash Banerjee (TIFitis) ChangesThis PR adds support for nested derived types and their mappers to the MapInfoFinalization pass.
This fixes #156461. Full diff: https://github.com/llvm/llvm-project/pull/160766.diff 3 Files Affected:
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 57be863cfa1b8..0c7b1ceaf21f9 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -701,40 +701,37 @@ class MapInfoFinalizationPass
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
llvm::SmallVector<mlir::Value> newMapOpsForFields;
- llvm::SmallVector<int64_t> fieldIndicies;
-
- for (auto fieldMemTyPair : recordType.getTypeList()) {
- auto &field = fieldMemTyPair.first;
- auto memTy = fieldMemTyPair.second;
-
- bool shouldMapField =
- llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
- if (!fir::isAllocatableType(memTy))
- return false;
-
- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
- if (!designateOp)
- return false;
-
- return designateOp.getComponent() &&
- designateOp.getComponent()->strref() == field;
- }) != mapVarForwardSlice.end();
-
- // TODO Handle recursive record types. Adapting
- // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
- // entities might be helpful here.
-
- if (!shouldMapField)
- continue;
-
- int32_t fieldIdx = recordType.getFieldIndex(field);
+ llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
+
+ auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
+ mlir::Type memTy,
+ llvm::ArrayRef<int64_t> indexPath,
+ llvm::StringRef memberName) {
+ // Avoid adding duplicates for the same index path within this op.
+ bool alreadyPlanned = llvm::any_of(
+ newMemberIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+ return p.size() == indexPath.size() &&
+ std::equal(p.begin(), p.end(), indexPath.begin());
+ });
+ if (alreadyPlanned)
+ return;
+
+ // Check if already mapped (index path equality).
bool alreadyMapped = [&]() {
if (op.getMembersIndexAttr())
for (auto indexList : op.getMembersIndexAttr()) {
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
- if (indexListAttr.size() == 1 &&
- mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
- fieldIdx)
+ if (indexListAttr.size() != indexPath.size())
+ continue;
+ bool allEq = true;
+ for (auto [i, attr] : llvm::enumerate(indexListAttr)) {
+ if (mlir::cast<mlir::IntegerAttr>(attr).getInt() !=
+ indexPath[i]) {
+ allEq = false;
+ break;
+ }
+ }
+ if (allEq)
return true;
}
@@ -742,53 +739,165 @@ class MapInfoFinalizationPass
}();
if (alreadyMapped)
- continue;
+ return;
builder.setInsertionPoint(op);
- fir::IntOrValue idxConst =
- mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
- auto fieldCoord = fir::CoordinateOp::create(
- builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
- llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
fir::factory::AddrAndBoundsInfo info =
- fir::factory::getDataOperandBaseAddr(
- builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+ fir::factory::getDataOperandBaseAddr(builder, coordRef,
+ /*isOptional=*/false, loc);
llvm::SmallVector<mlir::Value> bounds =
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
builder, info,
- hlfir::translateToExtendedValue(op.getLoc(), builder,
- hlfir::Entity{fieldCoord})
+ hlfir::translateToExtendedValue(loc, builder,
+ hlfir::Entity{coordRef})
.first,
- /*dataExvIsAssumedSize=*/false, op.getLoc());
+ /*dataExvIsAssumedSize=*/false, loc);
mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
- builder, op.getLoc(), fieldCoord.getResult().getType(),
- fieldCoord.getResult(),
- mlir::TypeAttr::get(
- fir::unwrapRefType(fieldCoord.getResult().getType())),
+ builder, loc, coordRef.getType(), coordRef,
+ mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
op.getMapTypeAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
/*members_index=*/mlir::ArrayAttr{}, bounds,
/*mapperId=*/mlir::FlatSymbolRefAttr(),
- builder.getStringAttr(op.getNameAttr().strref() + "." + field +
- ".implicit_map"),
+ builder.getStringAttr(op.getNameAttr().strref() + "." +
+ memberName + ".implicit_map"),
/*partial_map=*/builder.getBoolAttr(false));
newMapOpsForFields.emplace_back(fieldMapOp);
- fieldIndicies.emplace_back(fieldIdx);
+ newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+ };
+
+ // 1) Handle direct top-level allocatable fields.
+ for (auto fieldMemTyPair : recordType.getTypeList()) {
+ auto &field = fieldMemTyPair.first;
+ auto memTy = fieldMemTyPair.second;
+
+ if (!fir::isAllocatableType(memTy))
+ continue;
+
+ bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
+ return designateOp && designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ });
+ if (!referenced)
+ continue;
+
+ int32_t fieldIdx = recordType.getFieldIndex(field);
+ builder.setInsertionPoint(op);
+ fir::IntOrValue idxConst =
+ mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
+ auto fieldCoord = fir::CoordinateOp::create(
+ builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+ llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+ appendMemberMap(op.getLoc(), fieldCoord, memTy, {fieldIdx}, field);
+ }
+
+ // Handle nested allocatable fields along any component chain
+ // referenced in the region via HLFIR designates.
+ llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths;
+ for (mlir::Operation *sliceOp : mapVarForwardSlice) {
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp || !designateOp.getComponent())
+ continue;
+ llvm::SmallVector<llvm::StringRef> compPathReversed;
+ compPathReversed.push_back(designateOp.getComponent()->strref());
+ mlir::Value curBase = designateOp.getMemref();
+ bool rootedAtMapArg = false;
+ while (true) {
+ if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
+ if (!parentDes.getComponent())
+ break;
+ compPathReversed.push_back(parentDes.getComponent()->strref());
+ curBase = parentDes.getMemref();
+ continue;
+ }
+ if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
+ if (auto barg =
+ mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
+ rootedAtMapArg = (barg == opBlockArg);
+ } else if (auto blockArg =
+ mlir::dyn_cast_or_null<mlir::BlockArgument>(
+ curBase)) {
+ rootedAtMapArg = (blockArg == opBlockArg);
+ }
+ break;
+ }
+ // Only process nested paths (2+ components). Single-component paths
+ // for direct fields are handled above.
+ if (!rootedAtMapArg || compPathReversed.size() < 2)
+ continue;
+ builder.setInsertionPoint(op);
+ llvm::SmallVector<int64_t> indexPath;
+ mlir::Type curTy = underlyingType;
+ mlir::Value coordRef = op.getVarPtr();
+ bool validPath = true;
+ for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
+ auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
+ if (!recTy) {
+ validPath = false;
+ break;
+ }
+ int32_t idx = recTy.getFieldIndex(compName);
+ if (idx < 0) {
+ validPath = false;
+ break;
+ }
+ indexPath.push_back(idx);
+ mlir::Type memTy = recTy.getType(idx);
+ fir::IntOrValue idxConst =
+ mlir::IntegerAttr::get(builder.getI32Type(), idx);
+ coordRef = fir::CoordinateOp::create(
+ builder, op.getLoc(), builder.getRefType(memTy), coordRef,
+ llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+ curTy = memTy;
+ }
+ if (!validPath)
+ continue;
+ if (auto finalRefTy =
+ mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
+ mlir::Type eleTy = finalRefTy.getElementType();
+ if (fir::isAllocatableType(eleTy)) {
+ bool isNew = llvm::none_of(
+ seenIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+ return p.size() == indexPath.size() &&
+ std::equal(p.begin(), p.end(), indexPath.begin());
+ });
+ if (isNew) {
+ seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+ appendMemberMap(op.getLoc(), coordRef, eleTy, indexPath,
+ compPathReversed.front());
+ }
+ }
+ }
}
if (newMapOpsForFields.empty())
return mlir::WalkResult::advance();
- op.getMembersMutable().append(newMapOpsForFields);
+ // Deduplicate by index path to avoid emitting duplicate members for
+ // the same component.
+ llvm::SmallVector<mlir::Value> dedupMapOps;
+ llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths;
+ for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) {
+ const auto &path = newMemberIndexPaths[i];
+ bool isNew = llvm::none_of(
+ dedupIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+ return p.size() == path.size() &&
+ std::equal(p.begin(), p.end(), path.begin());
+ });
+ if (isNew) {
+ dedupMapOps.push_back(mapOp);
+ dedupIndexPaths.emplace_back(path.begin(), path.end());
+ }
+ }
+ op.getMembersMutable().append(dedupMapOps);
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
- mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
-
- if (oldMembersIdxAttr)
- for (mlir::Attribute indexList : oldMembersIdxAttr) {
+ if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
+ for (mlir::Attribute indexList : oldAttr) {
llvm::SmallVector<int64_t> listVec;
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
@@ -796,10 +905,8 @@ class MapInfoFinalizationPass
newMemberIndices.emplace_back(std::move(listVec));
}
-
- for (int64_t newFieldIdx : fieldIndicies)
- newMemberIndices.emplace_back(
- llvm::SmallVector<int64_t>(1, newFieldIdx));
+ for (auto &path : dedupIndexPaths)
+ newMemberIndices.emplace_back(path);
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
op.setPartialMap(true);
diff --git a/flang/test/Lower/OpenMP/declare-mapper.f90 b/flang/test/Lower/OpenMP/declare-mapper.f90
index 8a98c68a8d582..1c51666c80f8a 100644
--- a/flang/test/Lower/OpenMP/declare-mapper.f90
+++ b/flang/test/Lower/OpenMP/declare-mapper.f90
@@ -6,6 +6,7 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90
!--- omp-declare-mapper-1.f90
subroutine declare_mapper_1
@@ -262,3 +263,40 @@ subroutine use_inner()
!$omp end target
end subroutine
end program declare_mapper_5
+
+!--- omp-declare-mapper-6.f90
+subroutine declare_mapper_nested_parent
+ type :: inner_t
+ real, allocatable :: deep_arr(:)
+ end type inner_t
+
+ type, abstract :: base_t
+ real, allocatable :: base_arr(:)
+ type(inner_t) :: inner
+ end type base_t
+
+ type, extends(base_t) :: real_t
+ real, allocatable :: real_arr(:)
+ end type real_t
+
+ !$omp declare mapper (custommapper : real_t :: t) map(tofrom: t%base_arr, t%real_arr)
+
+ type(real_t) :: r
+
+ allocate(r%base_arr(10))
+ allocate(r%inner%deep_arr(10))
+ allocate(r%real_arr(10))
+ r%base_arr = 1.0
+ r%inner%deep_arr = 4.0
+ r%real_arr = 0.0
+
+ ! CHECK: omp.target
+ ! Check implicit maps for nested parent and deep nested allocatable payloads
+ ! CHECK-DAG: omp.map.info {{.*}} {name = "r.base_arr.implicit_map"}
+ ! CHECK-DAG: omp.map.info {{.*}} {name = "r.deep_arr.implicit_map"}
+ ! The declared mapper's own allocatable is still mapped implicitly
+ ! CHECK-DAG: omp.map.info {{.*}} {name = "r.real_arr.implicit_map"}
+ !$omp target map(mapper(custommapper), tofrom: r)
+ r%real_arr = r%base_arr(1) + r%inner%deep_arr(1)
+ !$omp end target
+end subroutine declare_mapper_nested_parent
diff --git a/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90
new file mode 100644
index 0000000000000..65e04af66e022
--- /dev/null
+++ b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90
@@ -0,0 +1,43 @@
+! This test validates that declare mapper for a derived type that extends
+! a parent type with an allocatable component correctly maps the nested
+! allocatable payload via the mapper when the whole object is mapped on
+! target.
+
+! REQUIRES: flang, amdgpu
+
+! RUN: %libomptarget-compile-fortran-run-and-check-generic
+
+program target_declare_mapper_parent_allocatable
+ implicit none
+
+ type, abstract :: base_t
+ real, allocatable :: base_arr(:)
+ end type base_t
+
+ type, extends(base_t) :: real_t
+ real, allocatable :: real_arr(:)
+ end type real_t
+ !$omp declare mapper(custommapper: real_t :: t) map(t%base_arr, t%real_arr)
+
+ type(real_t) :: r
+ integer :: i
+ allocate(r%base_arr(10), source=1.0)
+ allocate(r%real_arr(10), source=1.0)
+
+ !$omp target map(mapper(custommapper), tofrom: r)
+ do i = 1, size(r%base_arr)
+ r%base_arr(i) = 2.0
+ r%real_arr(i) = 3.0
+ r%real_arr(i) = r%base_arr(1)
+ end do
+ !$omp end target
+
+
+ !CHECK: base_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
+ print*, "base_arr: ", r%base_arr
+ !CHECK: real_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
+ print*, "real_arr: ", r%real_arr
+
+ deallocate(r%real_arr)
+ deallocate(r%base_arr)
+end program target_declare_mapper_parent_allocatable
|
@llvm/pr-subscribers-offload Author: Akash Banerjee (TIFitis) ChangesThis PR adds support for nested derived types and their mappers to the MapInfoFinalization pass.
This fixes #156461. Full diff: https://github.com/llvm/llvm-project/pull/160766.diff 3 Files Affected:
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 57be863cfa1b8..0c7b1ceaf21f9 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -701,40 +701,37 @@ class MapInfoFinalizationPass
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
llvm::SmallVector<mlir::Value> newMapOpsForFields;
- llvm::SmallVector<int64_t> fieldIndicies;
-
- for (auto fieldMemTyPair : recordType.getTypeList()) {
- auto &field = fieldMemTyPair.first;
- auto memTy = fieldMemTyPair.second;
-
- bool shouldMapField =
- llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
- if (!fir::isAllocatableType(memTy))
- return false;
-
- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
- if (!designateOp)
- return false;
-
- return designateOp.getComponent() &&
- designateOp.getComponent()->strref() == field;
- }) != mapVarForwardSlice.end();
-
- // TODO Handle recursive record types. Adapting
- // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
- // entities might be helpful here.
-
- if (!shouldMapField)
- continue;
-
- int32_t fieldIdx = recordType.getFieldIndex(field);
+ llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
+
+ auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
+ mlir::Type memTy,
+ llvm::ArrayRef<int64_t> indexPath,
+ llvm::StringRef memberName) {
+ // Avoid adding duplicates for the same index path within this op.
+ bool alreadyPlanned = llvm::any_of(
+ newMemberIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+ return p.size() == indexPath.size() &&
+ std::equal(p.begin(), p.end(), indexPath.begin());
+ });
+ if (alreadyPlanned)
+ return;
+
+ // Check if already mapped (index path equality).
bool alreadyMapped = [&]() {
if (op.getMembersIndexAttr())
for (auto indexList : op.getMembersIndexAttr()) {
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
- if (indexListAttr.size() == 1 &&
- mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
- fieldIdx)
+ if (indexListAttr.size() != indexPath.size())
+ continue;
+ bool allEq = true;
+ for (auto [i, attr] : llvm::enumerate(indexListAttr)) {
+ if (mlir::cast<mlir::IntegerAttr>(attr).getInt() !=
+ indexPath[i]) {
+ allEq = false;
+ break;
+ }
+ }
+ if (allEq)
return true;
}
@@ -742,53 +739,165 @@ class MapInfoFinalizationPass
}();
if (alreadyMapped)
- continue;
+ return;
builder.setInsertionPoint(op);
- fir::IntOrValue idxConst =
- mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
- auto fieldCoord = fir::CoordinateOp::create(
- builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
- llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
fir::factory::AddrAndBoundsInfo info =
- fir::factory::getDataOperandBaseAddr(
- builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+ fir::factory::getDataOperandBaseAddr(builder, coordRef,
+ /*isOptional=*/false, loc);
llvm::SmallVector<mlir::Value> bounds =
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
builder, info,
- hlfir::translateToExtendedValue(op.getLoc(), builder,
- hlfir::Entity{fieldCoord})
+ hlfir::translateToExtendedValue(loc, builder,
+ hlfir::Entity{coordRef})
.first,
- /*dataExvIsAssumedSize=*/false, op.getLoc());
+ /*dataExvIsAssumedSize=*/false, loc);
mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
- builder, op.getLoc(), fieldCoord.getResult().getType(),
- fieldCoord.getResult(),
- mlir::TypeAttr::get(
- fir::unwrapRefType(fieldCoord.getResult().getType())),
+ builder, loc, coordRef.getType(), coordRef,
+ mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
op.getMapTypeAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
/*members_index=*/mlir::ArrayAttr{}, bounds,
/*mapperId=*/mlir::FlatSymbolRefAttr(),
- builder.getStringAttr(op.getNameAttr().strref() + "." + field +
- ".implicit_map"),
+ builder.getStringAttr(op.getNameAttr().strref() + "." +
+ memberName + ".implicit_map"),
/*partial_map=*/builder.getBoolAttr(false));
newMapOpsForFields.emplace_back(fieldMapOp);
- fieldIndicies.emplace_back(fieldIdx);
+ newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+ };
+
+ // 1) Handle direct top-level allocatable fields.
+ for (auto fieldMemTyPair : recordType.getTypeList()) {
+ auto &field = fieldMemTyPair.first;
+ auto memTy = fieldMemTyPair.second;
+
+ if (!fir::isAllocatableType(memTy))
+ continue;
+
+ bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
+ return designateOp && designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ });
+ if (!referenced)
+ continue;
+
+ int32_t fieldIdx = recordType.getFieldIndex(field);
+ builder.setInsertionPoint(op);
+ fir::IntOrValue idxConst =
+ mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
+ auto fieldCoord = fir::CoordinateOp::create(
+ builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+ llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+ appendMemberMap(op.getLoc(), fieldCoord, memTy, {fieldIdx}, field);
+ }
+
+ // Handle nested allocatable fields along any component chain
+ // referenced in the region via HLFIR designates.
+ llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths;
+ for (mlir::Operation *sliceOp : mapVarForwardSlice) {
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp || !designateOp.getComponent())
+ continue;
+ llvm::SmallVector<llvm::StringRef> compPathReversed;
+ compPathReversed.push_back(designateOp.getComponent()->strref());
+ mlir::Value curBase = designateOp.getMemref();
+ bool rootedAtMapArg = false;
+ while (true) {
+ if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
+ if (!parentDes.getComponent())
+ break;
+ compPathReversed.push_back(parentDes.getComponent()->strref());
+ curBase = parentDes.getMemref();
+ continue;
+ }
+ if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
+ if (auto barg =
+ mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
+ rootedAtMapArg = (barg == opBlockArg);
+ } else if (auto blockArg =
+ mlir::dyn_cast_or_null<mlir::BlockArgument>(
+ curBase)) {
+ rootedAtMapArg = (blockArg == opBlockArg);
+ }
+ break;
+ }
+ // Only process nested paths (2+ components). Single-component paths
+ // for direct fields are handled above.
+ if (!rootedAtMapArg || compPathReversed.size() < 2)
+ continue;
+ builder.setInsertionPoint(op);
+ llvm::SmallVector<int64_t> indexPath;
+ mlir::Type curTy = underlyingType;
+ mlir::Value coordRef = op.getVarPtr();
+ bool validPath = true;
+ for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
+ auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
+ if (!recTy) {
+ validPath = false;
+ break;
+ }
+ int32_t idx = recTy.getFieldIndex(compName);
+ if (idx < 0) {
+ validPath = false;
+ break;
+ }
+ indexPath.push_back(idx);
+ mlir::Type memTy = recTy.getType(idx);
+ fir::IntOrValue idxConst =
+ mlir::IntegerAttr::get(builder.getI32Type(), idx);
+ coordRef = fir::CoordinateOp::create(
+ builder, op.getLoc(), builder.getRefType(memTy), coordRef,
+ llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+ curTy = memTy;
+ }
+ if (!validPath)
+ continue;
+ if (auto finalRefTy =
+ mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
+ mlir::Type eleTy = finalRefTy.getElementType();
+ if (fir::isAllocatableType(eleTy)) {
+ bool isNew = llvm::none_of(
+ seenIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+ return p.size() == indexPath.size() &&
+ std::equal(p.begin(), p.end(), indexPath.begin());
+ });
+ if (isNew) {
+ seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+ appendMemberMap(op.getLoc(), coordRef, eleTy, indexPath,
+ compPathReversed.front());
+ }
+ }
+ }
}
if (newMapOpsForFields.empty())
return mlir::WalkResult::advance();
- op.getMembersMutable().append(newMapOpsForFields);
+ // Deduplicate by index path to avoid emitting duplicate members for
+ // the same component.
+ llvm::SmallVector<mlir::Value> dedupMapOps;
+ llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths;
+ for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) {
+ const auto &path = newMemberIndexPaths[i];
+ bool isNew = llvm::none_of(
+ dedupIndexPaths, [&](const llvm::SmallVector<int64_t> &p) {
+ return p.size() == path.size() &&
+ std::equal(p.begin(), p.end(), path.begin());
+ });
+ if (isNew) {
+ dedupMapOps.push_back(mapOp);
+ dedupIndexPaths.emplace_back(path.begin(), path.end());
+ }
+ }
+ op.getMembersMutable().append(dedupMapOps);
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
- mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
-
- if (oldMembersIdxAttr)
- for (mlir::Attribute indexList : oldMembersIdxAttr) {
+ if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
+ for (mlir::Attribute indexList : oldAttr) {
llvm::SmallVector<int64_t> listVec;
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
@@ -796,10 +905,8 @@ class MapInfoFinalizationPass
newMemberIndices.emplace_back(std::move(listVec));
}
-
- for (int64_t newFieldIdx : fieldIndicies)
- newMemberIndices.emplace_back(
- llvm::SmallVector<int64_t>(1, newFieldIdx));
+ for (auto &path : dedupIndexPaths)
+ newMemberIndices.emplace_back(path);
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
op.setPartialMap(true);
diff --git a/flang/test/Lower/OpenMP/declare-mapper.f90 b/flang/test/Lower/OpenMP/declare-mapper.f90
index 8a98c68a8d582..1c51666c80f8a 100644
--- a/flang/test/Lower/OpenMP/declare-mapper.f90
+++ b/flang/test/Lower/OpenMP/declare-mapper.f90
@@ -6,6 +6,7 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90
!--- omp-declare-mapper-1.f90
subroutine declare_mapper_1
@@ -262,3 +263,40 @@ subroutine use_inner()
!$omp end target
end subroutine
end program declare_mapper_5
+
+!--- omp-declare-mapper-6.f90
+subroutine declare_mapper_nested_parent
+ type :: inner_t
+ real, allocatable :: deep_arr(:)
+ end type inner_t
+
+ type, abstract :: base_t
+ real, allocatable :: base_arr(:)
+ type(inner_t) :: inner
+ end type base_t
+
+ type, extends(base_t) :: real_t
+ real, allocatable :: real_arr(:)
+ end type real_t
+
+ !$omp declare mapper (custommapper : real_t :: t) map(tofrom: t%base_arr, t%real_arr)
+
+ type(real_t) :: r
+
+ allocate(r%base_arr(10))
+ allocate(r%inner%deep_arr(10))
+ allocate(r%real_arr(10))
+ r%base_arr = 1.0
+ r%inner%deep_arr = 4.0
+ r%real_arr = 0.0
+
+ ! CHECK: omp.target
+ ! Check implicit maps for nested parent and deep nested allocatable payloads
+ ! CHECK-DAG: omp.map.info {{.*}} {name = "r.base_arr.implicit_map"}
+ ! CHECK-DAG: omp.map.info {{.*}} {name = "r.deep_arr.implicit_map"}
+ ! The declared mapper's own allocatable is still mapped implicitly
+ ! CHECK-DAG: omp.map.info {{.*}} {name = "r.real_arr.implicit_map"}
+ !$omp target map(mapper(custommapper), tofrom: r)
+ r%real_arr = r%base_arr(1) + r%inner%deep_arr(1)
+ !$omp end target
+end subroutine declare_mapper_nested_parent
diff --git a/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90
new file mode 100644
index 0000000000000..65e04af66e022
--- /dev/null
+++ b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90
@@ -0,0 +1,43 @@
+! This test validates that declare mapper for a derived type that extends
+! a parent type with an allocatable component correctly maps the nested
+! allocatable payload via the mapper when the whole object is mapped on
+! target.
+
+! REQUIRES: flang, amdgpu
+
+! RUN: %libomptarget-compile-fortran-run-and-check-generic
+
+program target_declare_mapper_parent_allocatable
+ implicit none
+
+ type, abstract :: base_t
+ real, allocatable :: base_arr(:)
+ end type base_t
+
+ type, extends(base_t) :: real_t
+ real, allocatable :: real_arr(:)
+ end type real_t
+ !$omp declare mapper(custommapper: real_t :: t) map(t%base_arr, t%real_arr)
+
+ type(real_t) :: r
+ integer :: i
+ allocate(r%base_arr(10), source=1.0)
+ allocate(r%real_arr(10), source=1.0)
+
+ !$omp target map(mapper(custommapper), tofrom: r)
+ do i = 1, size(r%base_arr)
+ r%base_arr(i) = 2.0
+ r%real_arr(i) = 3.0
+ r%real_arr(i) = r%base_arr(1)
+ end do
+ !$omp end target
+
+
+ !CHECK: base_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
+ print*, "base_arr: ", r%base_arr
+ !CHECK: real_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
+ print*, "real_arr: ", r%real_arr
+
+ deallocate(r%real_arr)
+ deallocate(r%base_arr)
+end program target_declare_mapper_parent_allocatable
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the OpenMP MapInfoFinalization pass to support nested allocatable components in derived types when using declare mappers. It generalizes the pass to handle arbitrarily nested derived types by traversing HLFIR designates and building complete coordinate chains for nested components.
Key changes:
- Refactored MapInfoFinalization to handle nested allocatable components beyond just direct fields
- Added logic to traverse component paths and build coordinate_of chains for nested access patterns
- Enhanced deduplication logic to prevent duplicate mappings for the same component paths
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 | New test case for validating declare mapper with derived type inheritance and nested allocatable components |
flang/test/Lower/OpenMP/declare-mapper.f90 | Added test case for nested parent types with deep allocatable components |
flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp | Major refactoring to support nested allocatable component mapping with enhanced traversal and deduplication logic |
@agozillon Sorry, I found the base_arr wasn't getting fully mapped when rebasing it for downstream and had to revert #160116. I've fixed it here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no problem, LGTM!
52e7187
to
6f86956
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated no new comments.
ad30b76
to
0d97911
Compare
…ed types (llvm#160116) This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass. - Generalize MapInfoFinalization to add child maps for arbitrarily nested allocatables when a derived object is mapped via declare mapper. - Traverse HLFIR designates rooted at the target block arg and build full coordinate_of chains; append members with correct membersIndex. This fixes llvm#156461.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the update @TIFitis LGTM! :-)
0d97911
to
4149676
Compare
…ed types (llvm#160766) This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass. - Generalize MapInfoFinalization to add child maps for arbitrarily nested allocatables when a derived object is mapped via declare mapper. - Traverse HLFIR designates rooted at the target block arg and build full coordinate_of chains; append members with correct membersIndex. This fixes llvm#156461.
This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass.
This fixes #156461.