Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 220 additions & 74 deletions flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstddef>
#include <iterator>
Expand Down Expand Up @@ -75,6 +77,112 @@ class MapInfoFinalizationPass
/// | |
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;

/// Return true if the given path exists in a list of paths.
static bool
containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths,
llvm::ArrayRef<int64_t> path) {
return llvm::any_of(paths, [&](const llvm::SmallVector<int64_t> &p) {
return p.size() == path.size() &&
std::equal(p.begin(), p.end(), path.begin());
});
}

/// Return true if the given path is already present in
/// op.getMembersIndexAttr().
static bool mappedIndexPathExists(mlir::omp::MapInfoOp op,
llvm::ArrayRef<int64_t> indexPath) {
if (mlir::ArrayAttr attr = op.getMembersIndexAttr()) {
for (mlir::Attribute list : attr) {
auto listAttr = mlir::cast<mlir::ArrayAttr>(list);
if (listAttr.size() != indexPath.size())
continue;
bool allEq = true;
for (auto [i, val] : llvm::enumerate(listAttr)) {
if (mlir::cast<mlir::IntegerAttr>(val).getInt() != indexPath[i]) {
allEq = false;
break;
}
}
if (allEq)
return true;
}
}
return false;
}

/// Build a compact string key for an index path for set-based
/// deduplication. Format: "N:v0,v1,..." where N is the length.
static void buildPathKey(llvm::ArrayRef<int64_t> path,
llvm::SmallString<64> &outKey) {
outKey.clear();
llvm::raw_svector_ostream os(outKey);
os << path.size() << ':';
for (size_t i = 0; i < path.size(); ++i) {
if (i)
os << ',';
os << path[i];
}
}

/// Create the member map for coordRef and append it (and its index
/// path) to the provided new* vectors, if it is not already present.
void appendMemberMapIfNew(
mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value coordRef, llvm::ArrayRef<int64_t> indexPath,
llvm::StringRef memberName,
llvm::SmallVectorImpl<mlir::Value> &newMapOpsForFields,
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &newMemberIndexPaths) {
// Local de-dup within this op invocation.
if (containsPath(newMemberIndexPaths, indexPath))
return;
// Global de-dup against already present member indices.
if (mappedIndexPathExists(op, indexPath))
return;

if (op.getMapperId()) {
mlir::omp::DeclareMapperOp symbol =
mlir::SymbolTable::lookupNearestSymbolFrom<
mlir::omp::DeclareMapperOp>(op, op.getMapperIdAttr());
assert(symbol && "missing symbol for declare mapper identifier");
mlir::omp::DeclareMapperInfoOp mapperInfo = symbol.getDeclareMapperInfo();
// TODO: Probably a way to cache these keys in someway so we don't
// constantly go through the process of rebuilding them on every check, to
// save some cycles, but it can wait for a subsequent patch.
for (auto v : mapperInfo.getMapVars()) {
mlir::omp::MapInfoOp map =
mlir::cast<mlir::omp::MapInfoOp>(v.getDefiningOp());
if (!map.getMembers().empty() && mappedIndexPathExists(map, indexPath))
return;
}
}

builder.setInsertionPoint(op);
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
builder, coordRef, /*isOptional=*/false, loc);
llvm::SmallVector<mlir::Value> bounds = fir::factory::genImplicitBoundsOps<
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
builder, info,
hlfir::translateToExtendedValue(loc, builder, hlfir::Entity{coordRef})
.first,
/*dataExvIsAssumedSize=*/false, loc);

mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
builder, loc, coordRef.getType(), coordRef,
mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
op.getMapTypeAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
/*members_index=*/mlir::ArrayAttr{}, bounds,
/*mapperId=*/mlir::FlatSymbolRefAttr(),
builder.getStringAttr(op.getNameAttr().strref() + "." + memberName +
".implicit_map"),
/*partial_map=*/builder.getBoolAttr(false));

newMapOpsForFields.emplace_back(fieldMapOp);
newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
}

/// getMemberUserList gathers all users of a particular MapInfoOp that are
/// other MapInfoOp's and places them into the mapMemberUsers list, which
/// records the map that the current argument MapInfoOp "op" is part of
Expand Down Expand Up @@ -363,7 +471,7 @@ class MapInfoFinalizationPass
mlir::ArrayAttr newMembersAttr;
mlir::SmallVector<mlir::Value> newMembers;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
bool IsHasDeviceAddr = isHasDeviceAddr(op, target);
bool isHasDeviceAddrFlag = isHasDeviceAddr(op, target);

if (!mapMemberUsers.empty() || !op.getMembers().empty())
getMemberIndicesAsVectors(
Expand Down Expand Up @@ -406,7 +514,7 @@ class MapInfoFinalizationPass
mapUser.parent.getMembersMutable().assign(newMemberOps);
mapUser.parent.setMembersIndexAttr(
builder.create2DI64ArrayAttr(memberIndices));
} else if (!IsHasDeviceAddr) {
} else if (!isHasDeviceAddrFlag) {
auto baseAddr =
genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder);
newMembers.push_back(baseAddr);
Expand All @@ -429,7 +537,7 @@ class MapInfoFinalizationPass
// The contents of the descriptor (the base address in particular) will
// remain unchanged though.
uint64_t mapType = op.getMapType();
if (IsHasDeviceAddr) {
if (isHasDeviceAddrFlag) {
mapType |= llvm::to_underlying(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
}
Expand Down Expand Up @@ -701,105 +809,143 @@ class MapInfoFinalizationPass

auto recordType = mlir::cast<fir::RecordType>(underlyingType);
llvm::SmallVector<mlir::Value> newMapOpsForFields;
llvm::SmallVector<int64_t> fieldIndicies;
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;

// 1) Handle direct top-level allocatable fields.
for (auto fieldMemTyPair : recordType.getTypeList()) {
auto &field = fieldMemTyPair.first;
auto memTy = fieldMemTyPair.second;

bool shouldMapField =
llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
if (!fir::isAllocatableType(memTy))
return false;

auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
if (!designateOp)
return false;

return designateOp.getComponent() &&
designateOp.getComponent()->strref() == field;
}) != mapVarForwardSlice.end();

// TODO Handle recursive record types. Adapting
// `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
// entities might be helpful here.

if (!shouldMapField)
if (!fir::isAllocatableType(memTy))
continue;

int32_t fieldIdx = recordType.getFieldIndex(field);
bool alreadyMapped = [&]() {
if (op.getMembersIndexAttr())
for (auto indexList : op.getMembersIndexAttr()) {
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
if (indexListAttr.size() == 1 &&
mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
fieldIdx)
return true;
}

return false;
}();

if (alreadyMapped)
bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
return designateOp && designateOp.getComponent() &&
designateOp.getComponent()->strref() == field;
});
if (!referenced)
continue;

int32_t fieldIdx = recordType.getFieldIndex(field);
builder.setInsertionPoint(op);
fir::IntOrValue idxConst =
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
auto fieldCoord = fir::CoordinateOp::create(
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
fir::factory::AddrAndBoundsInfo info =
fir::factory::getDataOperandBaseAddr(
builder, fieldCoord, /*isOptional=*/false, op.getLoc());
llvm::SmallVector<mlir::Value> bounds =
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
builder, info,
hlfir::translateToExtendedValue(op.getLoc(), builder,
hlfir::Entity{fieldCoord})
.first,
/*dataExvIsAssumedSize=*/false, op.getLoc());

mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
builder, op.getLoc(), fieldCoord.getResult().getType(),
fieldCoord.getResult(),
mlir::TypeAttr::get(
fir::unwrapRefType(fieldCoord.getResult().getType())),
op.getMapTypeAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
/*members_index=*/mlir::ArrayAttr{}, bounds,
/*mapperId=*/mlir::FlatSymbolRefAttr(),
builder.getStringAttr(op.getNameAttr().strref() + "." + field +
".implicit_map"),
/*partial_map=*/builder.getBoolAttr(false));
newMapOpsForFields.emplace_back(fieldMapOp);
fieldIndicies.emplace_back(fieldIdx);
int64_t fieldIdx64 = static_cast<int64_t>(fieldIdx);
llvm::SmallVector<int64_t, 1> idxPath{fieldIdx64};
appendMemberMapIfNew(op, builder, op.getLoc(), fieldCoord, idxPath,
field, newMapOpsForFields, newMemberIndexPaths);
}

// Handle nested allocatable fields along any component chain
// referenced in the region via HLFIR designates.
llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths;
for (mlir::Operation *sliceOp : mapVarForwardSlice) {
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
if (!designateOp || !designateOp.getComponent())
continue;
llvm::SmallVector<llvm::StringRef> compPathReversed;
compPathReversed.push_back(designateOp.getComponent()->strref());
mlir::Value curBase = designateOp.getMemref();
bool rootedAtMapArg = false;
while (true) {
if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
if (!parentDes.getComponent())
break;
compPathReversed.push_back(parentDes.getComponent()->strref());
curBase = parentDes.getMemref();
continue;
}
if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
if (auto barg =
mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
rootedAtMapArg = (barg == opBlockArg);
} else if (auto blockArg =
mlir::dyn_cast_or_null<mlir::BlockArgument>(
curBase)) {
rootedAtMapArg = (blockArg == opBlockArg);
}
break;
}
// Only process nested paths (2+ components). Single-component paths
// for direct fields are handled above.
if (!rootedAtMapArg || compPathReversed.size() < 2)
continue;
builder.setInsertionPoint(op);
llvm::SmallVector<int64_t> indexPath;
mlir::Type curTy = underlyingType;
mlir::Value coordRef = op.getVarPtr();
bool validPath = true;
for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
if (!recTy) {
validPath = false;
break;
}
int32_t idx = recTy.getFieldIndex(compName);
if (idx < 0) {
validPath = false;
break;
}
indexPath.push_back(idx);
mlir::Type memTy = recTy.getType(idx);
fir::IntOrValue idxConst =
mlir::IntegerAttr::get(builder.getI32Type(), idx);
coordRef = fir::CoordinateOp::create(
builder, op.getLoc(), builder.getRefType(memTy), coordRef,
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
curTy = memTy;
}
if (!validPath)
continue;
if (auto finalRefTy =
mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
mlir::Type eleTy = finalRefTy.getElementType();
if (fir::isAllocatableType(eleTy)) {
if (!containsPath(seenIndexPaths, indexPath)) {
seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
appendMemberMapIfNew(op, builder, op.getLoc(), coordRef,
indexPath, compPathReversed.front(),
newMapOpsForFields, newMemberIndexPaths);
}
}
}
}

if (newMapOpsForFields.empty())
return mlir::WalkResult::advance();

op.getMembersMutable().append(newMapOpsForFields);
// Deduplicate by index path to avoid emitting duplicate members for
// the same component. Use a set-based key to keep this near O(n).
llvm::SmallVector<mlir::Value> dedupMapOps;
llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths;
llvm::StringSet<> seenKeys;
for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) {
const auto &path = newMemberIndexPaths[i];
llvm::SmallString<64> key;
buildPathKey(path, key);
if (seenKeys.contains(key))
continue;
seenKeys.insert(key);
dedupMapOps.push_back(mapOp);
dedupIndexPaths.emplace_back(path.begin(), path.end());
}
op.getMembersMutable().append(dedupMapOps);
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();

if (oldMembersIdxAttr)
for (mlir::Attribute indexList : oldMembersIdxAttr) {
if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
for (mlir::Attribute indexList : oldAttr) {
llvm::SmallVector<int64_t> listVec;

for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());

newMemberIndices.emplace_back(std::move(listVec));
}

for (int64_t newFieldIdx : fieldIndicies)
newMemberIndices.emplace_back(
llvm::SmallVector<int64_t>(1, newFieldIdx));
for (auto &path : dedupIndexPaths)
newMemberIndices.emplace_back(path);

op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
op.setPartialMap(true);
Expand Down
Loading