From 49c32ea2d32ec60d8e2fa423977089b4fab4c13f Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Fri, 19 Sep 2025 18:56:35 +0100 Subject: [PATCH 1/2] [Flang][OpenMP] Implicitly map nested allocatable components in derived types This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass. - Generalize MapInfoFinalization to add child maps for arbitrarily nested allocatables when a derived object is mapped via declare mapper. - Traverse HLFIR designates rooted at the target block arg and build full coordinate_of chains; append members with correct membersIndex. --- .../Optimizer/OpenMP/MapInfoFinalization.cpp | 181 ++++++++++++------ flang/test/Lower/OpenMP/declare-mapper.f90 | 38 ++++ ...rget-declare-mapper-parent-allocatable.f90 | 43 +++++ 3 files changed, 207 insertions(+), 55 deletions(-) create mode 100644 offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 57be863cfa1b8..5715bd9b1a3eb 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -701,40 +701,29 @@ class MapInfoFinalizationPass auto recordType = mlir::cast(underlyingType); llvm::SmallVector newMapOpsForFields; - llvm::SmallVector fieldIndicies; + llvm::SmallVector> newMemberIndexPaths; - for (auto fieldMemTyPair : recordType.getTypeList()) { - auto &field = fieldMemTyPair.first; - auto memTy = fieldMemTyPair.second; - - bool shouldMapField = - llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) { - if (!fir::isAllocatableType(memTy)) - return false; - - auto designateOp = mlir::dyn_cast(sliceOp); - if (!designateOp) - return false; - - return designateOp.getComponent() && - designateOp.getComponent()->strref() == field; - }) != mapVarForwardSlice.end(); - - // TODO Handle recursive record types. Adapting - // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR - // entities might be helpful here. - - if (!shouldMapField) - continue; - - int32_t fieldIdx = recordType.getFieldIndex(field); + auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef, + mlir::Type memTy, + llvm::ArrayRef indexPath, + llvm::StringRef memberName) { + // Check if already mapped (index path equality). bool alreadyMapped = [&]() { if (op.getMembersIndexAttr()) for (auto indexList : op.getMembersIndexAttr()) { auto indexListAttr = mlir::cast(indexList); - if (indexListAttr.size() == 1 && - mlir::cast(indexListAttr[0]).getInt() == - fieldIdx) + if (static_cast(indexListAttr.size()) != + static_cast(indexPath.size())) + continue; + bool allEq = true; + for (auto [i, attr] : llvm::enumerate(indexListAttr)) { + if (mlir::cast(attr).getInt() != + indexPath[i]) { + allEq = false; + break; + } + } + if (allEq) return true; } @@ -742,42 +731,128 @@ class MapInfoFinalizationPass }(); if (alreadyMapped) - continue; + return; builder.setInsertionPoint(op); - fir::IntOrValue idxConst = - mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx); - auto fieldCoord = fir::CoordinateOp::create( - builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(), - llvm::SmallVector{idxConst}); fir::factory::AddrAndBoundsInfo info = - fir::factory::getDataOperandBaseAddr( - builder, fieldCoord, /*isOptional=*/false, op.getLoc()); + fir::factory::getDataOperandBaseAddr(builder, coordRef, + /*isOptional=*/false, loc); llvm::SmallVector bounds = fir::factory::genImplicitBoundsOps( builder, info, - hlfir::translateToExtendedValue(op.getLoc(), builder, - hlfir::Entity{fieldCoord}) + hlfir::translateToExtendedValue(loc, builder, + hlfir::Entity{coordRef}) .first, - /*dataExvIsAssumedSize=*/false, op.getLoc()); + /*dataExvIsAssumedSize=*/false, loc); mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create( - builder, op.getLoc(), fieldCoord.getResult().getType(), - fieldCoord.getResult(), - mlir::TypeAttr::get( - fir::unwrapRefType(fieldCoord.getResult().getType())), + builder, loc, coordRef.getType(), coordRef, + mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())), op.getMapTypeAttr(), builder.getAttr( mlir::omp::VariableCaptureKind::ByRef), /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{}, /*members_index=*/mlir::ArrayAttr{}, bounds, /*mapperId=*/mlir::FlatSymbolRefAttr(), - builder.getStringAttr(op.getNameAttr().strref() + "." + field + - ".implicit_map"), + builder.getStringAttr(op.getNameAttr().strref() + "." + + memberName + ".implicit_map"), /*partial_map=*/builder.getBoolAttr(false)); newMapOpsForFields.emplace_back(fieldMapOp); - fieldIndicies.emplace_back(fieldIdx); + newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end()); + }; + + // 1) Handle direct top-level allocatable fields (existing behavior). + for (auto fieldMemTyPair : recordType.getTypeList()) { + auto &field = fieldMemTyPair.first; + auto memTy = fieldMemTyPair.second; + + if (!fir::isAllocatableType(memTy)) + continue; + + bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) { + auto designateOp = mlir::dyn_cast(opv); + return designateOp && designateOp.getComponent() && + designateOp.getComponent()->strref() == field; + }); + if (!referenced) + continue; + + int32_t fieldIdx = recordType.getFieldIndex(field); + builder.setInsertionPoint(op); + fir::IntOrValue idxConst = + mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx); + auto fieldCoord = fir::CoordinateOp::create( + builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(), + llvm::SmallVector{idxConst}); + appendMemberMap(op.getLoc(), fieldCoord, memTy, {fieldIdx}, field); + } + + // Handle nested allocatable fields along any component chain + // referenced in the region via HLFIR designates. + for (mlir::Operation *sliceOp : mapVarForwardSlice) { + auto designateOp = mlir::dyn_cast(sliceOp); + if (!designateOp || !designateOp.getComponent()) + continue; + llvm::SmallVector compPathReversed; + compPathReversed.push_back(designateOp.getComponent()->strref()); + mlir::Value curBase = designateOp.getMemref(); + bool rootedAtMapArg = false; + while (true) { + if (auto parentDes = curBase.getDefiningOp()) { + if (!parentDes.getComponent()) + break; + compPathReversed.push_back(parentDes.getComponent()->strref()); + curBase = parentDes.getMemref(); + continue; + } + if (auto decl = curBase.getDefiningOp()) { + if (auto barg = + mlir::dyn_cast(decl.getMemref())) + rootedAtMapArg = (barg == opBlockArg); + } else if (auto blockArg = + mlir::dyn_cast_or_null( + curBase)) { + rootedAtMapArg = (blockArg == opBlockArg); + } + break; + } + if (!rootedAtMapArg || compPathReversed.size() < 2) + continue; + builder.setInsertionPoint(op); + llvm::SmallVector indexPath; + mlir::Type curTy = underlyingType; + mlir::Value coordRef = op.getVarPtr(); + bool validPath = true; + for (llvm::StringRef compName : llvm::reverse(compPathReversed)) { + auto recTy = mlir::dyn_cast(curTy); + if (!recTy) { + validPath = false; + break; + } + int32_t idx = recTy.getFieldIndex(compName); + if (idx < 0) { + validPath = false; + break; + } + indexPath.push_back(idx); + mlir::Type memTy = recTy.getType(idx); + fir::IntOrValue idxConst = + mlir::IntegerAttr::get(builder.getI32Type(), idx); + coordRef = fir::CoordinateOp::create( + builder, op.getLoc(), builder.getRefType(memTy), coordRef, + llvm::SmallVector{idxConst}); + curTy = memTy; + } + if (!validPath) + continue; + if (auto finalRefTy = + mlir::dyn_cast(coordRef.getType())) { + mlir::Type eleTy = finalRefTy.getElementType(); + if (fir::isAllocatableType(eleTy)) + appendMemberMap(op.getLoc(), coordRef, eleTy, indexPath, + compPathReversed.front()); + } } if (newMapOpsForFields.empty()) @@ -785,10 +860,8 @@ class MapInfoFinalizationPass op.getMembersMutable().append(newMapOpsForFields); llvm::SmallVector> newMemberIndices; - mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr(); - - if (oldMembersIdxAttr) - for (mlir::Attribute indexList : oldMembersIdxAttr) { + if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr()) + for (mlir::Attribute indexList : oldAttr) { llvm::SmallVector listVec; for (mlir::Attribute index : mlir::cast(indexList)) @@ -796,10 +869,8 @@ class MapInfoFinalizationPass newMemberIndices.emplace_back(std::move(listVec)); } - - for (int64_t newFieldIdx : fieldIndicies) - newMemberIndices.emplace_back( - llvm::SmallVector(1, newFieldIdx)); + for (auto &path : newMemberIndexPaths) + newMemberIndices.emplace_back(path); op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices)); op.setPartialMap(true); diff --git a/flang/test/Lower/OpenMP/declare-mapper.f90 b/flang/test/Lower/OpenMP/declare-mapper.f90 index 8a98c68a8d582..1c51666c80f8a 100644 --- a/flang/test/Lower/OpenMP/declare-mapper.f90 +++ b/flang/test/Lower/OpenMP/declare-mapper.f90 @@ -6,6 +6,7 @@ ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90 +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90 !--- omp-declare-mapper-1.f90 subroutine declare_mapper_1 @@ -262,3 +263,40 @@ subroutine use_inner() !$omp end target end subroutine end program declare_mapper_5 + +!--- omp-declare-mapper-6.f90 +subroutine declare_mapper_nested_parent + type :: inner_t + real, allocatable :: deep_arr(:) + end type inner_t + + type, abstract :: base_t + real, allocatable :: base_arr(:) + type(inner_t) :: inner + end type base_t + + type, extends(base_t) :: real_t + real, allocatable :: real_arr(:) + end type real_t + + !$omp declare mapper (custommapper : real_t :: t) map(tofrom: t%base_arr, t%real_arr) + + type(real_t) :: r + + allocate(r%base_arr(10)) + allocate(r%inner%deep_arr(10)) + allocate(r%real_arr(10)) + r%base_arr = 1.0 + r%inner%deep_arr = 4.0 + r%real_arr = 0.0 + + ! CHECK: omp.target + ! Check implicit maps for nested parent and deep nested allocatable payloads + ! CHECK-DAG: omp.map.info {{.*}} {name = "r.base_arr.implicit_map"} + ! CHECK-DAG: omp.map.info {{.*}} {name = "r.deep_arr.implicit_map"} + ! The declared mapper's own allocatable is still mapped implicitly + ! CHECK-DAG: omp.map.info {{.*}} {name = "r.real_arr.implicit_map"} + !$omp target map(mapper(custommapper), tofrom: r) + r%real_arr = r%base_arr(1) + r%inner%deep_arr(1) + !$omp end target +end subroutine declare_mapper_nested_parent diff --git a/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 new file mode 100644 index 0000000000000..bd5e6828bc991 --- /dev/null +++ b/offload/test/offloading/fortran/target-declare-mapper-parent-allocatable.f90 @@ -0,0 +1,43 @@ +! This test validates that declare mapper for a derived type that extends +! a parent type with an allocatable component correctly maps the nested +! allocatable payload via the mapper when the whole object is mapped on +! target. + +! REQUIRES: flang, amdgpu + +! RUN: %libomptarget-compile-fortran-run-and-check-generic + +program target_declare_mapper_parent_allocatable + implicit none + + type, abstract :: base_t + real, allocatable :: base_arr(:) + end type base_t + + type, extends(base_t) :: real_t + real, allocatable :: real_arr(:) + end type real_t + !$omp declare mapper(custommapper: real_t :: t) map(t%base_arr, t%real_arr) + + type(real_t) :: r + integer :: i + allocate(r%base_arr(10), source=1.0) + allocate(r%real_arr(10), source=1.0) + + !$omp target map(tofrom: r) + do i = 1, size(r%base_arr) + r%base_arr(i) = 2.0 + r%real_arr(i) = 3.0 + r%real_arr(i) = r%base_arr(1) + end do + !$omp end target + + + !CHECK: base_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. + print*, "base_arr: ", r%base_arr + !CHECK: real_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. + print*, "real_arr: ", r%real_arr + + deallocate(r%real_arr) + deallocate(r%base_arr) +end program target_declare_mapper_parent_allocatable From da412de1a724b68212379937b94ee77dcc634351 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Mon, 22 Sep 2025 15:25:12 +0100 Subject: [PATCH 2/2] Address comment. --- flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 5715bd9b1a3eb..3659218f91ff6 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -712,8 +712,7 @@ class MapInfoFinalizationPass if (op.getMembersIndexAttr()) for (auto indexList : op.getMembersIndexAttr()) { auto indexListAttr = mlir::cast(indexList); - if (static_cast(indexListAttr.size()) != - static_cast(indexPath.size())) + if (indexListAttr.size() != indexPath.size()) continue; bool allEq = true; for (auto [i, attr] : llvm::enumerate(indexListAttr)) {