Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 55 additions & 125 deletions flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,175 +701,105 @@ class MapInfoFinalizationPass

auto recordType = mlir::cast<fir::RecordType>(underlyingType);
llvm::SmallVector<mlir::Value> newMapOpsForFields;
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
llvm::SmallVector<int64_t> fieldIndicies;

auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
mlir::Type memTy,
llvm::ArrayRef<int64_t> indexPath,
llvm::StringRef memberName) {
// Check if already mapped (index path equality).
for (auto fieldMemTyPair : recordType.getTypeList()) {
auto &field = fieldMemTyPair.first;
auto memTy = fieldMemTyPair.second;

bool shouldMapField =
llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
if (!fir::isAllocatableType(memTy))
return false;

auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
if (!designateOp)
return false;

return designateOp.getComponent() &&
designateOp.getComponent()->strref() == field;
}) != mapVarForwardSlice.end();

// TODO Handle recursive record types. Adapting
// `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
// entities might be helpful here.

if (!shouldMapField)
continue;

int32_t fieldIdx = recordType.getFieldIndex(field);
bool alreadyMapped = [&]() {
if (op.getMembersIndexAttr())
for (auto indexList : op.getMembersIndexAttr()) {
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
if (indexListAttr.size() != indexPath.size())
continue;
bool allEq = true;
for (auto [i, attr] : llvm::enumerate(indexListAttr)) {
if (mlir::cast<mlir::IntegerAttr>(attr).getInt() !=
indexPath[i]) {
allEq = false;
break;
}
}
if (allEq)
if (indexListAttr.size() == 1 &&
mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
fieldIdx)
return true;
}

return false;
}();

if (alreadyMapped)
return;
continue;

builder.setInsertionPoint(op);
fir::IntOrValue idxConst =
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
auto fieldCoord = fir::CoordinateOp::create(
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
fir::factory::AddrAndBoundsInfo info =
fir::factory::getDataOperandBaseAddr(builder, coordRef,
/*isOptional=*/false, loc);
fir::factory::getDataOperandBaseAddr(
builder, fieldCoord, /*isOptional=*/false, op.getLoc());
llvm::SmallVector<mlir::Value> bounds =
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
builder, info,
hlfir::translateToExtendedValue(loc, builder,
hlfir::Entity{coordRef})
hlfir::translateToExtendedValue(op.getLoc(), builder,
hlfir::Entity{fieldCoord})
.first,
/*dataExvIsAssumedSize=*/false, loc);
/*dataExvIsAssumedSize=*/false, op.getLoc());

mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
builder, loc, coordRef.getType(), coordRef,
mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
builder, op.getLoc(), fieldCoord.getResult().getType(),
fieldCoord.getResult(),
mlir::TypeAttr::get(
fir::unwrapRefType(fieldCoord.getResult().getType())),
op.getMapTypeAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
/*members_index=*/mlir::ArrayAttr{}, bounds,
/*mapperId=*/mlir::FlatSymbolRefAttr(),
builder.getStringAttr(op.getNameAttr().strref() + "." +
memberName + ".implicit_map"),
builder.getStringAttr(op.getNameAttr().strref() + "." + field +
".implicit_map"),
/*partial_map=*/builder.getBoolAttr(false));
newMapOpsForFields.emplace_back(fieldMapOp);
newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
};

// 1) Handle direct top-level allocatable fields (existing behavior).
for (auto fieldMemTyPair : recordType.getTypeList()) {
auto &field = fieldMemTyPair.first;
auto memTy = fieldMemTyPair.second;

if (!fir::isAllocatableType(memTy))
continue;

bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
return designateOp && designateOp.getComponent() &&
designateOp.getComponent()->strref() == field;
});
if (!referenced)
continue;

int32_t fieldIdx = recordType.getFieldIndex(field);
builder.setInsertionPoint(op);
fir::IntOrValue idxConst =
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
auto fieldCoord = fir::CoordinateOp::create(
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
appendMemberMap(op.getLoc(), fieldCoord, memTy, {fieldIdx}, field);
}

// Handle nested allocatable fields along any component chain
// referenced in the region via HLFIR designates.
for (mlir::Operation *sliceOp : mapVarForwardSlice) {
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
if (!designateOp || !designateOp.getComponent())
continue;
llvm::SmallVector<llvm::StringRef> compPathReversed;
compPathReversed.push_back(designateOp.getComponent()->strref());
mlir::Value curBase = designateOp.getMemref();
bool rootedAtMapArg = false;
while (true) {
if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
if (!parentDes.getComponent())
break;
compPathReversed.push_back(parentDes.getComponent()->strref());
curBase = parentDes.getMemref();
continue;
}
if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
if (auto barg =
mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
rootedAtMapArg = (barg == opBlockArg);
} else if (auto blockArg =
mlir::dyn_cast_or_null<mlir::BlockArgument>(
curBase)) {
rootedAtMapArg = (blockArg == opBlockArg);
}
break;
}
if (!rootedAtMapArg || compPathReversed.size() < 2)
continue;
builder.setInsertionPoint(op);
llvm::SmallVector<int64_t> indexPath;
mlir::Type curTy = underlyingType;
mlir::Value coordRef = op.getVarPtr();
bool validPath = true;
for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
if (!recTy) {
validPath = false;
break;
}
int32_t idx = recTy.getFieldIndex(compName);
if (idx < 0) {
validPath = false;
break;
}
indexPath.push_back(idx);
mlir::Type memTy = recTy.getType(idx);
fir::IntOrValue idxConst =
mlir::IntegerAttr::get(builder.getI32Type(), idx);
coordRef = fir::CoordinateOp::create(
builder, op.getLoc(), builder.getRefType(memTy), coordRef,
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
curTy = memTy;
}
if (!validPath)
continue;
if (auto finalRefTy =
mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
mlir::Type eleTy = finalRefTy.getElementType();
if (fir::isAllocatableType(eleTy))
appendMemberMap(op.getLoc(), coordRef, eleTy, indexPath,
compPathReversed.front());
}
fieldIndicies.emplace_back(fieldIdx);
}

if (newMapOpsForFields.empty())
return mlir::WalkResult::advance();

op.getMembersMutable().append(newMapOpsForFields);
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
for (mlir::Attribute indexList : oldAttr) {
mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();

if (oldMembersIdxAttr)
for (mlir::Attribute indexList : oldMembersIdxAttr) {
llvm::SmallVector<int64_t> listVec;

for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());

newMemberIndices.emplace_back(std::move(listVec));
}
for (auto &path : newMemberIndexPaths)
newMemberIndices.emplace_back(path);

for (int64_t newFieldIdx : fieldIndicies)
newMemberIndices.emplace_back(
llvm::SmallVector<int64_t>(1, newFieldIdx));

op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
op.setPartialMap(true);
Expand Down
38 changes: 0 additions & 38 deletions flang/test/Lower/OpenMP/declare-mapper.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90

!--- omp-declare-mapper-1.f90
subroutine declare_mapper_1
Expand Down Expand Up @@ -263,40 +262,3 @@ subroutine use_inner()
!$omp end target
end subroutine
end program declare_mapper_5

!--- omp-declare-mapper-6.f90
subroutine declare_mapper_nested_parent
type :: inner_t
real, allocatable :: deep_arr(:)
end type inner_t

type, abstract :: base_t
real, allocatable :: base_arr(:)
type(inner_t) :: inner
end type base_t

type, extends(base_t) :: real_t
real, allocatable :: real_arr(:)
end type real_t

!$omp declare mapper (custommapper : real_t :: t) map(tofrom: t%base_arr, t%real_arr)

type(real_t) :: r

allocate(r%base_arr(10))
allocate(r%inner%deep_arr(10))
allocate(r%real_arr(10))
r%base_arr = 1.0
r%inner%deep_arr = 4.0
r%real_arr = 0.0

! CHECK: omp.target
! Check implicit maps for nested parent and deep nested allocatable payloads
! CHECK-DAG: omp.map.info {{.*}} {name = "r.base_arr.implicit_map"}
! CHECK-DAG: omp.map.info {{.*}} {name = "r.deep_arr.implicit_map"}
! The declared mapper's own allocatable is still mapped implicitly
! CHECK-DAG: omp.map.info {{.*}} {name = "r.real_arr.implicit_map"}
!$omp target map(mapper(custommapper), tofrom: r)
r%real_arr = r%base_arr(1) + r%inner%deep_arr(1)
!$omp end target
end subroutine declare_mapper_nested_parent

This file was deleted.

Loading